In [1]:
#import
import os
from functools import partial
import torch
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

import pyro
import pyro.distributions as dist
import warnings
from pandas.errors import SettingWithCopyWarning
from pyro.infer import MCMC, NUTS

# for CI testing
warnings.simplefilter(action='ignore', category=SettingWithCopyWarning)
smoke_test = ('CI' in os.environ)
assert pyro.__version__.startswith('1.8.4')
pyro.set_rng_seed(1)


# Set matplotlib settings
%matplotlib inline
plt.style.use('default')

## Trying to adapt this model to our problem

In [2]:
df = pd.read_csv('../data/newLookAtMe/newLookAtMe20.csv').dropna()
data = df[['morphing level', 'shock']]
data['shock'] = data['shock'].astype(int)
data['morphing level'] = [int(d==6) for d in data['morphing level']]
data

Unnamed: 0,morphing level,shock
0,0,0
1,1,0
2,1,0
3,0,0
4,0,0
...,...,...
155,0,0
156,0,0
157,0,0
158,0,0


In [3]:
x_data, y_data = data.to_numpy()[:,0], data.to_numpy()[:,1]
x_data = torch.tensor(x_data)
y_data = torch.tensor(y_data)

In [16]:
data_model = data.to_numpy()
data_final = []
for x in data_model:
    if (x == [0, 0]).all():
        data_final.append(0)
    elif (x == [1, 0]).all():
        data_final.append(1)
    else:
        data_final.append(2)

data_final = torch.tensor(data_final)
data_final

tensor([0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 2, 0, 0, 2, 0, 2, 2, 0,
        1, 2, 0, 0, 0, 0, 2, 1, 0, 2, 2, 0, 1, 0, 2, 0, 1, 0, 2, 0, 2, 0, 2, 0,
        0, 2, 0, 0, 1, 0, 0, 0, 0, 2, 0, 0, 1, 0, 1, 0, 0, 2, 0, 2, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 2, 1, 0, 0, 0, 1, 0, 0, 2, 0,
        0, 1, 0, 2, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, 0, 0, 0,
        1, 0, 2, 0, 0, 0, 0, 2, 1, 0, 0, 1, 0, 0, 0, 1, 2, 0, 0, 0, 0, 2, 0, 0,
        2, 0, 0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 1])

In [17]:
prior_ = torch.ones(3)

def model(data):
    theta = pyro.sample('theta', dist.Dirichlet(prior_))
    with pyro.plate('data', len(data)):
        pyro.sample('obs', dist.Categorical(theta), obs=data)

In [19]:
nuts_kernel = NUTS(model)
num_samples, warmup_steps = (1000, 200) if not smoke_test else (10, 10)
mcmc = MCMC(nuts_kernel, num_samples=num_samples, warmup_steps=warmup_steps)
mcmc.run(data_final)
hmc_samples = {k: v.detach().cpu().numpy()
               for k, v in mcmc.get_samples().items()}

Sample: 100%|██████████| 1200/1200 [00:08, 137.95it/s, step size=1.03e+00, acc. prob=0.903]
