In [1]:
import warnings
import torch
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

import config
from utils import get_image_paths
from complete_training_data_processing import CompleteModelsDataProcessor
from graph_models import FullGraphModel
from from_retina_to_connectome_utils import (
    select_random_images,
    initialize_results_df,
    clean_model_outputs,
    update_results_df,
    update_running_loss,
)

from wandb_utils import WandBLogger

warnings.filterwarnings(
    "ignore",
    message="invalid value encountered in cast",
    category=RuntimeWarning,
    module="wandb.sdk.data_types.image",
)

torch.manual_seed(1234)

training_images_dir = config.TRAINING_DATA_DIR
small = config.small
small_length = config.small_length
ommatidia_size = 8

  _C._set_default_tensor_type(t)


In [3]:
# get data
training_images = get_image_paths(
    training_images_dir, small, small_length
)
data_processor = CompleteModelsDataProcessor()

In [4]:
model = FullGraphModel(
    input_shape=data_processor.synaptic_matrix.shape[0],
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    log_transform_weights=config.log_transform_weights,
    batch_size=config.batch_size,
    dtype=config.dtype,
    num_edges=data_processor.synaptic_matrix.nnz,
    retina_connection=False,
).to(config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=config.base_lr)

criterion = BCEWithLogitsLoss()

# train
results = initialize_results_df()
already_selected = []
running_loss, total_correct, total = 0, 0, 0

In [5]:
wandb_logger = WandBLogger("adult_complete")
wandb_logger.initialize()

model.train()
iterations = (
    config.debug_length
    if config.debugging
    else len(training_images) // config.batch_size
)
for i in tqdm(range(iterations)):
    batch_files, already_selected = select_random_images(
        training_images, config.batch_size, already_selected
    )
    inputs, labels = data_processor.process_batch(batch_files)
    optimizer.zero_grad()
    out = model(inputs)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()

    # Calculate run parameters
    outputs, predictions, labels_cpu, correct = clean_model_outputs(out, labels)
    results = update_results_df(
        results, batch_files, outputs, predictions, labels_cpu, correct
    )
    running_loss += update_running_loss(loss, inputs)
    total += config.batch_size
    total_correct += correct.sum()

    wandb_logger.log_metrics(i, running_loss, total_correct, total, results)

print(
    f"Finished training with loss {running_loss / total} and accuracy {total_correct / total}"
)
torch.cuda.empty_cache()
wandb_logger.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33meudald[0m. Use [1m`wandb login --relogin`[0m to force relogin


100%|██████████| 2/2 [00:03<00:00,  1.58s/it]


Finished training with loss 46482.807758152485 and accuracy 0.5


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁█
epoch,▁▁
iteration,▁█
loss,█▁

0,1
accuracy,0.5
epoch,0.0
iteration,2.0
loss,46482.80776


coses a fer:
- canviar les cel·les de voronoi a cada batch
- fer múltipes passades per veure la mateixa imatge amb més d'una tesselació
- escollir menys neurones pel decision making

+---------------------------------+----------------------------------------------------+----------------+----------+
| Layer                           | Input Shape                                        | Output Shape   | #Param   |
|---------------------------------+----------------------------------------------------+----------------+----------|
| FullGraphModel                  |                                                    | [32]           | 2        |
| ├─(connectome)TrainableEdgeConv | [2133056, 1], [2, 39398144], [39398144], [2133056] | [2133056, 1]   | --       |
| ├─(final_fc)Linear              | [32, 1, 1]                                         | [32, 1, 1]     | 2        |
+---------------------------------+----------------------------------------------------+----------------+----------+
