In [1]:
from importnb import Notebook

with Notebook():
    from DataClass import DataSet

In [2]:
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader

In [3]:
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)

In [4]:
def load_checkpoint(checkpoint, model):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])

In [5]:
def get_data_loaders(
        train_dir, train_mask_dir, val_dir, val_maskdir, batch_size,
        train_transform, val_transform):

    train_ds = DataSet(
        image_dir=train_dir, mask_dir=train_mask_dir,
        transform=train_transform)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True)

    val_ds = DataSet(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )

    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False)

    return train_loader, val_loader

In [6]:
def calculate_epoch_accuracy(model, data_loader, device):
    
    model.eval()  # Set the model to evaluation mode
    correct_pixels = 0
    total_pixels = 0

    with torch.no_grad():  # Disable gradient calculation for validation/testing
        for data, targets in data_loader:
            data = data.to(device)
            targets = targets.to(device)

            predictions = model(data)  # Get model predictions

            # Apply argmax to get the predicted class per pixel
            preds = torch.argmax(predictions, dim=1)  # [batch_size, height, width]

            # Calculate the number of correct pixels
            correct_pixels += (preds == targets).sum().item()

            # Calculate the total number of pixels
            total_pixels += torch.numel(preds)

    # Calculate the overall accuracy for the epoch
    epoch_accuracy = correct_pixels / total_pixels
    
    model.train()

    return epoch_accuracy


In [7]:
def calculate_validation_loss(model, data_loader, loss_fn, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0.0
    num_batches = 0

    with torch.no_grad():  # Disable gradient calculation for validation
        for data, targets in data_loader:
            data = data.to(device)
            targets = targets.long().to(device)

            predictions = model(data)  # Get model predictions
            loss = loss_fn(predictions, targets)  # Calculate the loss

            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches  # Compute the average loss
    
    model.train()

    return avg_loss


In [5]:
def calculate_miou(model, data_loader, device, num_classes):
    model.eval()  # Set the model to evaluation mode
    iou_sum = 0.0
    num_samples = 0

    with torch.no_grad():  # Disable gradient calculation for validation/testing
        for data, targets in data_loader:
            data = data.to(device)
            targets = targets.to(device)

            predictions = model(data)  # Get model predictions
            preds = torch.argmax(predictions, dim=1)  # [batch_size, height, width]

            for cls in range(num_classes):
                intersection = torch.logical_and(targets == cls, preds == cls).sum().item()
                union = torch.logical_or(targets == cls, preds == cls).sum().item()

                if union == 0:
                    iou = 1.0  # If there's no union, the IoU for this class is 1.0 (perfect match)
                else:
                    iou = intersection / union

                iou_sum += iou

            num_samples += 1

    # Calculate mean IoU over all classes and samples
    miou = iou_sum / (num_samples * num_classes)
    
    model.train()  # Set the model back to training mode

    return miou
