In [1]:
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import cosine_similarity
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import matplotlib.pyplot as plt

from myow import MLP, MYOW, MLP3
from myow.data import MonkeyReachNeuralDataset
from myow.transforms import get_neural_transform
from myow.samplers import RelativeSequenceDataLoader
from myow.utils import seed_everything, collect_params
from myow.tasks.train_reach_angle_regressor import linear_evaluate

In [None]:
# set random seed
seed = None
seed_everything(seed=seed)

# get device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'Using {device}.')

# Prepare dataset

The datasets were collected from two macaques, in Lee Miller lab at Northwestern U, as described in DAD paper [[1]](#1).
Data collection was made over two days for each individual, the list of datasets can be found in the table below:

| Primate | Day        | Day index | # of trials | # of neurons |
|---------|------------|:---------:|:-----------:|:------------:|
| Mihi    | 03/03/2014 |     1     |     209     |      187     |
| Mihi    | 03/06/2014 |     2     |     215     |      172     |
| Chewie  | 10/03/2013 |     1     |     159     |      174     |
| Chewie  | 12/19/2013 |     2     |     180     |      155     |

Neural activity is recorded in the primary motor cortex (M1), for which we have the corresponding reach direction (one of eight). All neural data are binned using 100ms windows.

<a id="1">[1]</a>
Dyer, E.L., Gheshlaghi Azar, M., Perich, M.G. et al.
A cryptography-based approach for movement decoding.
Nat Biomed Eng 1, 967–976 (2017).

In [3]:
root = '../data/mihi-chewie'  # path to data
primate = 'chewie' # options: 'chewie', 'mihi'
day = 1 # options: 1, 2

# load dataset for ssl (pre-train)
dataset = MonkeyReachNeuralDataset(root, primate=primate, day=day, split='trainval')

# prepare datasets for linear eval
train_dataset = MonkeyReachNeuralDataset(root, primate=primate, day=day, split='train')
val_dataset = MonkeyReachNeuralDataset(root, primate=primate, day=day, split='val')
test_dataset = MonkeyReachNeuralDataset(root, primate=primate, day=day, split='test')

# prepare dataset for visualization
full_dataset = MonkeyReachNeuralDataset(root, primate=primate, day=day, split=None)

In [4]:
print(full_dataset[0])

Data(x=[6, 174], edge_index=[2, 5], y=[6], pos=[6, 2], vel=[6, 2], acc=[6, 2], force=[6, 2], timestep=[6])


# Prepare data transforms
We apply three transformations to the neural data:
- **Randomized Dropout:** sets the firing rate of a random subset of neurons to zero. The dropout rate is uniformly sampled between `0` and `dropout_p`.
- **Noise:** adds gaussian noise with standard deviation `noise_sigma` before normalization.
- **Pepper or sparse additive noise:** increases the firing rate of a neuron by a constant `pepper_sigma` with a probability `pepper_p`.

These transformations are not always applied, they are random, each applied with a probability of `[]_apply_p`.

In [5]:
# dropout
dropout_p = 0.2 
dropout_apply_p = 1.0

# noise
noise_sigma = 1.5
noise_apply_p = 1.0

# pepper
pepper_p = 0.3 
pepper_sigma = 1.5 
pepper_apply_p = 0.5

# get mean and std for normalization
fr_mean, fr_std = dataset.get_mean_std('firing_rates')

# get transforms
transform = get_neural_transform(
    randomized_dropout=dict(p=dropout_p, apply_p=dropout_apply_p),
    pepper=dict(p=pepper_p, c=pepper_sigma, apply_p=pepper_apply_p),
    noise=dict(std=noise_sigma, apply_p=noise_apply_p),
    normalize=dict(mean=fr_mean, std=fr_std),
)

normalize = get_neural_transform(normalize=dict(mean=fr_mean, std=fr_std))

In addition to these neural augmentations, we use **Temporal jitter** where temporally neighboring samples are considered to be positive examples for one another. `max_lookahead` defines the positivity range. 

Here we set `max_lookahead = 2`, which means that samples `t-2` to `t+2` are positive views of `t`.
`100ms` is our unit of time since we used it for binning.

`RelativeSequenceDataLoader` is a custom dataloader that handles building positive views as well as the pool of candidates for view mining.

In [6]:
max_lookahead = 2 
batch_size = 512 

dataloader = RelativeSequenceDataLoader(dataset, batch_size=batch_size, drop_last=True, shuffle=True,
                                        transform=transform, pos_kmin=0, pos_kmax=max_lookahead,
                                        num_workers=4, persistent_workers=True)

# Prepare model

The input is the firing rates vector for neurons. We use a Multi-layer perceptron (MLP) with batch normalization layers and ReLU activation as the encoder/feature extractor.

In [7]:
hidden_layers = [64, 64, 64]
representation_size = 64 # output representation size

input_size = (fr_std != 0).sum() # some neurons never fire, we filter them out
encoder = MLP([input_size, *hidden_layers, representation_size], batchnorm=True)

Augmented and mined views are aligned in two different spaces, hence the use of two different pairs of projector/predictor.

In [8]:
# define number of nearest neighbors used during mining
knn_nneighs = 5

# define projector architecture
projector_hidden_size = 256 # Hidden size of projectors
projector_output_size = 32 # Output size of projectors

# create projectors and predictors
projector = nn.Identity()
projector_m = MLP3(representation_size, projector_output_size, hidden_size=projector_hidden_size)
predictor = MLP3(representation_size, representation_size, hidden_size=projector_hidden_size)  # used to predict across augmented views
predictor_m = MLP3(projector_output_size, projector_output_size, hidden_size=projector_hidden_size)  # used to predict across mined views

# make MYOW
model = MYOW(encoder, projector, projector_m, predictor, predictor_m, layout='cascaded')
model.to(device);

During mining, the mined view is randomly sampled from the top-k nearest neighbors (`knn_nneighs`) of the sample.

In [9]:
# define number of nearest neighbors used during mining
knn_nneighs = 5

# Prepare optimization algorithm

During training, we use different schedulers for our hyperparameters:
- **learning rate**: After a linear warmup period of 100 epochs, the learning rate is decayed following a cosine decay scheduler.
- **exponential moving average parameter**: $\tau$ is decayed from 0.98 to 1, following a cosine decay scheduler. The target network is updated as an exponential moving average of the online network: target $\leftarrow$ $\tau$  target + (1 - $\tau$) online
- **loss weights**: early in training, the representation is still forming, we use an initial linear warmup period of a few epochs (10) where the mined loss term's contribution is small. The total loss is $(1 - \lambda) \textrm{loss}_{aug} +  \lambda \textrm{loss}_{mined}$

Note that these curves are logged to TensorBoard, under the `SCALARS` tab.

In [10]:
num_epochs = 10000

# compute total number of gradient steps
num_steps_per_epoch = dataloader.num_examples // batch_size
total_steps = num_steps_per_epoch * num_epochs

In [11]:
lr = 0.002 # base learning rate
lr_warmup_epochs = 100
lr_warmup_steps = num_steps_per_epoch * lr_warmup_epochs

def update_learning_rate(step, max_val=lr, total_steps=total_steps, warmup_steps=lr_warmup_steps):
    if 0 <= step <= warmup_steps:
        return max_val * step / warmup_steps + 1e-9
    else:
        return max_val * (1 + np.cos((step - warmup_steps) * np.pi / (total_steps - warmup_steps))) / 2

In [12]:
mm = 0.98 # base momentum for moving average

def update_momentum(step, max_val=mm, total_steps=total_steps):
    return 1 - max_val * (1 + np.cos(step * np.pi / total_steps)) / 2

In [13]:
mined_weight = 0.5 # base loss weight for mined term
mined_weight_warmup_epochs = 10 # warmup period
mined_weight_warmup_steps = num_steps_per_epoch * mined_weight_warmup_epochs

def update_weight(step, max_val=mined_weight, warmup_steps=mined_weight_warmup_steps):
    if 0 <= step <= warmup_steps:
        return max_val * step / warmup_steps + 1e-9
    else:
        return max_val

In [14]:
# optimizer
weight_decay = 2e-5 
params = collect_params(model.trainable_modules, exclude_bias_and_bn=False)
optimizer = torch.optim.AdamW(params, lr=lr, weight_decay=weight_decay)

In [15]:
# tensorboard writer
logdir = None # where the logs are stored
writer = SummaryWriter(logdir)

# Training

In [16]:
def train(step):
    for inputs in dataloader:
        optimizer.zero_grad()

        # update params
        lr = update_learning_rate(step)
        mm = update_momentum(step)
        weight = update_weight(step)

        data = inputs['data'].to(device)
        out = model(online_view=data, target_view=data)
        
        view_2 = inputs['view_2'].to(device)
        out_2 = model(online_view=view_2, target_view=view_2)

        # Augmented Views
        view_1_index = inputs['view_1_index'].to(device)
        online_q, target_z = out.online.q[view_1_index], out_2.target.z

        # Augmented Views (Symmetric)
        online_q_s, target_z_s = out_2.online.q, out.target.z[view_1_index]

        # Mining
        online_y, target_candidate_y = out.online.y, out.target.y
        online_y = online_y[view_1_index]

        # Compute cosine distance
        online_y = F.normalize(online_y, dim=-1, p=2)
        target_candidate_y = F.normalize(target_candidate_y, dim=-1, p=2)
        dist = - torch.einsum('nc,kc->nk', [online_y, target_candidate_y])

        # remove ineligible candidates
        row, col = inputs['ccand_edge_index'].to(device)
        n_mask = torch.unique(row)
        n_idx = torch.zeros(target_candidate_y.size(0), dtype=torch.long)
        n_idx[n_mask] = torch.arange(n_mask.size(0))
        dist[n_idx[row], col] = torch.finfo(dist.dtype).max

        # get k nearest neighbors
        _, topk_index = torch.topk(dist, k=knn_nneighs, largest=False)

        # randomly select mined view out the k nearest neighbors
        mined_view_id = topk_index[torch.arange(topk_index.size(0), dtype=torch.long, device=dist.device),
                                   torch.randint(knn_nneighs, size=(topk_index.size(0),))]

        # Mined views
        online_q_m = out.online.q_m[view_1_index]
        target_v = out.target.v[mined_view_id]

        # loss
        aug_loss = 1 - 0.5 * cosine_similarity(online_q, target_z.detach(), dim=-1).mean() \
                   - 0.5 * cosine_similarity(online_q_s, target_z_s.detach(), dim=-1).mean()
        mined_loss = 1 - cosine_similarity(online_q_m, target_v.detach(), dim=-1).mean()

        loss = (1 - weight) * aug_loss + weight * mined_loss

        loss.backward()
        # update online network
        optimizer.step()
        # update target network
        model.update_target_network(mm)

        # log scalars
        if step % 50 == 0:
            writer.add_scalar('params/lr', lr, step)
            writer.add_scalar('params/mm', mm, step)
            writer.add_scalar('params/weight', weight, step)
            writer.add_scalar('train/loss', loss, step)
            writer.add_scalar('train/aug_loss', aug_loss, step)
            writer.add_scalar('train/mined_loss', mined_loss, step)

        step += 1
    return step

Linear evaluation is performed every `linear_eval_epochs`.

In [17]:
linear_eval_epochs = 10000

def test(step):
    encoder = copy.deepcopy(model.online_encoder.eval())
    test_acc, test_delta_acc = linear_evaluate(encoder, train_dataset, val_dataset, test_dataset, normalize, writer, device, epoch)
    return test_acc, test_delta_acc

# Training

In [None]:
step = 0
for epoch in tqdm(range(1, num_epochs+1)):
    step = train(step)
    if epoch % linear_eval_epochs == 0:
        test_acc, test_delta_acc = test(step)

print('Accuracy: %.2f\nDelta-Accuracy: %.2f' % (100*test_acc, 100*test_delta_acc))

# Visualizing learned representations

In [19]:
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

In [20]:
def visualize(step):
    encoder = copy.deepcopy(model.online_encoder.eval())
    # prepare data
    x = normalize(full_dataset.x).to(device)

    # compute representations
    x = x.to(device)
    with torch.inference_mode():
        representations = encoder(x).to('cpu')

    # get metadata
    reach_direction = full_dataset.y.numpy()
    timestep = full_dataset.timestep.numpy()
    trial = full_dataset.batch.numpy()
    vel = torch.norm(full_dataset.vel, 2, dim=1).numpy()

    # get __seq_next__ to display trajectory
    seq_next = np.zeros(full_dataset.num_samples, dtype='U8')
    seq_next[full_dataset.edge_index[0]] = full_dataset.edge_index[1].numpy().astype('U8')

    # combine metadata
    metadata = np.column_stack([reach_direction, timestep, trial, vel, seq_next]).tolist()
    metadata_header = ['reach_direction', 'timestep', 'trial_id', 'velocity', '__seq_next__']

    # log to tensorboard
    writer.add_embedding(representations, metadata=metadata, metadata_header=metadata_header, global_step=step)

In [21]:
visualize(step)

The learned embeddings are saved to the embedding directory and can be viewed through TensorBoard's embedding projector. 
In tensorboard, go to the `PROJECTOR` tab (can be found in the dropdown menu).

This is an example of what to expect. We use T-SNE with perplexity of 80, which highlights the global structure of the data.
![](../docs/embedding_projector_screenshot.png)

In [22]:
%load_ext tensorboard
%tensorboard --logdir runs