In [1]:
#Import required libraries
import torch
import torch.nn as nn
import cv2
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torch_snippets import stems, read
from torch.optim.lr_scheduler import ExponentialLR

from enel645_group5_functions import *

In [2]:
# Check if GPU is available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print("The device being used is:", device)

The device being used is: cpu


In [3]:
train_ds = SegmentationDataset('train')
val_ds = SegmentationDataset('val')

train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=train_ds.collate_fn)
val_dl = DataLoader(val_ds, batch_size=8, collate_fn=val_ds.collate_fn)

In [4]:
net = UNet().to(device)

criterion = nn.CrossEntropyLoss() # Loss function
optimizer = torch.optim.AdamW(net.parameters(), lr = 0.001)
scheduler = ExponentialLR(optimizer, gamma=0.9) # decrease learning rate over time

In [5]:
nepochs = 100
PATH = './best_model__.pth' # Path to save the best model
last_model_update = 0

best_loss = 1e+20
for epoch in range(nepochs):  # loop over the dataset multiple times
    # Training Loop
    train_loss = 0.0
    net.train()
    for i, data in enumerate(train_dl):
        images, ground_truth_masks = data
        optimizer.zero_grad()
        
        # forward + backward + optimize
        _masks = net(images)
        
        loss = criterion(_masks, ground_truth_masks)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
    print(f'{epoch + 1},  train loss: {train_loss / i:.3f},', end = ' ')
    scheduler.step()
    
    val_loss = 0
    net.eval()
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for i, data in enumerate(val_dl):
            images, ground_truth_masks = data
            _masks = net(images)
            loss = criterion(_masks, ground_truth_masks)
            
            val_loss += loss.item()
            
        print(f'val loss: {val_loss / i:.3f}')
        
        last_model_update += 1
        
        # Save best model
        if val_loss < best_loss:
            print("Saving model")
            torch.save(net.state_dict(), PATH)
            best_loss = val_loss
            last_model_update = 0
            
        # Early stopping if model doesn't improve in 10 epochs
        if last_model_update >= 10:
            print("Stopping early. No model improvement in 10 epochs.")
            break
            
print('Finished Training')

KeyboardInterrupt: 