In [1]:
import matplotlib.pyplot as plt
from tqdm.auto import trange

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal, Poisson, Uniform, kl_divergence
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
rate0 = 2 * torch.pi / 20.0
# rate1 = 1.4142 * 2 * torch.pi / 20.0
rate1 = 2 * rate0

def sample_trial(weights, bias,
                 num_timesteps=100,
                 rate0=rate0,
                 rate1=rate1,
                 kick_prob=0.05):
    """Sample a single trial of data according to the model above.
    """
    theta = torch.zeros(num_timesteps, 2)
    u = torch.zeros(num_timesteps, 1)
    for t in range(1, num_timesteps):
        if torch.rand(1) < kick_prob:
            u[t] = Uniform(0, 2 * torch.pi).sample()
        theta[t, 0] = theta[t-1, 0] + rate0 + u[t]
        theta[t, 1] = theta[t-1, 1] + rate1 + u[t]

    x = torch.column_stack([torch.cos(theta), torch.sin(theta)])
    y = Poisson(F.softplus(x @ weights.T + bias)).sample()
    return u, x, y


In [8]:
class OscillatorDataset(Dataset):
    """A dataset of randomly generated trials.
    """
    def __init__(self, num_neurons, num_timesteps, num_trials, seed=0):
        self.num_neurons = num_neurons
        self.num_timesteps = num_timesteps
        self.num_trials = num_trials

        # Sample random emission weights
        torch.manual_seed(seed)
        self.weights = torch.randn((num_neurons, 4))
        self.bias = torch.randn(num_neurons)

        # Permute based on read-out angle for \theta_1
        angle = torch.atan2(self.weights[:, 3], self.weights[:, 2])
        print(angle.shape, self.weights.shape)
        print(angle)
        perm = torch.argsort(angle)
        print(perm)
        self.weights = self.weights[perm]
        self.bias = self.bias[perm]

        # Sample trials
        all_trials = [sample_trial(self.weights, self.bias) for _ in range(num_trials)]
        self.all_inputs, self.all_states, self.all_spikes = list(zip(*all_trials))

    def __len__(self):
        return self.num_trials

    def __getitem__(self, idx):
        return dict(inputs=self.all_inputs[idx],
                    states=self.all_states[idx],
                    spikes=self.all_spikes[idx])

dataset=OscillatorDataset(10, 100, 60)

torch.Size([10]) torch.Size([10, 4])
tensor([-2.0945, -1.7191,  0.7219, -0.2179,  0.9507,  1.7530,  0.4323,  0.4803,
        -3.1376,  2.3180])
tensor([8, 0, 1, 3, 6, 7, 2, 4, 5, 9])
