In [4]:
from itertools import chain
import torch
from torch import ones, zeros, eye, as_tensor, tensor, float32
import pickle
import h5py
import numpy as np
from torch import nn
from torch.optim import Adam
import pandas as pd

from copy import deepcopy

import sys
sys.path.append("../../../nex/rgc/utils/")
from data_utils import (
    read_data,
    build_avg_recordings,
    build_training_data,
)

In [5]:
print("torch", torch.__version__)
print("pandas", pd.__version__)
print("numpy", np.__version__)

torch 2.4.0
pandas 2.2.2
numpy 1.26.4


In [6]:
import pandas as pd
from warnings import simplefilter
simplefilter(action="ignore", category=pd.errors.PerformanceWarning)

In [7]:
def eval_rho(net, data, inds, loss_weights, labels):
    predictions = net(data[inds])
    labels_split = labels[inds]
    loss_weights_split = torch.as_tensor(loss_weights[inds], dtype=torch.bool)

    rhos = []
    # Loop across all ROIs.
    for i in range(predictions.shape[1]):
        loss_weights_split_of_roi = loss_weights_split[:, i]
        relevant_pred_roi = predictions[loss_weights_split_of_roi, i]
        relevant_labels_roi = labels_split[loss_weights_split_of_roi, i]
        rho_roi = np.corrcoef(relevant_pred_roi.detach().numpy(), relevant_labels_roi.numpy())[0, 1]
        rhos.append(rho_roi)
    rho2 = np.mean(rhos)

    return rho2


def eval_nn(net_, data, labels, loss_weights, seed, val_frac, num_test, verbose=False):
    _ = torch.manual_seed(seed)

    num_datapoints = len(data)
    # Validation fraction is computed without considering the test set.
    test_frac = num_test / num_datapoints
    val_frac = val_frac * (1 - test_frac)
    num_train = int(num_datapoints * (1 - val_frac - test_frac))
    
    num_splits = 1
    train_inds = []
    val_inds = []
    test_inds = []
    for i in range(num_splits):
        permutation = torch.randperm(num_datapoints)
        # Perform data splits.
        test_inds.append(permutation[:num_test])
        train_inds.append(permutation[num_test:num_test+num_train])
        val_inds.append(permutation[num_test+num_train:])

    cross_val_test_accuracies = []
    cross_val_val_accuracies = []
    cross_val_train_accuracies = []
    for split in range(num_splits):
        net = deepcopy(net_)
        optimizer = Adam(list(net.parameters()), lr=1e-3)
    
        best_rho = -20.0
        
        for epoch in range(500):
            optimizer.zero_grad()
            predictions = net(data[train_inds[split]])
            diff = (predictions - labels[train_inds[split]])**2
            loss = diff * loss_weights[train_inds[split]]
            loss = torch.mean(loss)
            loss.backward()
            optimizer.step()
    
            # Evaluation
            if epoch % 10 == 0:
                rho = eval_rho(net, data, val_inds[split], loss_weights, labels)

                if rho > best_rho:
                    best_rho = rho
                    best_rho_train = eval_rho(net, data, train_inds[split], loss_weights, labels)
                    best_rho_test = eval_rho(net, data, test_inds[split], loss_weights, labels)
            
        cross_val_test_accuracies.append(best_rho_test)
        cross_val_train_accuracies.append(best_rho_train)
        cross_val_val_accuracies.append(best_rho)
        
    return np.mean(cross_val_test_accuracies), np.mean(cross_val_train_accuracies), np.mean(cross_val_val_accuracies)

In [8]:
all_ann_accuracies = {}
all_linreg_accuracies = {}

In [26]:
################## RECORDINGS ##################
linreg_accuracies = {"train": [], "test": []}
ann_accuracies = {"train": [], "test": []}
myrec = 3  # 1,2,3,5,7,9,13

for n_train in [32,  64, 128, 256, 512]:
    test_num = 512
    start_n_scan = 0

    num_datapoints_per_scanfield = test_num + n_train
    cell_id = "20161028_1"
    rec_ids = [myrec]
    nseg = 4
    
    stimuli, recordings, setup, noise_full = read_data(
        start_n_scan,
        num_datapoints_per_scanfield,
        cell_id,
        rec_ids,
        "noise",
        ".."
    )
    
    avg_recordings = build_avg_recordings(
        recordings, rec_ids, nseg, num_datapoints_per_scanfield
    )
    
    ################## DATASET ##################
    number_of_recordings_each_scanfield = list(avg_recordings.groupby("rec_id").size())
    print(f"number_of_recordings_each_scanfield {number_of_recordings_each_scanfield}")
    number_of_recordings = np.sum(number_of_recordings_each_scanfield)
    assert len(number_of_recordings_each_scanfield) == len(rec_ids)
    
    # Back to ANN code
    linears = []
    mlps = []
    
    linears_train = []
    mlps_train = []
    n_out = number_of_recordings
    
    warmup = 5.0
    i_amp = 0.1
    
    _, labels, loss_weights = build_training_data(
        i_amp,
        stimuli,
        avg_recordings,
        rec_ids,
        num_datapoints_per_scanfield,
        number_of_recordings_each_scanfield,
    )
    
    data_global = torch.as_tensor(np.reshape(noise_full, (300, noise_full.shape[2])).T)
    labels_global = torch.as_tensor(labels)
    loss_weights = torch.as_tensor(loss_weights)
    
    ################## TRAINING ##################
    val_frac = 0.2

    linreg = {"train": [], "test": []}
    ann = {"train": [], "test": []}
    for seed in range(5):
        _ = torch.manual_seed(seed)
        net = nn.Linear(300, n_out)
        linears, linears_train, linears_val = eval_nn(net, data_global, labels_global, loss_weights, seed+1, val_frac, test_num, False)
        # Average across splits (although currently only one is used).
        linreg["train"].append(np.mean(linears_train))
        linreg["test"].append(np.mean(linears))
        
        _ = torch.manual_seed(seed)
        net = nn.Sequential(
            nn.Linear(300, 100),
            nn.ReLU(),
            nn.Linear(100, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50, n_out)
        )
        mlps, mlps_train, mlps_val = eval_nn(net, data_global, labels_global, loss_weights, seed+1, val_frac, test_num, False)
        # Average across splits (although currently only one is used).
        ann["train"].append(np.mean(mlps_train))
        ann["test"].append(np.mean(mlps))

    # Take average across all seeds.
    linreg_accuracies["train"].append(np.mean(linreg["train"]))
    linreg_accuracies["test"].append(np.mean(linreg["test"]))

    ann_accuracies["train"].append(np.mean(ann["train"]))
    ann_accuracies["test"].append(np.mean(ann["test"]))

all_ann_accuracies[myrec] = ann_accuracies
all_linreg_accuracies[myrec] = linreg_accuracies

number_of_recordings_each_scanfield [15]
number_of_recordings_each_scanfield [15]
number_of_recordings_each_scanfield [15]
number_of_recordings_each_scanfield [15]
number_of_recordings_each_scanfield [15]


In [28]:
with open(f"../results/05_ann_inductive_bias/linreg_accuracies.pkl", "wb") as handle:
    pickle.dump(all_linreg_accuracies, handle)

with open(f"../results/05_ann_inductive_bias/ann_accuracies.pkl", "wb") as handle:
    pickle.dump(all_ann_accuracies, handle)