In [None]:
import torch
import torch.nn as nn
from datasets import UnderwaterCreaturesDataset, get_train_transform, get_test_transform
from torch.utils.data import DataLoader
from models import SimpleCNN
from trainer import Trainer

In [None]:
# Root directory of the dataset
class_names = ['fish', 'jellyfish', 'penguin', 'puffin', 'shark', 'starfish', 'stingray']
root_dir = 'aquarium-data-cots/aquarium_pretrain'
train_transform, test_transform = get_train_transform(), get_test_transform()

# Create datasets
train_dataset = UnderwaterCreaturesDataset(root_dir, split='train', transform=train_transform)
valid_dataset = UnderwaterCreaturesDataset(root_dir, split='valid', transform=test_transform)
test_dataset = UnderwaterCreaturesDataset(root_dir, split='test', transform=test_transform)

# Create data loaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [None]:
# Define the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = SimpleCNN(num_classes=len(class_names))
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Initialize the Trainer
trainer = Trainer(model, device, criterion, optimizer)
# Train the model
num_epochs = 10
trainer.train(train_loader, valid_loader, num_epochs=num_epochs)
# Plot training metrics
trainer.plot_metrics()
# Test the model
trainer.test(test_loader)

In [None]:
# Plot confusion matrix on test set
trainer.plot_confusion_matrix(test_loader, class_names)
# Visualize predictions on test set
trainer.visualize_predictions(test_loader, class_names, num_images=8)