In [None]:
# autoreload
%load_ext autoreload
%autoreload 2

In [None]:
import os
if os.getcwd().split('/')[-1] != 'Road-Segmentation-ML':
    os.chdir("..")
print("CWD:", os.getcwd())

In [None]:
import torch
import torchvision.transforms as transforms
from datasets.BaseDataset import BaseDataset
from datasets.TransformDataset import TransformDataset

In [None]:
# paths to image and ground truth folders
image_folder = "datasets/train/images/"
gt_folder = "datasets/train/groundtruth/"

In [None]:
# create an instance of the base dataset class
dataset = BaseDataset(image_folder, gt_folder)

In [None]:
# seed for reproducibility
torch.manual_seed(0)
# split the dataset into training and validation sets
train_set, val_set = torch.utils.data.random_split(
    dataset, [int(0.8 * len(dataset)), int(0.2 * len(dataset))]
)

In [None]:
# define data transform (same for images and groundtruth)
transform = transforms.Compose(
    [
        # transforms.Resize((400, 400)), # crashes with 400 x 400
        transforms.ToTensor(),
    ]
)

In [None]:
# transforms
train_set = TransformDataset(
    train_set, image_transform=transform, gt_transform=transform
)
val_set = TransformDataset(val_set, image_transform=transform, gt_transform=transform)

In [None]:
# define batch size
batch_size = 2
# create data loaders
train_loader = torch.utils.data.DataLoader(
    train_set, batch_size=batch_size, shuffle=True
)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=True)

In [None]:
# define model
# model = timm.create_model('unet', pretrained=False)
model = torch.hub.load(
    "milesial/Pytorch-UNet", "unet_carvana", pretrained=False, scale=0.5
)
# model = torch.hub.load('mateuszbuda/brain-segmentation-pytorch', 'unet', in_channels=3, out_channels=1, init_features=32, pretrained=False)

In [None]:
# define loss function
criterion = torch.nn.BCEWithLogitsLoss()

In [None]:
# define Adam optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# define SGD optimizer
# optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
# number of epochs to train the model
n_epochs = 10
# device to use for training
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# training loop
for epoch in range(n_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        # take one channel of the output
        # outputs = outputs[:, [0], :, :]
        loss = criterion(outputs, labels.float())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print("Epoch: %d | Loss: %.4f" % (epoch, running_loss / len(train_loader)))

In [None]:
# validation loop
model.eval()
running_loss = 0.0
with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        # take one channel of the output
        outputs = outputs[:, [0], :, :]
        loss = criterion(outputs, labels.float())
        running_loss += loss.item()
    # print loss
    print("Validation loss: %.4f" % (running_loss / len(val_loader)))