In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from copy import deepcopy
from sklearn.metrics import silhouette_score
import pandas as pd 
import torch.optim as optim
from braingeneers.analysis import SpikeData
import os
import json
import pytorch_lightning as pl
from backbones import ResNet18Enc, ResNet18Dec

In [None]:
waveform_files = ["data/" + f for f in os.listdir("./data") if f.endswith('.npy')]
meta_files = [f.replace('.npy', '.json') for f in waveform_files]

waveforms = []
isi_dist = []

for wf, mf in zip(waveform_files, meta_files):
    waveforms.append(np.load(wf))
    
    with open(mf, 'r') as file:
        json_data = json.load(file)

    sd = SpikeData([json_data[key]['train'] for key in json_data.keys()])
    trains = sd.train

    all_isi = sd.interspike_intervals()
    
    for isi in all_isi:
        hist, edges = np.histogram(isi, bins=50, density=True)
        isi_dist.append(hist)


In [None]:
waveforms = np.concatenate(waveforms, axis=0)

In [None]:
isi_dist[0]

In [None]:
class EphysDataset(Dataset):
    def __init__(self, waveforms, isi_dists, normalize=True):
        self.waveforms = np.array(waveforms)
        self.isi_dists = np.array(isi_dists)
        
        assert len(self.waveforms) == len(self.isi_dists)
        self.normalize = normalize
        
    def __getitem__(self, idx):
        waveform = torch.as_tensor(self.waveforms[idx, ...]).float()
        isi_dist = torch.as_tensor(self.isi_dists[idx, ...]).float()
        
        if self.normalize:
            waveform = (waveform - waveform.mean()) / waveform.std()
#             isi_dist = (isi_dist - isi_dist.mean()) / isi_dist.std()

        return waveform.unsqueeze(0), isi_dist.unsqueeze(0)
    
    def __len__(self):
        return len(self.waveforms)

In [None]:
class ContrastiveLoss(nn.Module):
    def __init__(self, batch_size, temperature=0.5):
        super().__init__()
        self.batch_size = batch_size
        self.register_buffer("temperature", torch.tensor(temperature))
        self.register_buffer("negatives_mask", (~torch.eye(batch_size * 2, batch_size * 2, dtype=bool)).float())
            
    def forward(self, emb_i, emb_j):
        """
        emb_i and emb_j are batches of embeddings, where corresponding indices are pairs
        z_i, z_j as per SimCLR paper
        """
        z_i = F.normalize(emb_i, dim=1)
        z_j = F.normalize(emb_j, dim=1)

        representations = torch.cat([z_i, z_j], dim=0)
        similarity_matrix = F.cosine_similarity(representations.unsqueeze(1), representations.unsqueeze(0), dim=2)
        
        sim_ij = torch.diag(similarity_matrix, self.batch_size)
        sim_ji = torch.diag(similarity_matrix, -self.batch_size)
        positives = torch.cat([sim_ij, sim_ji], dim=0)

        nominator = torch.exp(positives / self.temperature)
        denominator = self.negatives_mask * torch.exp(similarity_matrix / self.temperature)

        loss_partial = -torch.log(nominator / torch.sum(denominator, dim=1))
        loss = torch.sum(loss_partial) / (2 * self.batch_size)
        return loss


In [None]:
class MixedModel(nn.Module):
    def __init__(self, encoder=None, decoder=None):
        super().__init__()
        
        if encoder is not None:
            self.encoder = encoder
        
        self.fc_wave = nn.Linear(16*(4), 16)
        self.fc_time = nn.Linear(16*(4), 16)
        
        self.norm1 = nn.BatchNorm1d(16)
        
        self.wave_upsample = nn.Linear(16, 16*(4))
        self.time_upsample = nn.Linear(16, 16*(4))
        
        self.norm2 = nn.BatchNorm1d(64)
        
        if decoder is not None:
            self.decoder = decoder
        
    def forward(self, wave, time):
        # encoder zone 
#         rep_w, rep_t = self.encoder(wave), self.encoder(time)
#         rep_w, rep_t = self.attn_enc(rep_w), self.attn_enc(rep_t)
#         B, C, H = rep_w.shape
        
#         # middle zone 
#         rep_w, rep_t = rep_w.view(B, -1), rep_t.view(B, -1)
#         rep_w, rep_t = self.fc_wave(rep_w), self.fc_time(rep_t)
#         rep_w, rep_t = self.norm1(rep_w), self.norm1(rep_t)
        
#         # decoder zone 
#         decode_w, decode_t = self.wave_upsample(rep_w), self.time_upsample(rep_t)
#         decode_w, decode_t = self.norm2(decode_w), self.norm2(decode_t)
#         decode_w, decode_t = decode_w.view(B, C, -1), decode_t.view(B, C, -1)
#         decode_w, decode_t = self.attn_dec(decode_w), self.attn_dec(decode_t)
#         decode_w, decode_t = self.decoder(decode_w), self.decoder(decode_t)
        e_wave, e_time = self.encoder(wave), self.encoder(time)
        d_wave, d_time = self.decoder(e_wave), self.decoder(e_time)
        
        decode_w, decode_t = F.interpolate(d_wave, (wave.size(-1),)), F.interpolate(d_time, (time.size(-1),))

        return e_wave, e_time, decode_w, decode_t 

In [None]:
model = MixedModel(
    encoder=ResNet18Enc(z_dim=5),
    decoder=ResNet18Dec(z_dim=5)
)

sample = torch.randn(2, 8, 1, 50).unbind(0)

x,y,z,w = model(*sample)
x.shape, y.shape, z.shape, w.shape

In [None]:
from pytorch_lightning.utilities import grad_norm

class MultimodalEmbedding(pl.LightningModule):
    def __init__(self, base_model, batch_size, contrastive_loss_weight=1.0):
        super().__init__()
        self.model = base_model
        self.contrastive_loss_weight = contrastive_loss_weight
        self.mse_loss = nn.MSELoss()
        self.contrastive_loss = ContrastiveLoss(batch_size, temperature=0.5)

    def training_step(self, batch, batch_idx):
        wave, time = batch
        rep_w, rep_t, decode_w, decode_t = self.model(wave, time)

        mse_loss_w = self.mse_loss(decode_w, wave)
        mse_loss_t = self.mse_loss(decode_t, time)
        mse_loss = mse_loss_w + mse_loss_t

        emb_i, emb_j = rep_w, rep_t
        contrastive_loss = self.contrastive_loss(emb_i, emb_j)

        # Combine losses
        total_loss = mse_loss + self.contrastive_loss_weight * contrastive_loss

        self.log('train_loss', total_loss)
        self.log('train_mse_loss', mse_loss)
        self.log('train_xe_loss', contrastive_loss)
        
        norms = grad_norm(self.model.encoder, norm_type=2)
        self.log_dict(norms)
        norms = grad_norm(self.model.decoder, norm_type=2)
        self.log_dict(norms)
        norms = grad_norm(self.model.fc_wave, norm_type=2)
        self.log_dict(norms)

        return total_loss

    def validation_step(self, batch, batch_idx):
        wave, time = batch
        rep_w, rep_t, decode_w, decode_t = self.model(wave, time)

        mse_loss_w = self.mse_loss(decode_w, wave)
        mse_loss_t = self.mse_loss(decode_t, time)
        mse_loss = mse_loss_w + mse_loss_t

        emb_i, emb_j = rep_w, rep_t
        contrastive_loss = self.contrastive_loss(emb_i, emb_j)

        # Combine losses
        total_loss = mse_loss + self.contrastive_loss_weight * contrastive_loss

        self.log('val_loss', total_loss)
        self.log('val_mse_loss', mse_loss)
        self.log('val_xe_loss', contrastive_loss)

        return total_loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=0.001)
        return optimizer

In [None]:
import wandb
wandb.init(reinit=True)
bs = 512
wf_train, isi_train = waveforms[:15000], isi_dist[:15000]
wf_val, isi_val = waveforms[15000:], isi_dist[15000:]

traindata = EphysDataset(wf_train, isi_train)
valdata = EphysDataset(wf_val, isi_val)

train_loader = DataLoader(traindata, batch_size=bs, drop_last=True)
val_loader = DataLoader(valdata, batch_size=bs, drop_last=True)

base_model = MixedModel()
model = MultimodalEmbedding(base_model=base_model, batch_size=bs, contrastive_loss_weight=0.5)

trainer = pl.Trainer(
    logger=pl.loggers.WandbLogger(),
#     limit_train_batches=1,
    log_every_n_steps=5,
)

trainer.fit(model, train_loader, val_loader)