In [None]:
import copy
import os
import random
import shutil
import zipfile
from math import atan2, cos, sin, sqrt, pi, log

from sklearn.metrics import f1_score
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
from numpy import linalg as LA
from torch import optim, nn
from torch.utils.data import DataLoader, random_split
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from tqdm import tqdm

In [None]:
from unet import *
from ants_dataset import *
from utils_final import *

### Creation of the directories and data splitting

In [None]:
# Paths to the original image and mask directories
original_images_dir = '/content/train'
original_masks_dir = '/content/train_masks'

# Paths to the new directories
train_images_aug_dir = '/content/train_augmented'
train_masks_aug_dir = '/content/train_masks_augmented'
test_images_dir = '/content/test'
test_masks_dir = '/content/test_masks'

# Create the new directories if they do not exist
os.makedirs(train_images_aug_dir, exist_ok=True)
os.makedirs(train_masks_aug_dir, exist_ok=True)
os.makedirs(test_images_dir, exist_ok=True)
os.makedirs(test_masks_dir, exist_ok=True)

# Get a list of all image and mask file names
image_files = sorted(os.listdir(original_images_dir)) 
mask_files = sorted(os.listdir(original_masks_dir)) 

assert len(image_files) == len(mask_files), "Mismatch between images and masks count."

train_images, test_images, train_masks, test_masks = train_test_split(
    image_files, mask_files, test_size=0.4, random_state=42)

# Move training/testing images and masks to their respective directories
move_files(train_images, original_images_dir, train_images_aug_dir)
move_files(train_masks, original_masks_dir, train_masks_aug_dir)

move_files(test_images, original_images_dir, test_images_dir)
move_files(test_masks, original_masks_dir, test_masks_dir)


### Preprocessing and data augmentation

In [None]:
train_dataset = AntsDataset(images_path=train_images_aug_dir, masks_path=train_masks_aug_dir, augmentation=True)
test_dataset = AntsDataset(images_path=test_images_dir, masks_path=test_masks_dir, augmentation=False)

train_dataset.load_patches(training=True)
test_dataset.load_patches(training=False)

# Augments only the training dataset
augmented_images = torch.stack([img for img in train_dataset.loaded_patches_rgb])
augmented_masks = torch.stack([label.squeeze(0) for label in train_dataset.loaded_patches_gt]) 

#Create a Tensor with the augmented dataset
train_dataset = torch.utils.data.TensorDataset(augmented_images, augmented_masks)
generator = torch.Generator().manual_seed(25)

# split the test_dataset into validation and testing
test_dataset, val_dataset = random_split(test_dataset, [0.5, 0.5], generator=generator)


In [None]:
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

if device == "cuda":
    num_workers = torch.cuda.device_count() * 2

### Data loading

In [None]:
LEARNING_RATE = 1e-4
BATCH_SIZE = 8

train_dataloader = DataLoader(dataset=train_dataset,
                              num_workers=num_workers, pin_memory=False,
                              batch_size=BATCH_SIZE,
                              shuffle=True)
val_dataloader = DataLoader(dataset=val_dataset,
                            num_workers=num_workers, pin_memory=False,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

test_dataloader = DataLoader(dataset=test_dataset,
                            num_workers=num_workers, pin_memory=False,
                            batch_size=BATCH_SIZE,
                            shuffle=True)

model = UNet(in_channels=3, num_classes=4).to(device)
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE)
weights = torch.tensor([0.54, 0.73, 1.0, 0.06])
criterion = nn.CrossEntropyLoss(weight=weights.to(device))

In [None]:
torch.cuda.empty_cache()

### Model training

In [None]:
EPOCHS = 15

# Initialize the metrics
train_accuracies = []
train_f1s = []
val_accuracies = []
val_f1s = []


for epoch in tqdm(range(EPOCHS)):
    model.train()
    train_running_accuracy = 0
    train_running_f1 = 0

    for idx, img_mask in enumerate(tqdm(train_dataloader, position=0, leave=True)):
        img = img_mask[0].float().to(device)
        mask = img_mask[1].long().to(device)

        # Remove channel dimension (if mask has shape [B, 1, H, W])
        mask = mask.squeeze(1)

        # get the predictions
        y_pred = model(img)
        optimizer.zero_grad()

        # Compute training metrics
        loss = criterion(y_pred, mask)
        accuracy = compute_accuracy(y_pred, mask)
        f1 = compute_f1_score(y_pred, mask, num_classes=y_pred.shape[1])

        train_running_accuracy += accuracy
        train_running_f1 += f1

        loss.backward()
        optimizer.step()
    
    # Calculate average metrics over the training set
    train_accuracy = train_running_accuracy / (idx + 1)
    train_f1 = train_running_f1 / (idx + 1)

    train_accuracies.append(train_accuracy)
    train_f1s.append(train_f1)

    # Compute accuracy and f1_score on the validation set
    val_accuracy, val_f1 = eval_model(val_dataloader, model, device)

    val_accuracies.append(val_accuracy)
    val_f1s.append(val_f1)

    print("-" * 30)
    print(f"Training Accuracy EPOCH {epoch + 1}: {train_accuracy:.4f}")
    print(f"Training F1 Score EPOCH {epoch + 1}: {train_f1:.4f}")
    print("\n")
    print(f"Validation Accuracy EPOCH {epoch + 1}: {val_accuracy:.4f}")
    print(f"Validation F1 Score EPOCH {epoch + 1}: {val_f1:.4f}")
    print("-" * 30)

    # Save the model and metrics
    save_model(model, epoch)

plot_training_metrics(EPOCHS, train_accuracies, val_accuracies, train_f1s, val_f1s)

In [None]:
# Find the best epoch and save it
best_score = 0 
for epoch in range(len(val_f1s)):
    f1_value= val_f1s[epoch]
    accuracy = val_accuracies[epoch]
    
    score = (f1_value + accuracy) / 2

    if score > best_score:
        best_epoch = epoch
        best_score = score
print(f'best epoch {best_epoch+1}, Validation Accuracy : {val_accuracies[best_epoch]}, Validation F1 Score : {val_f1s[best_epoch]}')

In [None]:
# Save the trained model for the best epoch
model_path = f'/content/models/unet_epoch_{best_epoch}.pth'
trained_model = UNet(in_channels=3, num_classes=4).to(device)
trained_model.load_state_dict(torch.load(model_path, map_location=torch.device(device)))

### Model evaluation

In [None]:
# Compute accuracy and f1_score on the test set
test_accuracy, test_f1 = eval_model(test_dataloader, trained_model, device)

print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")

In [None]:
# Display 3 random normalized images, predictions and masks of the test set
n = 3

image_tensors = []
mask_tensors = []
image_paths = []

for _ in range(n):
    random_index = random.randint(0, len(test_dataloader.dataset) - 1)
    random_sample = test_dataloader.dataset[random_index]

    image_tensors.append(random_sample[0])
    mask_tensors.append(random_sample[1])

mask_pred_plots(image_tensors, mask_tensors, trained_model, device="cpu")