In [2]:
import os

from torch import nn, optim, device, cuda
import torch

import yaml
from tqdm.notebook import tqdm, trange

from data_parser import adj_matrix, nodes
from image_parser import train_loader, test_loader
from utils import check_for_missing_values

from network_models import CombinedModel
from model_config_manager import ModelConfigManager

In [3]:
config = yaml.safe_load(open("config.yml"))
epochs = config["EPOCHS"]
RETINA_MODEL = config["RETINA_MODEL"]
DEBUG = config["DEBUG"]

# Create the ModelConfigManager and load configurations from YAML files
# FIXME: duplicated code
config_manager = ModelConfigManager()
config_manager.load_configs_from_yaml(["cnn_1.yml", "cnn_2.yml"])

# Get a specific configuration by model name
config_manager.set_model_config(RETINA_MODEL)
config_manager.output_model_details()

Model configurations:
Model name: vgg16
This is a pretrained model


In [4]:
# Initialize the combined model
combined_model = CombinedModel(adj_matrix, neurons=nodes, model_config=config_manager.model_config)
dev = device("cuda" if cuda.is_available() else "cpu")
combined_model = combined_model.to(dev)

# Specify the loss function and the optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.00001)



In [None]:
for epoch in trange(epochs):
    running_loss = 0
    # If the model is fast~ish
    # for images, labels in train_loader:
    for images, labels in tqdm(train_loader):
        # Move images and labels to the device
        images, labels = images.to(dev), labels.to(dev)

        # Checks
        if torch.isnan(images).any():
            raise Exception("NaN in images")
        if torch.isnan(labels).any():
            raise Exception("NaN in labels")

        # Forward pass
        outputs = combined_model(images)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()

        if DEBUG:
            # check for missing values in model parameters
            check_for_missing_values(combined_model, epoch)
            # Clip gradients to avoid exploding gradients
            nn.utils.clip_grad_norm_(combined_model.parameters(), 1)
            # Clip parameters to avoid exploding parameters
            for p in combined_model.parameters():
                p.data.clamp_(-1, 1) 

        optimizer.step()
 
        running_loss += loss.item()

    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')
    # Save model every 10 epochs
    if epoch > 0 and (epoch + 1) % 10 == 0:
        torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))
        print(f"Saved model after {epoch + 1} runs")

  0%|          | 0/100 [00:00<?, ?it/s]
  0%|          | 0/680 [00:00<?, ?it/s][A
  0%|          | 1/680 [00:03<39:57,  3.53s/it][A
  0%|          | 2/680 [00:06<34:49,  3.08s/it][A
  0%|          | 3/680 [00:09<34:14,  3.03s/it][A
  1%|          | 4/680 [00:12<33:56,  3.01s/it][A
  1%|          | 5/680 [00:15<33:47,  3.00s/it][A
  1%|          | 6/680 [00:16<26:04,  2.32s/it][A
  1%|          | 7/680 [00:20<31:41,  2.83s/it][A
  1%|          | 8/680 [00:23<33:18,  2.97s/it][A
  1%|▏         | 9/680 [00:26<34:41,  3.10s/it][A
  1%|▏         | 10/680 [00:29<33:11,  2.97s/it][A
  2%|▏         | 11/680 [00:32<34:41,  3.11s/it][A
  2%|▏         | 12/680 [00:36<35:40,  3.20s/it][A
  2%|▏         | 13/680 [00:39<34:11,  3.08s/it][A
  2%|▏         | 14/680 [00:42<35:21,  3.18s/it][A
  2%|▏         | 15/680 [00:45<36:04,  3.25s/it][A
  2%|▏         | 16/680 [00:48<34:32,  3.12s/it][A
  2%|▎         | 17/680 [00:51<34:21,  3.11s/it][A
  3%|▎         | 18/680 [00:54<34:12,  3.1

In [9]:
torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))

In [10]:
epoch

28

In [9]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = combined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

  0%|          | 0/170 [00:00<?, ?it/s]

In [10]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

Accuracy on the 1360 test images: 53.45588302612305%


In [15]:
model = CombinedModel(adj_matrix, nodes, RETINA_MODEL)
model.load_state_dict(torch.load(os.path.join("models", "model_cnn_1_51_epochs.pth")), strict=False)
model = model.to(dev)
model.eval()

CombinedModel(
  (retina_model): RetinaModel(
    (conv_layer): Conv2d(3, 1, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (activation): ReLU(inplace=True)
    (pooling_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (connectome_model): ConnectomeNetwork(
    (retina_layer): Linear(in_features=65025, out_features=2952, bias=True)
    (rational_layer): Linear(in_features=2952, out_features=10, bias=True)
  )
)

In [16]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

  0%|          | 0/170 [00:00<?, ?it/s]

In [17]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

Accuracy on the 1360 test images: 53.602943420410156%
