In [1]:
import os

import pandas as pd
from torch import nn, optim, device, cuda
import torch

import yaml
from tqdm.notebook import tqdm, trange

import wandb

from data_parser import adj_matrix, nodes
from image_parser import train_loader, test_loader, debug_loader
from utils import log_training_images

from models import CombinedModel
from model_config_manager import ModelConfigManager

In [2]:
config = yaml.safe_load(open("config.yml"))
DEBUG = config["DEBUG"]
epochs = config["EPOCHS"] if not DEBUG else 2
RETINA_MODEL = config["RETINA_MODEL"]
images_fraction = config["IMAGES_FRACTION"]
continue_training = config["CONTINUE_TRAINING"]
saved_model_path = config["SAVED_MODEL_PATH"]
model_name = config["MODEL_NAME"]

loader = debug_loader if DEBUG else train_loader

dev = device("cuda" if cuda.is_available() else "cpu")    

# Create the ModelConfigManager and load configurations from YAML files
config_manager = ModelConfigManager()

# Get a specific configuration by model name
config_manager.set_model_config(RETINA_MODEL)

In [3]:
combined_model = CombinedModel(adj_matrix, neurons=nodes, model_config=config_manager.model_config)

if continue_training:
    # If we want to continue training a saved model
    combined_model.load_state_dict(torch.load(saved_model_path))

combined_model = combined_model.to(dev)
wandb.init(project="connectome", config=config_manager.model_config)

# Specify the loss function and the optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.00001)
_ = wandb.watch(combined_model, criterion, log="all") 


[34m[1mwandb[0m: Currently logged in as: [33meudald[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
config_manager.output_model_details()
if DEBUG:
    print("WARNING: Running on DEBUG mode, so using 10% of the images")
elif images_fraction < 1:
    print(f"WARNING: Using {images_fraction * 100}% of the images")

for epoch in trange(epochs):
    running_loss = 0
    correct_predictions = 0
    # If the model is fast~ish
    # for images, labels in train_loader:
    for images, labels in tqdm(loader):

        # Move images and labels to the device
        images, labels = images.to(dev), labels.to(dev)

        # Forward pass
        outputs = combined_model(images)

        # Compute the loss
        loss = criterion(outputs, labels)
        
        # Compute the accuracy
        predicted_labels = torch.argmax(outputs, dim=1)
        correct_predictions = (predicted_labels == labels).sum().item()
        accuracy = correct_predictions / len(labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update loss
        running_loss += loss.item()
        
        # Logs to wandb
        wandb.log({"loss": loss.item(), "accuracy": accuracy, "epoch": epoch})
        log_training_images(images, labels, outputs)

    print(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')
    # Save model every 2 epochs
    if (epoch + 1) % 2 == 0:
        torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{model_name}_{epoch + 1}_epochs.pth"))
        print(f"Saved model after {epoch + 1} runs")

wandb.finish()

Model configurations:
Model name: vgg16
This is a pretrained model


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
# torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))

In [18]:
# Test de model
correct = 0
total = 0
test_results_df = pd.DataFrame(columns=["Image", "Real Label", "Predicted Label", "Correct Prediction"])

j = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = combined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
        # Convert the tensor values to CPU and numpy
        labels_cpu = labels.cpu().numpy()
        predicted_cpu = predicted.cpu().numpy()
        
        # Check if the prediction is correct
        correct_predictions = (predicted == labels)
        correct_cpu = correct_predictions.cpu().numpy()
        
        image_names = [a[0] for a in test_loader.dataset.dataset.samples[j * test_loader.batch_size: (j + 1) * test_loader.batch_size]]
        j += 1
        
        batch_df = pd.DataFrame({
            "Image": image_names,
            "Real Label": labels_cpu,
            "Predicted Label": predicted_cpu,
            "Correct Prediction": correct_cpu.astype(int)
        })
        
        # Append the batch DataFrame to the list
        test_results_df = pd.concat([test_results_df, batch_df], ignore_index=True)

In [19]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")

In [20]:
# Assuming you have a DataFrame named 'test_results_df' with a 'image_filename' column
test_results_df['yellow'] = test_results_df['Image'].apply(lambda x: x.split('_')[1])
test_results_df['blue'] = test_results_df['Image'].apply(lambda x: x.split('_')[2])
test_results_df['weber_ratio'] = test_results_df.apply(lambda row: max(int(row['yellow']), int(row['blue'])) / min(int(row['yellow']), int(row['blue'])), axis=1)

In [21]:
import seaborn as sns
import matplotlib.pyplot as plt

# Calculate the percentage of correct answers for each Weber ratio
correct_percentage = test_results_df.groupby('weber_ratio')['Correct Prediction'].mean() * 100

# Create a bar plot
plt.figure(figsize=(10, 6))
sns.barplot(x=correct_percentage.index, y=correct_percentage.values)
plt.xlabel('Weber Ratio')
plt.ylabel('Percentage of Correct Answers')
plt.title('Percentage of Correct Answers for Each Weber Ratio')
plt.xticks(rotation=45)
plt.tight_layout()

# Show the plot
plt.show()


In [9]:
# Test de model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()

In [10]:
print(f"Accuracy on the {total} test images: {100 * correct / total}%")