In [1]:
%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=0
%pylab inline
%load_ext autoreload
%autoreload 2

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=0
%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [2]:
import json
import torch
from tqdm import tqdm
import seaborn as sns
from sklearn.manifold import TSNE

In [3]:
from ssl_neuron.datasets import GraphDataset
from ssl_neuron.utils import plot_neuron, plot_tsne, neighbors_to_adjacency_torch, compute_eig_lapl_torch_batch
from ssl_neuron.graphdino import create_model

#### Load config

In [4]:
config = json.load(open('../configs/config.json'))

In [5]:
config

{'model': {'num_classes': 1000,
  'dim': 32,
  'depth': 7,
  'n_head': 8,
  'pos_dim': 32,
  'move_avg': 0.999,
  'center_avg': 0.9,
  'teacher_temp': 0.06},
 'data': {'class': 'allen',
  'path': '/usr/users/agecker/datasets/neuron_morphology_allen/',
  'n_nodes': 200,
  'feat_dim': 8,
  'batch_size': 64,
  'num_workers': 6,
  'jitter_var': 1,
  'translate_var': 10,
  'rotation_axis': 'y',
  'n_drop_branch': 10},
 'optimizer': {'lr': 0.0001, 'max_iter': 100000, 'exp_decay': 0.5},
 'trainer': {'ckpt_dir': 'ssl_neuron/ckpts/',
  'save_ckpt_every': 200,
  'seed': None}}

#### Load model + checkpoint

In [6]:
model = create_model(config)

In [7]:
state_dict = torch.load('../ckpts/ckpt.pt')

FileNotFoundError: [Errno 2] No such file or directory: '../ckpts/ckpt.pt'

In [None]:
model.load_state_dict(state_dict)

In [None]:
model.eval()
model.cuda();

#### Load dataset

In [None]:
dset = GraphDataset(config, mode='all')

In [None]:
feat, neigh = dset.__getsingleitem__(0)

In [None]:
feat.shape, len(neigh)

#### Plot neuron

In [None]:
plot_neuron(neigh, feat)

In [None]:
dir(torch.linalg)

#### Run inference

In [None]:
latents = np.zeros((dset.num_samples, config['model']['dim']))

for i in tqdm(range(dset.num_samples)):
    feat, neigh = dset.__getsingleitem__(i)
    adj = neighbors_to_adjacency_torch(neigh, list(neigh.keys())).float().cuda()[None, ]
    lapl = compute_eig_lapl_torch_batch(adj, pos_enc_dim=config['model']['pos_dim']).float().cuda()
    feat = torch.from_numpy(feat).float().cuda()[None, ]
    
    latents[i] = model.student_encoder.forward(feat, adj, lapl)[0].cpu().detach()


In [None]:
plt.scatter(np.arange(config['model']['dim']), latents.mean(axis=0))

plt.errorbar(np.arange(config['model']['dim']), latents.mean(axis=0), yerr = latents.std(axis=0), fmt='none')

In [None]:
colors = [list(sns.color_palette("mako", n_colors=10)[3])] + [list(sns.color_palette("mako", n_colors=10)[-2])]

In [None]:
z = TSNE(n_components=2, perplexity=30).fit_transform(latents)

In [None]:
plot_tsne(z, np.concatenate([np.zeros(200, dtype=int), np.ones(230, dtype=int)]), targets=['aspiny', 'spiny'], colors=[colors[1], colors[0]])