In [1]:
import os

from torch import nn, optim, device, cuda
import torch

import yaml
from tqdm.notebook import tqdm, trange

import wandb

from data_parser import adj_matrix, nodes
from image_parser import train_loader, test_loader, debug_loader
from utils import log_training_images

from models import CombinedModel
from model_config_manager import ModelConfigManager

In [2]:
config = yaml.safe_load(open("config.yml"))
DEBUG = config["DEBUG"]
epochs = config["EPOCHS"] if not DEBUG else 2
RETINA_MODEL = config["RETINA_MODEL"]
images_fraction = config["IMAGES_FRACTION"]
    
loader = debug_loader if DEBUG else train_loader

# Create the ModelConfigManager and load configurations from YAML files
config_manager = ModelConfigManager()

# Get a specific configuration by model name
config_manager.set_model_config(RETINA_MODEL)

wandb.init(project="connectome", config=config_manager.model_config)

[34m[1mwandb[0m: Currently logged in as: [33meudald[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
# Initialize the combined model
combined_model = CombinedModel(adj_matrix, neurons=nodes, model_config=config_manager.model_config)
dev = device("cuda" if cuda.is_available() else "cpu")
combined_model = combined_model.to(dev)

# Specify the loss function and the optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.00001)
_ = wandb.watch(combined_model, criterion, log="all") 

In [None]:
config_manager.output_model_details()
if DEBUG:
    print("WARNING: Running on DEBUG mode, so using 10% of the images")
elif images_fraction < 1:
    print(f"WARNING: Using {images_fraction * 100}% of the images")

for epoch in trange(epochs):
    running_loss = 0
    correct_predictions = 0
    # If the model is fast~ish
    # for images, labels in train_loader:
    j = 0
    for images, labels in tqdm(loader):
        # print(f"Image {j}")
        # Move images and labels to the device
        images, labels = images.to(dev), labels.to(dev)

        # Checks
        if torch.isnan(images).any():
            raise Exception("NaN in images")
        if torch.isnan(labels).any():
            raise Exception("NaN in labels")

        # Forward pass
        outputs = combined_model(images)

        # Compute the loss
        loss = criterion(outputs, labels)
        
        # Compute the accuracy
        predicted_labels = torch.argmax(outputs, dim=1)
        correct_predictions = (predicted_labels == labels).sum().item()
        accuracy = correct_predictions / len(labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # Logs to wandb
        wandb.log({"loss": loss.item(), "accuracy": accuracy, "epoch": epoch})
        log_training_images(images, labels, outputs)
            
        j += 1

    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')
    # Save model every 2 epochs
    if (epoch + 1) % 2 == 0:
        torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))
        print(f"Saved model after {epoch + 1} runs")

wandb.finish()

Model configurations:
Model name: cnn_2_1
Number of layers: 2
Output channels: 2
Kernel size: 5
Stride: 2
Padding: 1


  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/680 [00:00<?, ?it/s]

Epoch 1/20, Loss: 43703.58419117647


  0%|          | 0/680 [00:00<?, ?it/s]

[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encounte

Epoch 2/20, Loss: 23296.871507352942
Saved model after 2 runs


  0%|          | 0/680 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:03.694931, resuming normal operation.
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: Network error resolved after 0:00:07.849699, resuming normal operation.


Epoch 3/20, Loss: 17533.167647058825


  0%|          | 0/680 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:07.946046, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:00:08.030516, resuming normal operation.


Epoch 4/20, Loss: 11874.872334558824
Saved model after 4 runs


  0%|          | 0/680 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:07.960032, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:00:08.094106, resuming normal operation.


Epoch 5/20, Loss: 10288.868850528492


  0%|          | 0/680 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:07.814574, resuming normal operation.


In [9]:
# torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))

In [5]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = combined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

  0%|          | 0/170 [00:00<?, ?it/s]

In [6]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

Accuracy on the 1360 test images: 57.0588264465332%


In [15]:
model = CombinedModel(adj_matrix, nodes, RETINA_MODEL)
model.load_state_dict(torch.load(os.path.join("models", "model_cnn_1_51_epochs.pth")), strict=False)
model = model.to(dev)
model.eval()

CombinedModel(
  (retina_model): RetinaModel(
    (conv_layer): Conv2d(3, 1, kernel_size=(5, 5), stride=(2, 2), padding=(1, 1))
    (activation): ReLU(inplace=True)
    (pooling_layer): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (connectome_model): ConnectomeNetwork(
    (retina_layer): Linear(in_features=65025, out_features=2952, bias=True)
    (rational_layer): Linear(in_features=2952, out_features=10, bias=True)
  )
)

In [16]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

  0%|          | 0/170 [00:00<?, ?it/s]

In [17]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

Accuracy on the 1360 test images: 53.602943420410156%
