In [1]:
import warnings
import torch
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm

import config
from utils import get_image_paths
from complete_training_data_processing import CompleteModelsDataProcessor
from graph_models import FullGraphModel
from from_retina_to_connectome_utils import (
    select_random_images,
    initialize_results_df,
    clean_model_outputs,
    update_results_df,
    update_running_loss,
)

from wandb_utils import WandBLogger

warnings.filterwarnings(
    "ignore",
    message="invalid value encountered in cast",
    category=RuntimeWarning,
    module="wandb.sdk.data_types.image",
)

torch.manual_seed(1234)

num_epochs = config.num_epochs
training_images_dir = config.TRAINING_DATA_DIR
small = config.small
small_length = config.small_length

  _C._set_default_tensor_type(t)


In [2]:
# get data
training_images = get_image_paths(
    training_images_dir, small, small_length
)
data_processor = CompleteModelsDataProcessor()
data_processor.create_voronoi_cells()

In [3]:
model = FullGraphModel(
    input_shape=data_processor.synaptic_matrix.shape[0],
    num_connectome_passes=config.NUM_CONNECTOME_PASSES,
    decision_making_vector=data_processor.decision_making_vector,
    log_transform_weights=config.log_transform_weights,
    batch_size=config.batch_size,
    dtype=config.dtype,
    num_edges=data_processor.synaptic_matrix.nnz,
    retina_connection=False,
).to(config.DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=config.base_lr)

criterion = BCEWithLogitsLoss()

In [4]:
wandb_logger = WandBLogger("adult_complete")
wandb_logger.initialize()

model.train()
iterations = (
    config.debug_length
    if config.debugging
    else len(training_images) // config.batch_size
)
results = initialize_results_df()
running_loss, total_correct, total = 0, 0, 0
for ep in range(num_epochs):
    # train
    already_selected = []
    for i in tqdm(range(iterations)):
        batch_files, already_selected = select_random_images(
            training_images, config.batch_size, already_selected
        )
        # recreate voronoi cells so they are different
        data_processor.create_voronoi_cells()
        inputs, labels = data_processor.process_batch(batch_files)
        optimizer.zero_grad()
        out = model(inputs)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()

        # Calculate run parameters
        outputs, predictions, labels_cpu, correct = clean_model_outputs(out, labels)
        results = update_results_df(
            results, batch_files, outputs, predictions, labels_cpu, correct
        )
        running_loss += update_running_loss(loss, inputs)
        total += config.batch_size
        total_correct += correct.sum()

        wandb_logger.log_metrics(i, running_loss, total_correct, total, results)

    print(
        f"Finished epoch {ep} with loss {running_loss / total} and accuracy {total_correct / total}"
    )
    torch.cuda.empty_cache()

print(
    f"Finished training with loss {running_loss / total} and accuracy {total_correct / total}"
)
wandb_logger.finish()

100%|██████████| 2/2 [00:03<00:00,  1.68s/it]


Finished epoch 0 with loss 46737.92021846771 and accuracy 0.484375


100%|██████████| 2/2 [00:02<00:00,  1.39s/it]

Finished epoch 1 with loss 46379.585933983326 and accuracy 0.5078125
Finished training with loss 46379.585933983326 and accuracy 0.5078125





coses a fer:
- canviar les cel·les de voronoi a cada batch
- fer múltipes passades per veure la mateixa imatge amb més d'una tesselació
- escollir menys neurones pel decision making