# Originals and Recons on the data-set level

## This will be used for BUAN score comparisons

In [1]:
import numpy as np
import torch
import torch.nn.functional as F
from models import AllModels
from dipy.segment.bundles import bundle_shape_similarity
import gc


In [2]:
# Parameters for every model
batch_size = 256
num_training_updates = 30_000
num_epochs = 1

num_hiddens = 128
num_residual_hiddens = 32
num_residual_layers = 2

embedding_dim = 64
num_embeddings = 512

commitment_cost = 0.25

decay = 0.99
tau = 10.0
std = 2.0

In [3]:
torch.manual_seed(1337)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
val_set = torch.load('data_ml/val_set.pt')

## Important Note:

The `recons` and `originals` directory will be used for the dataset as is.

However the `subject_recons`, `subject_originals` will accumulate the originals across the subjects to reconstruct the entire bundle

In [11]:
# Collect corresponding tracts
tracts = {}

for i, (tract, label) in enumerate(val_set):
    subject, name_of_tract = label.split('__')

    if subject in tracts:
        tracts[subject].append([name_of_tract, tract])
    else:
        tracts[subject] = []
        tracts[subject].append([name_of_tract, tract])

In [41]:
num_tracts_per_sub = []
for subject in tracts.keys():
    hold = set()
    for name_of_tract, _ in tracts[subject]:
        hold.add(name_of_tract)
    num_tracts_per_sub.append((len(hold), subject))

In [46]:
num_tracts_per_sub.sort(key=lambda x:x[0], reverse=True)

In [49]:
num_tracts_per_sub[0]

(12, 'sub-1178')

In [54]:
tracts['sub-1178'][0]
print(len(tracts['sub-1178']))

194


In [52]:
# Test all 6 models
print('Configuring Models')
configs_to_run = [[False for _ in range(5)] for _ in range(5)]
for i in range(5): 
    for j in range(5):
        if j == i:
            configs_to_run[i][j] = True


Configuring Models


In [59]:
def reconstruct_full_tract(configs_to_run, subject):
    for config in configs_to_run:
        # current configuration
        ae = config[0]; vae = config[1]; vq=config[2]; vq_ema = config[3]; vq_diff = config[4]

        model = AllModels(num_hiddens, num_residual_layers, num_residual_hiddens,
            num_embeddings, embedding_dim, commitment_cost, tau, std, decay, 
            ae, vae, vq, vq_ema, vq_diff
            ).to(device)

        if ae:
            model.load_state_dict(torch.load('saved_models/ae.pt'))
        elif vae:
            model.load_state_dict(torch.load('saved_models/vae.pt'))
        elif vq:
            model.load_state_dict(torch.load('saved_models/vq.pt'))
        elif vq_ema:
            model.load_state_dict(torch.load('saved_models/vq_ema.pt'))
        else:
            model.load_state_dict(torch.load('saved_models/vq_diff.pt'))

        m = model.eval()
        print(m.model_used())

        recons = {}
        originals = {}
        for label, x in subject:

            # Keep a record of the original tract and express as numpy
            if label in originals:
                originals[label].append(x.permute(1,0,2).detach().cpu().numpy())
            else:
                originals[label] = [x.permute(1,0,2).detach().cpu().numpy()]
            
            # Then reconstruct
            x = x.to(device).unsqueeze(0)
            vq_loss, x_recon, perplexity, encodings = m(x)
            x_recon = x_recon.squeeze(0)

            # Keep a record of the recon tract and exress as numpy
            if label in recons:
                recons[label].append(x_recon.permute(1,0,2).detach().cpu().numpy())
            else:
                recons[label] = [x_recon.permute(1,0,2).detach().cpu().numpy()]

        # concatenate the full reconstruction 
        for label in originals.keys():

            originals[label] = np.concatenate(originals[label])
            np.save(f'./subject_originals/{label}.npy', originals[label])

            recons[label] = np.concatenate(recons[label])
            np.save(f'./subject_recons/{label}_{m.model_used()}.npy', recons[label])


        # Free up memory
        del model
        torch.cuda.empty_cache()
        gc.collect()

        

In [60]:
#reconstruct_full_tract(configs_to_run, tracts['sub-1178'])

ae
vae
vq
vq_ema
vq_diff
