In [13]:
import torch
import pickle

import numpy as np
from copy import copy, deepcopy
from utils.vmf_batch import vMF
from models import SeqEncoder, SeqDecoder, Seq2Seq_VAE, PoolingClassifier
from utils.cluster_utils import _convert_cluster_results_dict_into_array, get_clustered_rws_agglom, tree_from_clustered_result
from utils.sampling_utils import _fill_with_infty, decode_z, sample_rws

## plotting ###

import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt
import seaborn as sns

%matplotlib inline

In [2]:
with open('./data/toy_data/3_populations/walk_representation_32.npy', 'rb') as f:
    walk_representation = np.load(f)

In [3]:
with open('./data/toy_data/3_populations/iterator/test_iterator.pkl', 'rb') as f:
    test_iterator = pickle.load(f)

In [4]:
SEED = 17
# get data
np.random.seed(SEED)
torch.random.manual_seed(SEED)
src_data, trg_data, seq_len, indices, labels = list(test_iterator)[0]
rw_i = np.round(trg_data, 2)

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

INPUT_DIM = 3
EMBED_DIM = 16
HIDDEN_DIM = 16
LATENT_DIM = 8
NUM_LAYERS = 2
KAPPA = 500
DROPOUT =.1

# model
enc = SeqEncoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, DROPOUT)
dec = SeqDecoder(INPUT_DIM, EMBED_DIM, HIDDEN_DIM, NUM_LAYERS, DROPOUT)
dist = vMF(LATENT_DIM, kappa=KAPPA, device=device)
model = Seq2Seq_VAE(enc, dec, dist, device).to(device)

KLD: 18.465579986572266


In [10]:
state_dict = torch.load('./models/parameter_search/emb16_hid16_lat8_dp0.1_k500_avg_run1_best.pt')
model.load_state_dict(state_dict['model_state_dict'])
model.eval()
with torch.no_grad():
    bs, n_walks, walk_length, input_dim = src_data.shape
    src = src_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    # src = [walk length , bs * n_walks, input_dim]
    trg = trg_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    seq_len = seq_len.view(-1).to(device)
    output = model(src, seq_len, trg, 0)

## sample neurons

In [12]:
for k in range(len(indices)):
    
    for kappa in [100,300,500]:
        
        vmf = vMF(LATENT_DIM, kappa=kappa)
        mus = model.h[k*n_walks: k*n_walks+n_walks]
        original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()
        decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                                 n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

         # cluster the rws
        clustered_rws = []
        clustered_results = []
        for rws in decoded_rws:
            clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.5 )
            clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
            clustered_results.append(clus_res)
        clustered_rws = np.vstack(clustered_rws)
        # reduce to trees
        for clus_res in clustered_results:
            N = tree_from_clustered_result(clus_res)
            N.write_to_swc('%i'%indices[k], path='./data/toy_data/3_populations/sampled_neurons/test_data/v3/k%i/'%kappa)

KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 1

# On real data

## M1 EXC

In [None]:
with open('./data/M1_exc_data/walks/walk_representation.npy', 'rb') as f:
    walk_representation = np.load(f)

with open('./data/M1_exc_data/iterator/m_labels/test_iterator.pkl', 'rb') as f:
    test_iterator = pickle.load(f)

In [None]:
SEED = 17
# get data
np.random.seed(SEED)
torch.manual_seed(SEED)
src_data, trg_data, seq_len, indices, labels = list(test_iterator)[0]
rw_i = np.round(trg_data, 2)

N, n_walks, walk_length, input_dim = src_data.shape

In [None]:
torch.cuda.empty_cache()

In [None]:
# model with the best validation loss 
state_dict = torch.load('./models/M1_exc/m_label/finetuned_vae_k500_frac1.0_best_run2.pt', map_location=device)
model.load_state_dict(state_dict['model_state_dict'])

In [None]:

model.eval()
with torch.no_grad():
    bs, n_walks, walk_length, input_dim = src_data.shape
    src = src_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    # src = [walk length , bs * n_walks, input_dim]
    trg = trg_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    seq_len = seq_len.view(-1).to(device)
    %timeit -r 3 -n 10 output = model(src, seq_len, trg, 0)
    output = model(src, seq_len, trg, 0)

In [None]:
### time the code: ####
timing = []
k = 0
kappa = 500

vmf = vMF(LATENT_DIM, kappa=kappa, device=device)
mus = model.h[k*n_walks: k*n_walks+n_walks]
original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()

o = %timeit -r 3 -n 100 -o decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)
timing.append(('sampling', o))

decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                         n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

# cluster the rws
clustered_rws = []
clustered_results = []
for rws in decoded_rws:
    o = %timeit -r 3 -o clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
    timing.append(('clustering', o))
    clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
    
    clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
    clustered_results.append(clus_res)
clustered_rws = np.vstack(clustered_rws)
# reduce to trees
for clus_res in clustered_results:
    o = %timeit -r 3 -o N = tree_from_clustered_result(clus_res)
    timing.append(('get_tree', o))

In [None]:

for k in range(len(indices)):
    
    for kappa in [100,300,500]:
        
        vmf = vMF(LATENT_DIM, kappa=kappa, device=device)
        mus = model.h[k*n_walks: k*n_walks+n_walks]
        original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()
        decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                                 n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

         # cluster the rws
        clustered_rws = []
        clustered_results = []
        for rws in decoded_rws:
            clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
            clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
            clustered_results.append(clus_res)
        clustered_rws = np.vstack(clustered_rws)
        # reduce to trees
        for clus_res in clustered_results:
            N = tree_from_clustered_result(clus_res)
            N.write_to_swc('%i'%indices[k], path='./data/M1_exc_data/sampled_neurons/test_data/k%i/'%kappa)

## M1 Inh data

In [None]:
with open('./data/M1_inh_data/walks/axon/walk_representation_32.npy', 'rb') as f:
    walk_representation = np.load(f)

with open('./data/M1_inh_data/iterator/axon/test_iterator_32.pkl', 'rb') as f:
    test_iterator = pickle.load(f)

In [None]:
SEED = 17
# get data
np.random.seed(SEED)
torch.manual_seed(SEED)
src_data, trg_data, seq_len, indices, labels = list(test_iterator)[0]
rw_i = np.round(trg_data, 2)

N, n_walks, walk_length, input_dim = src_data.shape

In [None]:
# model with the best validation loss 
state_dict = torch.load('./models/M1_inh/finetuned/axon/finetuned_vae_frac0.5_best_run2.pt',map_location=device)
model.load_state_dict(state_dict['model_state_dict'])

In [None]:

model.eval()
with torch.no_grad():
    bs, n_walks, walk_length, input_dim = src_data.shape
    src = src_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    # src = [walk length , bs * n_walks, input_dim]
    trg = trg_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    seq_len = seq_len.view(-1).to(device)
    %timeit -r 3 -n 10 output = model(src, seq_len, trg, 0)
    output = model(src, seq_len, trg, 0)

In [None]:
### time the code: ####
k = 0
kappa = 500

vmf = vMF(LATENT_DIM, kappa=kappa, device=device)
mus = model.h[k*n_walks: k*n_walks+n_walks]
original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()

o = %timeit -r 3 -n 100 -o decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)
timing.append(('inh', 'sampling', o))

decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                         n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

# cluster the rws
clustered_rws = []
clustered_results = []
for rws in decoded_rws:
    o = %timeit -r 3 -o clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
    timing.append(('inh','clustering', o))
    clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
    
    clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
    clustered_results.append(clus_res)
clustered_rws = np.vstack(clustered_rws)
# reduce to trees
for clus_res in clustered_results:
    o = %timeit -r 3 -o N = tree_from_clustered_result(clus_res)
    timing.append(('inh', 'get_tree', o))

In [None]:
for k in range(len(indices)):
    
    for kappa in [100,300,500]:
        
        vmf = vMF(LATENT_DIM, kappa=kappa)
        mus = model.h[k*n_walks: k*n_walks+n_walks]
        original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()
        decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                                 n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

         # cluster the rws
        clustered_rws = []
        clustered_results = []
        for rws in decoded_rws:
            clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.3)
            clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
            clustered_results.append(clus_res)
        clustered_rws = np.vstack(clustered_rws)
        # reduce to trees
        for clus_res in clustered_results:
            N = tree_from_clustered_result(clus_res)
            N.write_to_swc('%i'%indices[k], path='./data/M1_inh_data/sampled_neurons/axon/test_data/k%i/'%kappa)

## Farrow data

In [14]:
part = 'soma_centered'

with open('./data/Farrow_data/walks/%s/walk_representation.npy'%part, 'rb') as f:
    walk_representation = np.load(f)

with open('./data/Farrow_data/iterator/%s/test_iterator.pkl'%part, 'rb') as f:
    test_iterator = pickle.load(f)

In [16]:
SEED = 17

# get data
np.random.seed(SEED)
torch.manual_seed(SEED)
src_data, trg_data, seq_len, indices, labels = list(test_iterator)[0]
rw_i = np.round(trg_data, 2)

N, n_walks, walk_length, input_dim = src_data.shape
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = dict(input_dim =3, embed_dim=16, hidden_dim=16, latent_dim=8, num_layers = 2, kappa=500, dropout=.1)

LATENT_DIM = config['latent_dim']

# model with the best validation mse loss 
state_dict = torch.load('./models/Farrow/finetuned/%s/finetuned_vae_frac1.0_best_run1.pt'%part, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [17]:

model.eval()
with torch.no_grad():
    bs, n_walks, walk_length, input_dim = src_data.shape
    src = src_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    # src = [walk length , bs * n_walks, input_dim]
    trg = trg_data.view(-1,walk_length,input_dim).transpose(0,1).to(device)
    seq_len = seq_len.view(-1).to(device)
    %timeit -r 3 -n 10 output = model(src, seq_len, trg, 0)
    output = model(src, seq_len, trg, 0)

2.62 s ± 4.78 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [18]:
# time the code

### time the code: ####
k = 0
kappa = 500
timing = []
vmf = vMF(LATENT_DIM, kappa=kappa, device=device)
mus = model.h[k*n_walks: k*n_walks+n_walks]
original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()

o = %timeit -r 3 -n 100 -o decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)
timing.append(('rgc', 'sampling', o))

decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                         n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

# cluster the rws
clustered_rws = []
clustered_results = []
for rws in decoded_rws:
    o = %timeit -r 3 -o clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
    timing.append(('rgc','clustering', o))
    clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.4 )
    
    clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
    clustered_results.append(clus_res)
clustered_rws = np.vstack(clustered_rws)
# reduce to trees
for clus_res in clustered_results:
    o = %timeit -r 3 -o N = tree_from_clustered_result(clus_res)
    timing.append(('rgc', 'get_tree', o))

KLD: 18.465579986572266
97.6 ms ± 459 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)
28.6 ms ± 13.9 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
3.87 ms ± 7.91 µs per loop (mean ± std. dev. of 3 runs, 100 loops each)


In [19]:
for k in range(len(indices)):
    
    for kappa in [100,300,500]:
        
        vmf = vMF(LATENT_DIM, kappa=kappa)
        mus = model.h[k*n_walks: k*n_walks+n_walks]
        original_seq_len = seq_len[k*n_walks: k*n_walks+n_walks].cpu()
        decoded_rws = sample_rws(model, vmf, mus, orig_seq_len=original_seq_len,
                                 n_samples=1, max_trg_len=walk_length, min_angle=np.pi/2.4)

         # cluster the rws
        clustered_rws = []
        clustered_results = []
        for rws in decoded_rws:
            clus_res, clus_rws = get_clustered_rws_agglom(rws,dist_thresh=.25)
            clustered_rws.append(clus_rws.reshape((1,)+clus_rws.shape))
            clustered_results.append(clus_res)
        clustered_rws = np.vstack(clustered_rws)
        # reduce to trees
        for clus_res in clustered_results:
            N = tree_from_clustered_result(clus_res)
            N.write_to_swc('%i'%indices[k], path='./data/Farrow_data/sampled_neurons/soma_centered/test_data/k%i/'%kappa)

KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 11.350215911865234
KLD: 16.18794822692871
KLD: 18.465579986572266
KLD: 1

### urban data

In [21]:
part = 'soma_centered'

with open('./data/urban_data/walks/%s/walk_representation_16.npy'%part, 'rb') as f:
    walk_representation = np.load(f)

with open('./data/urban_data/iterator/%s/test_iterator.pkl'%part, 'rb') as f:
    test_iterator = pickle.load(f)

In [None]:
SEED = 17

# get data
np.random.seed(SEED)
torch.manual_seed(SEED)
src_data, trg_data, seq_len, indices, labels = list(test_iterator)[0]
rw_i = np.round(trg_data, 2)

N, n_walks, walk_length, input_dim = src_data.shape
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

config = dict(input_dim =3, embed_dim=16, hidden_dim=16, latent_dim=8, num_layers = 2, kappa=500, dropout=.1)

LATENT_DIM = config['latent_dim']

# model with the best validation mse loss 
state_dict = torch.load('./models/urban/finetuned/%s/finetuned_vae_frac1.0_best_run1.pt'%part, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])