In [1]:
# autoreload
%load_ext autoreload
%autoreload 2

In [2]:
import timm
import torch
import torchvision
import torchvision.transforms as transforms
from datasets.TrainDataset import TrainDataset

In [3]:
# define data transformation
transform = transforms.Compose([
    transforms.Resize((400, 400)), # crashes with 400 x 400
    transforms.ToTensor(), 
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

In [4]:
# paths to image and ground truth folders
image_folder = 'datasets/train/images/'
gt_folder = 'datasets/train/groundtruth/'

In [5]:
# create an instance of custom dataset class
dataset = TrainDataset(image_folder, gt_folder, transform=transform)

In [6]:
# seed for reproducibility
torch.manual_seed(0)
# split the dataset into training and validation sets
train_set, val_set = torch.utils.data.random_split(dataset, [int(0.8*len(dataset)), int(0.2*len(dataset))])

In [7]:
# define batch size
batch_size = 4
# create data loaders
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

In [8]:
# define model
# model = timm.create_model('unet', pretrained=False)
model = torch.hub.load('milesial/Pytorch-UNet', 'unet_carvana', pretrained=False, scale=0.5)
# model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=False)

Using cache found in /home/nadezhda/.cache/torch/hub/milesial_Pytorch-UNet_master


In [9]:
# define loss function
criterion = torch.nn.BCEWithLogitsLoss()

In [10]:
# define optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [11]:
# number of epochs to train the model
n_epochs = 10
# device to use for training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [12]:
# training loop
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        # take one channel of the output
        outputs = outputs[:, [0], :, :]
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print('Epoch: %d | Loss: %.4f' %(epoch, running_loss/len(train_loader)))

Epoch: 0 | Loss: 0.5123
Epoch: 1 | Loss: 0.3902
Epoch: 2 | Loss: 0.3574
Epoch: 3 | Loss: 0.3254
Epoch: 4 | Loss: 0.3196
Epoch: 5 | Loss: 0.2881
Epoch: 6 | Loss: 0.2719
Epoch: 7 | Loss: 0.2701
Epoch: 8 | Loss: 0.2596
Epoch: 9 | Loss: 0.2344


In [13]:
# validation loop
model.eval()
running_loss = 0.0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        # take one channel of the output
        outputs = outputs[:, [0], :, :]
        loss = criterion(outputs, labels.float())
        running_loss += loss.item()
    # print loss
    print('Validation loss: %.4f' %(running_loss/len(val_loader)))

Validation loss: 0.2709


In [14]:
# save the model
torch.save(model.state_dict(), 'models/unet-v1.pt')