In [2]:
import os

import pandas as pd
from torch import nn, optim, device, cuda
import torch

import yaml
from tqdm.notebook import tqdm, trange

import logging
import wandb
from torch.utils.tensorboard import SummaryWriter

from data_parser import adj_matrix, nodes
from image_parser import train_loader, test_loader, debug_loader
from utils import log_training_images

from models import CombinedModel
from model_config_manager import ModelConfigManager
from model_manager import ModelManager

import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
config = yaml.safe_load(open("config.yml"))
DEBUG = config["DEBUG"]
epochs = config["EPOCHS"] if not DEBUG else 2
RETINA_MODEL = config["RETINA_MODEL"]
images_fraction = config["IMAGES_FRACTION"]
continue_training = config["CONTINUE_TRAINING"]
saved_model_path = config["SAVED_MODEL_PATH"]
model_name = config["SAVED_MODEL_NAME"]
save_every = config["SAVE_EVERY"]
connectome_layer_number = config["CONNECTOME_LAYER_NUMBER"]

loader = debug_loader if DEBUG else train_loader

logging.basicConfig(
    filename="training_log.log", 
    level=logging.DEBUG if DEBUG else logging.INFO, 
    format="%(asctime)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("training_log")
logger.addHandler(logging.StreamHandler())

dev = device("cuda" if cuda.is_available() else "cpu")    
if dev.type == "cpu":
    logger.warning("WARNING: Running on CPU, so it might be slow")

# Create the ModelConfigManager and load configurations from YAML files
config_manager = ModelConfigManager(config)

# Get a specific configuration by model name
config_manager.set_model_config(RETINA_MODEL)

In [4]:
combined_model = CombinedModel(adj_matrix, neurons=nodes, model_config=config_manager.current_model_config)

# Saving and loading manager
model_manager = ModelManager(config, clean_previous = True)

if continue_training:
    # If we want to continue training a saved model
    model_manager.load_model(combined_model, saved_model_path)

combined_model = combined_model.to(dev)

# Logs
# wandb sometimes screws up, so we might want to disable it
wb = True
if wb:
    wandb.init(project="connectome", config=config_manager.current_model_config)
# I also want the model architecture
tensorboard_writer = SummaryWriter()

# Specify the loss function and the optimizer
criterion = nn.NLLLoss()
optimizer = optim.Adam(combined_model.parameters(), lr=0.00001)

if wb:
    _ = wandb.watch(combined_model, criterion, log="all") 


[34m[1mwandb[0m: Currently logged in as: [33meudald[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [5]:
config_manager.output_model_details()
if DEBUG:
    logger.warning("WARNING: Running on DEBUG mode, so using 10% of the images")
elif images_fraction < 1:
    logger.warning(f"WARNING: Using {images_fraction * 100}% of the images")
if continue_training:
    logger.warning("Warning: I'm training and already trained model")

Model configurations:
Model name: cnn_1
Number of connectome layers: 2
Number of retina layers: 1
Output channels: 1
Kernel size: 5
Stride: 2
Padding: 1




In [None]:
for epoch in trange(epochs):
    running_loss = 0
    correct_predictions = 0
    
    # Log model architecture to tensorboard
    if epoch == 0:
        tensorboard_writer.add_graph(combined_model, torch.rand(1, 3, 512, 512).to(dev), verbose=False)
    
    # If the model is fast~ish
    # for images, labels in train_loader:
    for images, labels in tqdm(loader):

        # Move images and labels to the device
        images, labels = images.to(dev), labels.to(dev)
            
        # Forward pass
        outputs = combined_model(images)

        # Compute the loss
        loss = criterion(outputs, labels)
        
        # Compute the accuracy
        predicted_labels = torch.argmax(outputs, dim=1)
        correct_predictions = (predicted_labels == labels).sum().item()
        accuracy = correct_predictions / len(labels)

        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update loss
        running_loss += loss.item()
        
        # Logs to wandb
        if wb:
            wandb.log({"loss": loss.item(), "accuracy": accuracy, "epoch": epoch})
        log_training_images(images, labels, outputs)

    logger.info(f'Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(train_loader)}')
    
    # Save model
    if (epoch + 1) % save_every == 0:
        model_manager.save_model(combined_model, epoch)

# Close logs
if wb:
    wandb.finish()
tensorboard_writer.close()

# Clean intermedidate models
model_manager.clean_previous_runs()

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  check = torch.cuda.FloatTensor(1).fill_(0)
Epoch 1/100, Loss: 3.8950953148227416


  0%|          | 0/1000 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:01.723958, resuming normal operation.
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 encountered ({"error":"rate limit exceeded"}), retrying request
[34m[1mwandb[0m: 429 enco

  0%|          | 0/1000 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:01.336014, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:00:01.418141, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:00:03.789109, resuming normal operation.
Epoch 3/100, Loss: 0.5422907839640975


  0%|          | 0/1000 [00:00<?, ?it/s]

[34m[1mwandb[0m: Network error resolved after 0:00:07.944725, resuming normal operation.
[34m[1mwandb[0m: Network error resolved after 0:00:07.705278, resuming normal operation.
Epoch 4/100, Loss: 0.43888491443917155


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 5/100, Loss: 0.35932101981155573


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 6/100, Loss: 0.30743194065010176


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 7/100, Loss: 0.2376996778359171


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 8/100, Loss: 0.20656843411945738


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 9/100, Loss: 0.17661806338914904


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 10/100, Loss: 0.14327574405234192
Saved model after 10 runs


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 11/100, Loss: 0.09521678857006191


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 12/100, Loss: 0.11585164472763063


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 13/100, Loss: 0.10478019039016635


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 14/100, Loss: 0.06836845326331445


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 15/100, Loss: 0.05914963753125448


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 16/100, Loss: 0.06734727846351597


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 17/100, Loss: 0.07954374932516874


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 18/100, Loss: 0.040776129772599916


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 19/100, Loss: 0.03471360785443761


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 20/100, Loss: 0.038702972921278285
Saved model after 20 runs


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 21/100, Loss: 0.051380980976086815


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 22/100, Loss: 0.03865878658497902


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 23/100, Loss: 0.03232115250509607


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 24/100, Loss: 0.03534471280793535


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 25/100, Loss: 0.0382659070627692


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 26/100, Loss: 0.030802768591195785


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 27/100, Loss: 0.021069424193701402


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 28/100, Loss: 0.027712174238426063


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 29/100, Loss: 0.031322927203255226


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 30/100, Loss: 0.030312454633997645
Saved model after 30 runs


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 31/100, Loss: 0.023013409793225276


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 32/100, Loss: 0.01762182297586243


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 33/100, Loss: 0.015133971230845328


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 34/100, Loss: 0.023709632054141262


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 35/100, Loss: 0.0247464727926882


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 36/100, Loss: 0.0239291314337145


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 37/100, Loss: 0.025615768576564914


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 38/100, Loss: 0.019999024732209717


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 39/100, Loss: 0.009393984254222684


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch 40/100, Loss: 0.013393875003063596
Saved model after 40 runs


  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
# torch.save(combined_model.state_dict(), os.path.join("models", f"model_{RETINA_MODEL}_{epoch + 1}_epochs.pth"))

In [6]:
# Test de model
correct = 0
total = 0
test_results_df = pd.DataFrame(columns=["Image", "Real Label", "Predicted Label", "Correct Prediction"])

j = 0
with torch.no_grad():
    for images, labels in tqdm(test_loader):
        images, labels = images.to(dev), labels.to(dev)
        outputs = combined_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum()
        # Convert the tensor values to CPU and numpy
        labels_cpu = labels.cpu().numpy()
        predicted_cpu = predicted.cpu().numpy()
        
        # Check if the prediction is correct
        correct_predictions = (predicted == labels)
        correct_cpu = correct_predictions.cpu().numpy()
        
        image_names = [a[0] for a in test_loader.dataset.dataset.samples[j * test_loader.batch_size: (j + 1) * test_loader.batch_size]]
        j += 1
        
        batch_df = pd.DataFrame({
            "Image": image_names,
            "Real Label": labels_cpu,
            "Predicted Label": predicted_cpu,
            "Correct Prediction": correct_cpu.astype(int)
        })
        
        # Append the batch DataFrame to the list
        test_results_df = pd.concat([test_results_df, batch_df], ignore_index=True)

logger.info(f"Accuracy on the {total} test images: {100 * correct / total}%")

  0%|          | 0/250 [00:00<?, ?it/s]

Accuracy on the 2000 test images: 53.500003814697266%
[34m[1mwandb[0m: While tearing down the service manager. The following error has occurred: [WinError 10054] An existing connection was forcibly closed by the remote host


In [None]:
# Calculate the percentage of correct answers for each Weber ratio
test_results_df['yellow'] = test_results_df['Image'].apply(lambda x: x.split('_')[1])
test_results_df['blue'] = test_results_df['Image'].apply(lambda x: x.split('_')[2])
test_results_df['weber_ratio'] = test_results_df.apply(lambda row: max(int(row['yellow']), int(row['blue'])) / min(int(row['yellow']), int(row['blue'])), axis=1)
correct_percentage = test_results_df.groupby('weber_ratio')['Correct Prediction'].mean() * 100

# Plot
plt.figure(figsize=(10, 6))
sns.barplot(x=correct_percentage.index, y=correct_percentage.values)
plt.xlabel('Weber Ratio')
plt.ylabel('Percentage of Correct Answers')
plt.title('Percentage of Correct Answers for Each Weber Ratio')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


In [None]:
100 * correct / total