In [None]:
import wandb
import os

from torch import nn, optim, device, cuda
import torch

import yaml
from tqdm import tqdm, trange

from data_parser import adj_matrix, nodes
from image_parser import train_loader, test_loader, debug_loader
from utils import check_for_missing_values

from network_models import CombinedModel
from model_config_manager import ModelConfigManager

In [None]:
config = yaml.safe_load(open("config.yml"))
epochs = config["EPOCHS"]
RETINA_MODEL = config["RETINA_MODEL"]
DEBUG = config["DEBUG"]

# FIXME: Remove this
DEBUG = True
if DEBUG:
    epochs = 5
    
loader = debug_loader if DEBUG else train_loader

# CLoad model configurations
config_manager = ModelConfigManager()
config_manager.load_configs_from_yaml(["cnn_1.yml", "cnn_2.yml"])
config_manager.set_model_config(RETINA_MODEL)
wandb.init(project="test", config=config_manager.model_config)

In [None]:
# Initialize the combined model
combined_model = CombinedModel(adj_matrix, neurons=nodes, model_config=config_manager.model_config)
dev = device("cuda" if cuda.is_available() else "cpu")
combined_model = combined_model.to(dev)

# Specify the loss function and the optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.00001)

In [None]:
config_manager.output_model_details()

for epoch in trange(epochs):
    running_loss = 0
    # If the model is fast~ish
    # for images, labels in train_loader:
    for images, labels in tqdm(loader):
        # Move images and labels to the device
        images, labels = images.to(dev), labels.to(dev)

        # Checks
        if torch.isnan(images).any():
            raise Exception("NaN in images")
        if torch.isnan(labels).any():
            raise Exception("NaN in labels")

        # Forward pass
        outputs = combined_model(images)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()

        optimizer.step()
 
        running_loss += loss.item()
        
        wandb.log({"loss": loss.item(), "epoch": epoch})

    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')
    # Save model every 10 epochs
    if epoch > 0 and (epoch + 1) % 10 == 0:
        torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))
        print(f"Saved model after {epoch + 1} runs")

wandb.finish()

In [None]:
from image_parser import test_loader
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = combined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

In [None]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")