In [1]:
%load_ext autoreload
%autoreload 2
import torch
import torch.nn as nn
from datasets import UnderwaterCreaturesMultiLabelDataset, collate_fn
from torch.utils.data import DataLoader
from models import SimpleCNN
from trainer import Trainer


In [2]:
# Root directory of the dataset
class_names = ['fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray']
num_classes = len(class_names)
root_dir = 'datasets/aquarium-data-cots/aquarium_pretrain'

# # Create datasets
train_dataset = UnderwaterCreaturesMultiLabelDataset(root_dir, split='train', num_classes=num_classes)
valid_dataset = UnderwaterCreaturesMultiLabelDataset(root_dir, split='valid', num_classes=num_classes)
test_dataset = UnderwaterCreaturesMultiLabelDataset(root_dir, split='test', num_classes=num_classes)

# Create data loaders
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

In [None]:
device = "cuda"
model = SimpleCNN(num_classes=7).cuda()
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [3]:
# Initialize the Trainer
trainer = Trainer(model, device, criterion, optimizer, num_classes=len(class_names), class_names=class_names)
# Train the model
num_epochs = 20
trainer.train(train_loader, valid_loader, num_epochs=num_epochs)
# Plot training metrics
trainer.plot_metrics()
# Test the model
# Evaluate metrics on validation set
val_targets, val_outputs, average_precisions, mAP = trainer.evaluate_metrics(valid_loader)

NameError: name 'model' is not defined

In [11]:
# Plot ROC curves for validation set
# trainer.plot_roc_curves(val_targets, val_outputs)
# Plot confusion matrices for validation set
# trainer.plot_confusion_matrix(val_targets, val_outputs, threshold=0.5)
# 
confusion_matrix_test = trainer.compute_pairwise_confusion_matrix(val_targets, val_outputs, threshold=0.5)
trainer.plot_pairwise_confusion_matrix(confusion_matrix_test, class_names, epoch=num_epochs, threshold=0.5)


In [None]:
# Visualize predictions on validation set
trainer.visualize_predictions(valid_loader, num_images=8, threshold=0.8)


# Test set

In [None]:
# Evaluate metrics on test set
test_targets, test_outputs, test_average_precisions, test_mAP = trainer.evaluate_metrics(test_loader)

# Plot ROC curves for test set
trainer.plot_roc_curves(test_targets, test_outputs)

# Plot confusion matrices for test set
trainer.plot_confusion_matrices(test_targets, test_outputs, threshold=0.5)

# Visualize predictions on test set
trainer.visualize_predictions(test_loader, num_images=8, threshold=0.7)