In [1]:
# import dependencies
import numpy as np
import probtorch
import scipy.io as sio
from scipy.stats import norm
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.nn import Parameter

In [2]:
# check the availability of CUDA
CUDA = torch.cuda.is_available()

In [3]:
# placeholder values for hyperparameters
LEARNING_RATE = 0.1
NUM_FACTORS   = 50
NUM_SAMPLES   = 10
SOURCE_WEIGHT_VARIANCE = 2
SOURCE_WIDTH_VARIANCE  = 3
VOXEL_NOISE            = 0.1

In [4]:
# Load our sample dataset
dataset = sio.loadmat('s0.mat')

# pull out the voxel activations and locations
voxel_activations = torch.Tensor(dataset['data']).transpose(0, 1)
voxel_locations = torch.Tensor(dataset['R'])

# This could be a huge file.  Close it
del dataset

In [5]:
# Pull out relevant dimensions: the number of times-of-recording, and the number of voxels in each timewise "slice"
num_times = voxel_activations.shape[0]
num_voxels = voxel_activations.shape[1]

In [6]:
# Estimate further hyperparameters from the dataset
brain_center = torch.mean(voxel_locations, 0).unsqueeze(0)
brain_center_variance = 10 * torch.var(voxel_locations, 0).unsqueeze(0)

In [7]:
def radial_basis(locations, centers, distances, num_samples = NUM_SAMPLES, num_voxels = num_voxels):
    locations = locations.unsqueeze(0).expand(num_samples, num_voxels, 3).unsqueeze(1)
    centers = centers.unsqueeze(2)
    distances = distances.unsqueeze(2)
    
    return torch.exp(((locations - centers)**2).sum(3) / -torch.exp(distances))

def elbo(q, p, num_samples = NUM_SAMPLES):
    # Remember: the Evidence Lower Bound is the negative free-energy/negative KL divergence
    return -probtorch.objectives.montecarlo.kl(q, p, sample_dim = 0)

In [8]:
class TFAEncoder(nn.Module):
    def __init__(self, num_times = num_times, num_factors = NUM_FACTORS):
        super(self.__class__, self).__init__()
        self._num_times = num_times
        self._num_factors = num_factors
        
        self._mean_weight = Parameter(torch.randn((self._num_times, self._num_factors)))
        self._weight_variance = Parameter(torch.randn((self._num_times, self._num_factors)))
        
        self._mean_factor_center = Parameter(torch.randn((self._num_factors, 3)))
        self._factor_center_variance = Parameter(torch.randn((self._num_factors, 3)))
        
        self._mean_factor_width = Parameter(torch.randn((self._num_factors)))
        self._factor_width_variance = Parameter(torch.randn((self._num_factors)))

    def forward(self, num_samples = NUM_SAMPLES):
        q = probtorch.Trace()

        mean_weight = self._mean_weight.expand(num_samples, self._num_times, self._num_factors)
        weight_variance = self._weight_variance.expand(num_samples, self._num_times, self._num_factors)
        
        mean_factor_center = self._mean_factor_center.expand(num_samples, self._num_factors, 3)
        factor_center_variance = self._factor_center_variance.expand(num_samples, self._num_factors, 3)
        
        mean_factor_width = self._mean_factor_width.expand(num_samples, self._num_factors)
        factor_width_variance = self._factor_width_variance.expand(num_samples, self._num_factors)
        
        weights = q.normal(mean_weight, weight_variance, name='Weights')
        
        factor_centers = q.normal(mean_factor_center, factor_center_variance, name='FactorCenters')
        factor_widths = q.normal(mean_factor_width, factor_width_variance, name='FactorWidths')
        
        return q

In [9]:
class TFADecoder(nn.Module):
    def __init__(self, num_times = num_times, num_factors = NUM_FACTORS, num_voxels = num_voxels):
        super(self.__class__, self).__init__()
        self._num_times = num_times
        self._num_factors = num_factors
        self._num_voxels = num_voxels
        
        self._mean_weight = Parameter(torch.zeros((self._num_times, self._num_factors)))
        self._weight_variance = Parameter(SOURCE_WEIGHT_VARIANCE * torch.ones((self._num_times, self._num_factors)))
        
        self._mean_factor_center = Parameter(brain_center.expand(self._num_factors, 3) * torch.ones((self._num_factors, 3)))
        self._factor_center_variance = Parameter(brain_center_variance.expand(self._num_factors, 3) * torch.ones((self._num_factors, 3)))
        
        self._mean_factor_width = Parameter(torch.ones((self._num_factors)))
        self._factor_width_variance = Parameter(SOURCE_WIDTH_VARIANCE * torch.ones((self._num_factors)))
        
        self._voxel_noise = Parameter(VOXEL_NOISE * torch.ones(self._num_times, self._num_voxels))
        
    def forward(self, activations = voxel_activations, locations = voxel_locations, q=None):
        p = probtorch.Trace()
        
        weights = p.normal(self._mean_weight, self._weight_variance, value=q['Weights'], name='Weights')
        factor_centers = p.normal(self._mean_factor_center, self._factor_center_variance, value=q['FactorCenters'], name='FactorCenters')
        factor_widths = p.normal(self._mean_factor_width, self._factor_width_variance, value=q['FactorWidths'], name='FactorWidths')
        factors = radial_basis(locations, factor_centers, factor_widths, num_voxels = self._num_voxels)
        observations = p.normal(torch.matmul(weights, factors), self._voxel_noise, value=activations, name='Y')
        
        return p

In [10]:
enc = TFAEncoder()
dec = TFADecoder()

if CUDA:
    enc.cuda()
    dec.cuda()

    Found GPU0 GeForce GTX 1080 Ti which requires CUDA_VERSION >= 8000 for
     optimal performance and fast startup time, but your PyTorch was compiled
     with CUDA_VERSION 7050. Please install the correct PyTorch binary
     using instructions from http://pytorch.org
    


In [11]:
def train(activations, locations, enc, dec, num_steps = 10):
    optimizer = torch.optim.Adam(list(enc.parameters()), lr = LEARNING_RATE)
    if CUDA:
            activations = activations.cuda()
            locations = locations.cuda()

    enc.train()
    dec.train()
    
    losses = np.zeros(num_steps)
    for n in range(num_steps):
        optimizer.zero_grad()
        q = enc()
        p = dec(activations = activations, locations = locations, q = q)
        
        loss = elbo(q, p)
        loss.backward()
        
        optimizer.step()
        if CUDA:
            loss = loss.cpu()
        losses[n] = loss.data.numpy()[0]
        print(losses[n])
        
    return losses

In [12]:
losses = train(Variable(voxel_activations), Variable(voxel_locations), enc, dec)

-17434.318359375
-21401.673828125
-25290.572265625
-28538.748046875
-31673.337890625
-34475.75390625
-36921.83203125
-39085.265625
-41075.375
-42802.57421875


In [13]:
if CUDA:
    q = enc()
    weights = q['Weights'].value.data.cpu().numpy()
    factor_centers = q['FactorCenters'].value.data.cpu().numpy()
    factor_widths = q['FactorWidths'].value.data.cpu().numpy()