In [2]:
import os

from torch import nn, optim, device, cuda
import torch

import yaml
from tqdm.notebook import tqdm, trange

from data_parser import adj_matrix, nodes
from image_parser import train_loader, test_loader
from utils import check_for_missing_values

from network_models import CombinedModel
from model_config_manager import ModelConfigManager

In [3]:
config = yaml.safe_load(open("config.yml"))
epochs = config["EPOCHS"]
RETINA_MODEL = config["RETINA_MODEL"]

# Create the ModelConfigManager and load configurations from YAML files
# FIXME: duplicated code
config_manager = ModelConfigManager()
config_manager.load_configs_from_yaml(["cnn_1.yml", "cnn_2.yml"])

# Get a specific configuration by model name
model_config = config_manager.get_config(RETINA_MODEL)

problems = False

In [4]:
# Initialize the combined model
combined_model = CombinedModel(adj_matrix, neurons=nodes, retina_model_type=RETINA_MODEL)

dev = device("cuda" if cuda.is_available() else "cpu")
combined_model = combined_model.to(dev)

# Specify the loss function and the optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.00001)

In [5]:
print(f"Training {RETINA_MODEL} for {epochs} epochs on {dev} with the following params")
print(f"Out channels: {model_config.out_channels}")
print(f"Kernel size: {model_config.kernel_size}")
print(f"Stride: {model_config.stride}")
print(f"Padding: {model_config.padding}")


for epoch in trange(epochs):
    running_loss = 0
    # If the model is fast~ish
    # for images, labels in train_loader:
    for images, labels in tqdm(train_loader):
        # Move images and labels to the device
        images, labels = images.to(dev), labels.to(dev)

        # Checks
        if torch.isnan(images).any():
            raise Exception("NaN in images")
        if torch.isnan(labels).any():
            raise Exception("NaN in labels")

        # Forward pass
        outputs = combined_model(images)

        # Compute the loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()

        if problems:
            # check for missing values in model parameters
            check_for_missing_values(combined_model, epoch)
            # Clip gradients to avoid exploding gradients
            nn.utils.clip_grad_norm_(combined_model.parameters(), 1)
            # Clip parameters to avoid exploding parameters
            for p in combined_model.parameters():
                p.data.clamp_(-1, 1) 

        optimizer.step()

        running_loss += loss.item()

    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')
    # Save model every 10 epochs
    if epoch > 0 and (epoch + 1) % 10 == 0:
        torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))
        print(f"Saved model after {epoch + 1} runs")

Training cnn_2 for 100 epochs on cuda with the following params
Out channels: 2
Kernel size: 5
Stride: 2
Padding: 1


  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 1/100, Loss: 8808229.297058824


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 2/100, Loss: 1705317.505882353


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 3/100, Loss: 1387597.4


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 4/100, Loss: 971960.0647058823


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 5/100, Loss: 658568.1470588235


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 6/100, Loss: 485027.52352941176


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 7/100, Loss: 400308.00588235294


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 8/100, Loss: 316328.3661764706


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 9/100, Loss: 173369.8323529412


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 10/100, Loss: 115928.20735294117
Saved model after 10 runs


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 11/100, Loss: 109269.25863970588


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 12/100, Loss: 62873.05863970588


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 13/100, Loss: 66162.46029411764


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 14/100, Loss: 69455.75505514706


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 15/100, Loss: 46032.34926470588


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 16/100, Loss: 46046.67261029412


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 17/100, Loss: 53083.099172794115


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 18/100, Loss: 47493.83667279412


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 19/100, Loss: 27374.761029411766


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 20/100, Loss: 39544.416176470586
Saved model after 20 runs


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 21/100, Loss: 23325.491911764704


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 22/100, Loss: 21502.515625


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 23/100, Loss: 22969.36994485294


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 24/100, Loss: 21037.33419117647


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 25/100, Loss: 17693.852205882355


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 26/100, Loss: 11002.12794117647


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 27/100, Loss: 10624.827435661764


  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 28/100, Loss: 8269.350735294118


  0%|          | 0/680 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [9]:
torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))

In [10]:
epoch

28

In [9]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = combined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

  0%|          | 0/170 [00:00<?, ?it/s]

In [10]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

Accuracy on the 1360 test images: 53.45588302612305%


In [15]:
model = CombinedModel(adj_matrix, nodes, RETINA_MODEL)
model.load_state_dict(torch.load(os.path.join("models", "model_cnn_1_51_epochs.pth")), strict=False)
model = model.to(dev)
model.eval()

CombinedModel(
  (retina_model): RetinaModel(
    (conv_layer): Conv2d(3, 1, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (activation): ReLU(inplace=True)
    (pooling_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (connectome_model): ConnectomeNetwork(
    (retina_layer): Linear(in_features=65025, out_features=2952, bias=True)
    (rational_layer): Linear(in_features=2952, out_features=10, bias=True)
  )
)

In [16]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

  0%|          | 0/170 [00:00<?, ?it/s]

In [17]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

Accuracy on the 1360 test images: 53.602943420410156%
