In [1]:
import warnings
import torch
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
import pandas as pd
import numpy as np

import config
from utils import get_image_paths

from graph_models import FullGraphModel
from from_retina_to_connectome_utils import (
    select_random_images,
    paths_to_labels,
    initialize_results_df,
    clean_model_outputs,
    update_results_df,
    update_running_loss,
)
from scipy.sparse import load_npz
from complete_training_funcs import (
    get_voronoi_cells,
    import_images,
    process_images,
    get_voronoi_averages,
    assign_cell_type,
    get_neuron_activations,
    get_side_decision_making_vector,
)
from torch_geometric.data import Data, Batch
from wandb_utils import WandBLogger

warnings.filterwarnings(
    "ignore",
    message="invalid value encountered in cast",
    category=RuntimeWarning,
    module="wandb.sdk.data_types.image",
)

torch.manual_seed(1234)

training_images_dir = config.TRAINING_DATA_DIR
small = config.small
small_length = config.small_length
ommatidia_size = 8

  _C._set_default_tensor_type(t)


In [2]:
# get data
# todo: move elsewhere
right_visual = pd.read_csv("adult_data/right_filtered.csv").drop(
    columns=["side", "x"]
)
neuron_indices, voronoi_indices = get_voronoi_cells(right_visual)
right_visual["voronoi_indices"] = neuron_indices
right_visual["cell_type"] = right_visual.apply(assign_cell_type, axis=1)
right_visual = right_visual.drop(columns=["y", "z"])
right_root_ids = pd.read_csv("adult_data/right_root_id_to_index.csv")
training_images = get_image_paths(training_images_dir, False, small_length)

decision_making_vector = get_side_decision_making_vector(right_root_ids, "right")
synaptic_matrix = load_npz("adult_data/right_synaptic_matrix.npz")

In [3]:
model = FullGraphModel(
    input_shape=synaptic_matrix.shape[0],
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=decision_making_vector,
    log_transform_weights=config.log_transform_weights,
    batch_size=config.batch_size,
    dtype=config.dtype,
    retina_connection=False
).to(config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=config.base_lr)

criterion = BCEWithLogitsLoss()

# train
results = initialize_results_df()
already_selected = []
running_loss, total_correct, total = 0, 0, 0

In [4]:
wandb_logger = WandBLogger("adult_complete")
wandb_logger.initialize()

model.train()
iterations = (
    config.debug_length
    if config.debugging
    else len(training_images) // config.batch_size
)
for i in tqdm(range(iterations)):
    batch_files, already_selected = select_random_images(
        training_images, config.batch_size, already_selected
    )
    labels = paths_to_labels(batch_files)
    imgs = import_images(batch_files)
    processed_imgs = process_images(imgs, voronoi_indices)
    voronoi_averages = get_voronoi_averages(processed_imgs)
    neuron_activations = pd.concat(
        [get_neuron_activations(right_visual, a) for a in voronoi_averages],
        axis=1,
    )
    activation_df = (
        right_root_ids.merge(
            neuron_activations, left_on="root_id", right_index=True, how="left"
        )
        .fillna(0)
        .set_index("index")
        .drop(columns=["root_id"])
    )
    edges = torch.tensor(
        np.array([synaptic_matrix.row, synaptic_matrix.col]),
        # Note: the edges need to be specificaly int64
        dtype=torch.int64,
    )
    weights = torch.tensor(synaptic_matrix.data, dtype=config.dtype)
    activation_tensor = torch.tensor(activation_df.values, dtype=config.dtype)
    graph_list_ = []
    for j in range(activation_tensor.shape[1]):
        # Shape [num_nodes, 1], one feature per node
        node_features = activation_tensor[:, j].unsqueeze(1)
        graph = Data(
            x=node_features,
            edge_index=edges,
            edge_attr=weights,
            y=labels[j],
            device=config.DEVICE,
        )
        graph_list_.append(graph)

    inputs = Batch.from_data_list(graph_list_).to(config.DEVICE)
    labels = torch.tensor(labels, dtype=config.dtype).to(config.DEVICE)

    optimizer.zero_grad()
    out = model(inputs)
    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()

    # Calculate run parameters
    outputs, predictions, labels_cpu, correct = clean_model_outputs(out, labels)
    results = update_results_df(
        results, batch_files, outputs, predictions, labels_cpu, correct
    )
    running_loss += update_running_loss(loss, inputs)
    total += config.batch_size
    total_correct += correct.sum()

    wandb_logger.log_metrics(i, running_loss, total_correct, total, results)


print(
    f"Finished training with loss {running_loss / total} and accuracy {total_correct / total}"
)
torch.cuda.empty_cache()
wandb_logger.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33meudald[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112425700007911, max=1.0…

100%|██████████| 750/750 [42:11<00:00,  3.37s/it]


Finished training with loss 46232.714590812844 and accuracy 0.503


VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
accuracy,▁▄▆▇█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
iteration,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▅▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
accuracy,0.503
epoch,0.0
iteration,32.0
loss,46232.71459


coses a fer:
- canviar les cel·les de voronoi a cada batch
- fer múltipes passades per veure la mateixa imatge amb més d'una tesselació