In [1]:
import torch

from copy import deepcopy
from matplotlib import pyplot as plt

In [2]:
def train(model, loader, optimizer, criterion):
    running_loss = 0.
    
    model.train()
    for images, masks in loader:
        images = images.to(device)
        masks = masks.to(device)
        
        optimizer.zero_grad() 
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    return running_loss / len(loader)

In [3]:
def evaluate(model, loader, criterion):
    running_loss = 0.
    
    model.eval()
    with torch.no_grad():
        for images, masks in loader:
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, masks)
            
            running_loss += loss.item()
            
    return running_loss / len(loader)

In [4]:
def train_loop_with_validation(model, num_epoch, train_loader, val_loader, optimizer, criterion):
    best_model, best_loss = None, float('inf')
    
    train_losses, val_losses = [], []

    for epoch in range(num_epoch):
        print(f"Epoch {epoch + 1}/{num_epoch} started...")

        train_loss = train(model, train_loader, optimizer, criterion)
        validation_loss = evaluate(model, val_loader, criterion)

        if validation_loss < best_loss:
            best_loss = validation_loss
            best_model = deepcopy(model)
            
        train_losses.append(train_loss)
        val_losses.append(validation_loss)

        print(f'Training loss: {train_loss}')
        print(f'Validation loss: {validation_loss}')
        print('...........................................')
        
    plt.plot(range(num_epoch), train_losses, label='Train loss')
    plt.plot(range(num_epoch), val_losses, label='Validation loss')
    plt.legend()
    plt.show()
        
    return best_model, best_loss

In [5]:
def train_loop(model, num_epoch, loader, optimizer, criterion):
    best_model, best_loss = None, float('inf')
    
    train_losses = []

    for epoch in range(num_epoch):
        print(f"Epoch {epoch + 1}/{num_epoch} started...")

        train_loss = train(model, train_loader, optimizer, criterion)

        if train_loss < best_loss:
            best_loss = train_loss
            best_model = deepcopy(model)
            
        train_losses.append(train_loss)

        print(f'Training loss: {train_loss}')
        print('...........................................')
        
    plt.plot(range(num_epoch), train_losses, label='Train loss')
    plt.legend()
    plt.show()
        
    return best_model, best_loss

In [6]:
def display_predictions(model, loader, device, display_size):
    model.eval() # Set to evaluation mode

    # Get a batch of test data
    images, masks = next(iter(loader))
    images = images.to(device)
    
    # Predict
    with torch.no_grad():
        outputs = model(images)
    
    # Plot the results
    for i in range(display_size): # Display the first display_size images
        plt.figure(figsize=(10, 5))
        
        plt.subplot(1, 3, 1)
        plt.imshow(images[i].cpu().permute(1, 2, 0), cmap='gray') # Original image
        plt.title("Input Image")
        
        plt.subplot(1, 3, 2)
        plt.imshow(masks[i][0], cmap='gray') # Ground truth mask
        plt.title("Ground Truth Mask")
        
        plt.subplot(1, 3, 3)
        plt.imshow(outputs[i][0].cpu(), cmap='gray') # Predicted mask
        plt.title("Predicted Mask")
        
        plt.show()