In [4]:
import segmentation_models_pytorch as smp
from dataloader import Angioectasias
import torch
import torchvision
import torch.nn as nn
from torch.utils import data
from torch import optim
from tqdm import tqdm
from models import Models

In [6]:
M = Models()
model = M.PSP(img_ch=3, output_ch=1)
print(model)

PSPNet(
  (encoder): SENetEncoder(
    (layer0): Sequential(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (pool): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
    )
    (layer1): Sequential(
      (0): SEResNeXtBottleneck(
        (conv1): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (re

In [4]:
train_img = Angioectasias('ampulla-of-vater', mode='train')
val_img = Angioectasias('ampulla-of-vater', mode='val')

train_queue = data.DataLoader(train_img, batch_size=1,
                            drop_last=False, shuffle=True)
val_queue = data.DataLoader(val_img, batch_size=1, shuffle=True)

['ampulla1.png', 'ampulla2.png', 'ampulla3.png', 'ampulla4.png', 'ampulla5.png', 'ampulla6.png', 'ampulla7.png', 'ampulla8.png', 'ampulla9.png', 'ampulla10.png', 'ampulla11.png', 'ampulla12.png', 'ampulla13.png']
Mean: [0.48880472 0.30032677 0.23514914], Std: [0.36276894 0.25298407 0.19216456]
['ampulla14.png', 'ampulla15.png', 'ampulla16.png', 'ampulla17.png', 'ampulla18.png', 'ampulla19.png']
Mean: [0.55473959 0.34673989 0.26728009], Std: [0.41426158 0.2971396  0.22648759]


In [None]:
criterion = nn.BCELoss() #nn.BCEWithLogitsLoss()
criterion = criterion.to('cpu')
model_optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.01)

In [None]:
for epoch in range(2):
    model.train()
    tbar = tqdm(train_queue)
    for step, (input, target) in enumerate(tbar):
        input = input.to(device='cpu', dtype=torch.float32)
        target = target.to(device='cpu', dtype=torch.float32)

        predicts = model(input)
        predicts_prob = torch.sigmoid(predicts)
        loss = criterion(predicts_prob, target)
        print('loss')
        model_optimizer.zero_grad()
        loss.backward()
        model_optimizer.step()

In [18]:
loss = smp.utils.losses.DiceLoss()
metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

In [19]:
train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device='cpu',
    verbose=True,
)

valid_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device='cpu',
    verbose=True,
)

In [22]:
max_score = 0

for i in range(0, 2):
    
    print('\nEpoch: {}'.format(i))
    train_logs = train_epoch.run(train_loader)
    valid_logs = valid_epoch.run(valid_loader)
    
    # do something (save model, change lr, etc.)
    if max_score < valid_logs['iou_score']:
        max_score = valid_logs['iou_score']
        torch.save(model, './best_model.pth')
        print('Model saved!')


Epoch: 0


train:   0%|          | 0/7 [00:00<?, ?it/s][A[A

train:   0%|          | 0/7 [00:11<?, ?it/s, dice_loss - -0.4182, iou_score - 3.42e-05][A[A

train:  14%|█▍        | 1/7 [00:11<01:09, 11.52s/it, dice_loss - -0.4182, iou_score - 3.42e-05][A[A

train:  14%|█▍        | 1/7 [00:27<01:09, 11.52s/it, dice_loss - -0.5517, iou_score - 0.001579][A[A

train:  29%|██▊       | 2/7 [00:31<01:18, 15.73s/it, dice_loss - -0.5517, iou_score - 0.001579]


KeyboardInterrupt: 