In [1]:
import torch
import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2 as transforms
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold

import mininet
import torchmetrics

from train import normalize
import mininet
import datasets

In [2]:
image_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST),
    transforms.ToTensor(),
])

label_transform = transforms.Compose([
    transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.NEAREST),
])

device = 'cuda:0'
torch.cuda.empty_cache()
corales_data = datasets.CORALES('datasets/corales', transform=image_transform, target_transform=label_transform, train=True)

# Initialize the model
model = mininet.MiniNetv2(3, corales_data.num_classes, interpolate=True)
model = model.to(device)




TRAINING WITH CORALES DATASET

In [4]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

batch_size = 24
num_epochs = 100
lr = 0.001


train_size = int(0.8 * len(corales_data))  # 80% for training
val_size = len(corales_data) - train_size  # 20% for validation
train_dataset, val_dataset = torch.utils.data.random_split(corales_data, [train_size, val_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Define your loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(model.parameters(), lr=lr)
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    # Training
    model.train()
    epoch_loss = 0
    correct_predictions = 0
    total_predictions = 0
    num_batches = 0
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.to(device).float()
        labels = labels.to(device).squeeze(1).long()

        #plot labels of the batch
        # for i in range(0, len(labels)):
        #     plt.figure(figsize=(10, 5))
            
        #     # Plot input image
        #     plt.subplot(1, 2, 1)
        #     plt.imshow(inputs[i].cpu().numpy().transpose((1, 2, 0)))  # Assuming your input is in (C, H, W) format
        #     plt.title('Input Image')
            
        #     # Plot label
        #     plt.subplot(1, 2, 2)
        #     plt.imshow(labels[i].cpu().numpy())
        #     plt.title('Label')
            
        #     plt.show()
            
        # Forward pass
        outputs = model(inputs)
        # print(f"outputs size: {outputs.size()}")

        loss = criterion(outputs, labels)

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        # print(f"max predicted: {predicted.max()}")
        correct_predictions += (predicted[labels != train_dataset.dataset.ignore_index] == labels[labels != train_dataset.dataset.ignore_index]).float().sum().item()
        total_predictions += torch.numel(labels[labels != train_dataset.dataset.ignore_index])

        # Accumulate loss for the epoch
        epoch_loss += loss.item()
        num_batches += 1

    # Calculate average loss and accuracy for the epoch
    avg_loss = epoch_loss / num_batches
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

    # Store the average loss and accuracy
    train_losses.append(avg_loss)
    train_accuracies.append(accuracy)

    # Validation
    model.eval()
    with torch.no_grad():
        epoch_loss = 0
        correct_predictions = 0
        total_predictions = 0
        num_batches = 0
        for i, (inputs, labels) in enumerate(val_dataloader):
            inputs = inputs.to(device).float()
            labels = labels.to(device).squeeze(1).long()

            # print(f"min label: {labels.min()}, max label: {labels.max()}")

            # Forward pass
            outputs = model(inputs)
            # print(f"outputs size val: {outputs.size()}")

            loss = criterion(outputs, labels)

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            # print(f"max predicted val: {predicted.max()}")
            correct_predictions += (predicted[labels != train_dataset.dataset.ignore_index] == labels[labels != train_dataset.dataset.ignore_index]).float().sum().item()
            total_predictions += torch.numel(labels[labels != train_dataset.dataset.ignore_index])

            # Accumulate loss for the epoch
            epoch_loss += loss.item()
            num_batches += 1

        # Calculate average loss and accuracy for the epoch
        avg_loss = epoch_loss / num_batches
        accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0

        # Store the average loss and accuracy
        val_losses.append(avg_loss)
        val_accuracies.append(accuracy)

    print(f'Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Train Accuracy: {train_accuracies[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}, Val Accuracy: {val_accuracies[-1]:.4f}')

# Create a figure and axes
fig, (ax1, ax2) = plt.subplots(2)

# Plot the losses
ax1.plot(train_losses, label='Train')
ax1.plot(val_losses, label='Val')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()

# Plot the accuracies
ax2.plot(train_accuracies, label='Train')
ax2.plot(val_accuracies, label='Val')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()

plt.tight_layout()
plt.show()

# Save the model weights
torch.save(model.state_dict(), 'corales_weights.pth')

Epoch 1, Train Loss: 4.1113, Train Accuracy: 0.0068, Val Loss: 4.0344, Val Accuracy: 0.0024
Epoch 2, Train Loss: 4.1115, Train Accuracy: 0.0067, Val Loss: 4.0443, Val Accuracy: 0.0022
Epoch 3, Train Loss: 4.1102, Train Accuracy: 0.0067, Val Loss: 4.0556, Val Accuracy: 0.0038
Epoch 4, Train Loss: 4.1097, Train Accuracy: 0.0068, Val Loss: 4.0665, Val Accuracy: 0.0069
Epoch 5, Train Loss: 4.1107, Train Accuracy: 0.0068, Val Loss: 4.0800, Val Accuracy: 0.0070
Epoch 6, Train Loss: 4.1097, Train Accuracy: 0.0069, Val Loss: 4.0921, Val Accuracy: 0.0071
Epoch 7, Train Loss: 4.1092, Train Accuracy: 0.0066, Val Loss: 4.1016, Val Accuracy: 0.0072
Epoch 8, Train Loss: 4.1111, Train Accuracy: 0.0069, Val Loss: 4.1097, Val Accuracy: 0.0070
Epoch 9, Train Loss: 4.1119, Train Accuracy: 0.0067, Val Loss: 4.1140, Val Accuracy: 0.0072
Epoch 10, Train Loss: 4.1106, Train Accuracy: 0.0067, Val Loss: 4.1117, Val Accuracy: 0.0074
Epoch 11, Train Loss: 4.1090, Train Accuracy: 0.0066, Val Loss: 4.1073, Val Acc

KeyboardInterrupt: 

In [None]:
transform_test = transforms.Compose([
    transforms.ToTensor(),
])

# Evaluation on the test set
corales_data_test = datasets.CORALES('datasets/corales', transform=transform_test, train=False)
test_dataloader = torch.utils.data.DataLoader(corales_data_test, batch_size=16, shuffle=False)

inputs, labels = next(iter(test_dataloader))
inputs = inputs.to(device)
labels = labels.to(device)

# Load the model weights
model.load_state_dict(torch.load('corales_weights.pth'))

# Get the model's predictions
model.eval()
with torch.no_grad():
    outputs = model(inputs)
    _, predictions = torch.max(outputs, 1)

inputs = inputs.cpu()
labels = labels.cpu()
predictions = predictions.cpu()

# Get unique colors in the predicted images
unique_colors = torch.unique(predictions)

# Plot the images, ground truth, and predictions
fig, axs = plt.subplots(nrows=inputs.size(0), ncols=3, figsize=(10, inputs.size(0)*5))
for i in range(inputs.size(0)):
    # Plot the image
    axs[i, 0].imshow(inputs[i].permute(1, 2, 0))
    axs[i, 0].set_title('Image')
    axs[i, 0].axis('off')

    # Plot the ground truth
    axs[i, 1].imshow(labels[i].squeeze(0), cmap='gray')
    axs[i, 1].set_title('Ground Truth')
    axs[i, 1].axis('off')

    # Plot the prediction
    axs[i, 2].imshow(predictions[i], cmap='gray')
    axs[i, 2].set_title('Prediction')
    axs[i, 2].axis('off')

plt.tight_layout()
plt.show()

FINE-TUNNING

In [None]:
batch_size = 20
num_epochs = 100
lr = 0.0001

# Load the training dataset
sebens_train_data = datasets.SEBENS('datasets/Sebens_MA_LTM', transform=transform, train=True)
sebens_train_dataloader = torch.utils.data.DataLoader(sebens_train_data, batch_size=batch_size, shuffle=True)

# Load the testing dataset
sebens_test_data = datasets.SEBENS('datasets/Sebens_MA_LTM', transform=transform, train=False)
sebens_test_dataloader = torch.utils.data.DataLoader(sebens_test_data, batch_size=batch_size, shuffle=False)

# Initialize the model
model = mininet.MiniNetv2(3, corales_data.num_classes, interpolate=True)

# Load the saved weights
model.load_state_dict(torch.load('corales_weights.pth'))

# Adjust the output layer
model.output.conv = torch.nn.Conv2d(64, sebens_train_data.num_classes, kernel_size=(2, 2), stride=(1, 1))

# Move the model to the device
model = model.to(device)

# Define your loss function and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Initialize a list to store the losses
losses = []

# Initialize a list to store the accuracies
accuracies = []

# Fine-tune the model on the Sebens_MA_LTM dataset
for epoch in range(num_epochs):
    epoch_loss = 0
    correct_predictions = 0
    total_predictions = 0
    num_batches = 0
    for i, (inputs, labels) in enumerate(sebens_train_dataloader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Forward pass
        inputs = inputs.float()
        outputs = model(inputs)
        labels = labels.squeeze(1)
        labels = labels.long()
        loss = criterion(outputs, labels)

        # # Calculate accuracy
        # _, predicted = torch.max(outputs.data, 1)
        # total_predictions += labels.nelement()
        # correct_predictions += predicted.eq(labels.data).sum().item()

        # Backward and optimize
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Accumulate loss for the epoch
        epoch_loss += loss.item()
        num_batches += 1

    # Calculate average loss and accuracy for the epoch
    avg_loss = epoch_loss / num_batches
    # accuracy = correct_predictions / total_predictions

    # Store the average loss and accuracy
    losses.append(avg_loss)
    accuracies.append(accuracy)

    print(f'Epoch {epoch+1}, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')

# Create a figure and axes
fig, (ax1, ax2) = plt.subplots(2)

# Plot the losses
ax1.plot(losses)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training Loss')

# Plot the accuracies
ax2.plot(accuracies)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training Accuracy')

plt.tight_layout()
plt.show()

In [None]:
def color_labels(label, color_dict):
    # Create a default background color (black)
    colored_label = np.zeros((label.shape[0], label.shape[1], 3), 
dtype=np.uint8)
    
    # Iterate over the label and replace values with corresponding colors
    for i in range(label.shape[0]):
        for j in range(label.shape[1]):
            key = label[i, j]
            if key in color_dict:
                color = color_dict[key]
                colored_label[i, j] = [color[0], color[1], color[2]]
    return colored_label

color_dict_path = '/home/cbm/BOSTON/CoralNet_expansion/color_dict.csv'
color_df = pd.read_csv(color_dict_path)
# Transpose the DataFrame so that each label has its own row
color_df = color_df.transpose()

# Create a dictionary where the keys are integers and the values are the colors
color_dict = {i: [row[1], row[2], row[3]] for i, row in enumerate(color_df.iterrows(), start=1)}

# Convert the colors to integers
for i, color in color_dict.items():
    color_dict[i] = list(map(int, color))

# Get a batch of test images and labels
for x_test, y_test in sebens_train_dataloader:
    # Move the images and labels to the device
    x_test = x_test.to(device)
    x_test = x_test.float()
    y_test = y_test.to(device)

    # Normalize the images
    x_test = (x_test - 0.5) / 0.5

    # Make predictions
    y_pred_test = model(x_test).argmax(dim=1)

    # Plot the images, labels, and predictions
    for i in range(x_test.shape[0]):
        plt.figure(figsize=(10, 25))
        plt.subplot(131)
        plt.imshow(x_test[i].cpu().permute(1, 2, 0).numpy() / x_test.max().item())
        plt.title('input')
        plt.axis('off')
        plt.subplot(132)
        plt.imshow(color_labels(y_test[i].cpu().numpy().squeeze(), color_dict) / 255.0)
        plt.title('target')
        plt.axis('off')
        plt.subplot(133)
        plt.imshow(color_labels(y_pred_test[i].cpu().numpy().squeeze(), color_dict) / 255.0)
        plt.title('prediction')
        plt.axis('off')
        plt.show()

    # Stop after the first batch
    break
