In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict, Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from tqdm import tqdm
import shutil
from utils import count_parameters, accuracy, pooling
from config import NUM_EPOCHS, CV_DATA

import pyro
import pyro.distributions as dist
from pyro.distributions import Normal, Categorical
from pyro.infer import SVI, Trace_ELBO, TraceEnum_ELBO, config_enumerate
from pyro.optim import Adam, ClippedAdam, SGD

In [2]:
import matplotlib.pyplot as plt
import seaborn as sns

In [3]:
%matplotlib inline

In [4]:
plt.style.use('ggplot')

In [5]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
df = pd.read_csv(CV_DATA, delimiter='\t', header=None)
df.rename({0: 'y'}, axis=1, inplace=True)
ys = df['y']-1
df.drop('y', axis=1, inplace=True)

In [None]:
gt = torch.LongTensor(ys).to(device)

In [None]:
df.head()

In [None]:
xs = df.values

In [None]:
xs = pooling(xs, (1,8))

In [None]:
xs.shape

In [6]:
cv_loader = torch.load('cv_loader.pt')

### Model

In [7]:
from bayesian_nn import *

In [8]:
state = torch.load('nn_state.pth.tar')

In [9]:
mdl = Classifier()

In [10]:
state['cv_props']

{'loss': 485.5384292602539,
 'accuracy': 0.7350000143051147,
 'accuracy_1': 0.0,
 'accuracy_2': 0.0,
 'accuracy_3': 0.0}

In [11]:
mdl.load_state_dict(state['state_dict'])

<All keys matched successfully>

In [12]:
mdl.guide

<bound method Classifier.guide of Classifier(
  (encoder): SimpleConvolutionalEncoder(
    (c1): Conv1d(1, 5, kernel_size=(3,), stride=(1,))
    (pool): AdaptiveMaxPool1d(output_size=10)
    (act): LeakyReLU(negative_slope=0.3)
    (out): Linear(in_features=50, out_features=3, bias=True)
  )
  (log_softmax): LogSoftmax()
)>

In [13]:
mdl.eval()

Classifier(
  (encoder): SimpleConvolutionalEncoder(
    (c1): Conv1d(1, 5, kernel_size=(3,), stride=(1,))
    (pool): AdaptiveMaxPool1d(output_size=10)
    (act): LeakyReLU(negative_slope=0.3)
    (out): Linear(in_features=50, out_features=3, bias=True)
  )
  (log_softmax): LogSoftmax()
)

In [14]:
x_cv, y_cv = next(iter(cv_loader
                      ))

In [15]:
cv_res = mdl.predict(x_cv)

In [16]:
(cv_res == y_cv).float().mean()

tensor(0.3350)

In [17]:
aa = 0
with torch.no_grad():
    for x_cv, y_cv in cv_loader:
        res = mdl.predict(x_cv)
        a = (res == y_cv).float().mean()
        aa += a.item()
    aa = aa/len(cv_loader)

In [None]:
aa

In [None]:
def predict(x, num_samples=10):
    sampled_models = [mdl.guide(None, None) for _ in range(num_samples)]
    yhats = [model(x.to(device)).data for model in sampled_models]
    mean = torch.mean(torch.stack(yhats), 0)
    return mean.argmax(dim=1)

In [None]:
x = df.iloc[0].values.reshape(1,-1)
x = pooling(x, (1,8))
x = torch.FloatTensor(x)

In [None]:
predict(x)

In [None]:
import pdb

### Calculate certainties 

In [None]:
def sample_preds(x, num_samples=1000):
    with torch.no_grad():
        x = torch.FloatTensor(x)
        if len(x.size()) == 1:
            x = x.unsqueeze(0)
        sampled_models = [mdl.guide(None, None) for _ in range(num_samples)]
        logits = [model(x.to(device)).data for model in sampled_models]
        mean = torch.mean(torch.stack(logits), 0)
        res = mean.argmax(dim=1)
#         pdb.set_trace()
        preds = [a.argmax(dim=1) for a in logits]
        gt = torch.LongTensor(ys).to(device)
        accs = np.stack([(pred == gt).float().mean().detach().cpu().numpy() for pred in preds])
#         res = torch.stack(logits)
        return logits, res, preds, accs

In [None]:
xs.shape

In [None]:
logits, res, preds, accs = sample_preds(xs)

In [None]:
(res == gt).float().mean()

In [None]:
bincnt

In [None]:
preds = torch.stack(preds)

In [None]:
preds = preds.T

In [None]:
preds.shape

In [None]:
preds = preds.detach().cpu().numpy()

In [None]:
outp = []
bins = []
for pred in preds:
    bincnt = np.bincount(pred)
    bins.append(bincnt)
    y=bincnt.argmax()
    outp.append(y)

In [None]:
maxvals = np.stack([x.max() for x in bins])

In [None]:
sns.distplot(maxvals.reshape(-1))

In [None]:
sns.distplot(accs, bins=10)

In [None]:
sns.distplot(accs, bins=10)

In [None]:
accs.mean(), accs.std()

In [None]:
ys = np.array(ys)

In [None]:
print 1

### Subset Analysis

In [None]:
preds

In [None]:
cert_thresh = np.percentile(maxvals, 95)

In [None]:
ixs = (maxvals > cert_thresh)

In [None]:
ixs.sum()

In [None]:
top = preds[ixs]
gt_top = gt[ixs].numpy()

In [None]:
sns.countplot(top[200])

In [None]:
preds.shape

In [None]:
top_preds = []
for x in top:
    top_preds.append(np.bincount(x).argmax())
top_preds = np.stack(top_preds)

In [None]:
high_thresh_acc = (top_preds == gt_top)

In [None]:
top_preds.shape

In [None]:
top_preds

In [None]:
def summary(ts, num_samples=100):
#     ts = ts.reshape(1,-1)
#     ts = pooling(ts, (1,8))
    x = torch.FloatTensor(ts).unsqueeze(0)
    _, preds, u, std = sample_preds(x, num_samples=num_samples)
    preds = np.array(preds).detach().cpu()
    u = u.detach().cpu().numpy()
    std = std.detach().cpu().numpy()
    out = {'preds': preds, 'u': u, 'std': std}
    return out

In [None]:
idx = 514

In [None]:
y[idx]

In [None]:
out = summary(xs[idx], num_samples=100)

In [None]:
out

In [None]:
sns.countplot(out['preds'])

In [None]:
from collections import Counter

In [None]:
Counter(list(out['preds']))

In [None]:
out['u'].argmax(1)

### Subset 

In [None]:
ixs = np.arange(100)

In [None]:
inputs = df.iloc[ixs].values
inputs = pooling(inputs, (1,8))

In [None]:
inputs.shape

In [None]:
for ts in tqdm(inputs):
    

### Plot series 

In [None]:
def plot_ts(ts):
    y = ts.values
    x = np.arange(0, len(y), 1)
    plt.figure(figsize=(7,5))
    sns.lineplot(x, y)
    plt.show()

In [None]:
plot_ts(df.iloc[500])

In [None]:
def plt_multiple(ixs):
    fig, ax = plt.subplots(nrows=2, sharey=True, sharex=True)
    for i, ix in enumerate(ixs):
        y = df.iloc[ix]
        x = np.arange(0, len(y), 1)
        sns.lineplot(x,y, ax=ax[i])

In [None]:
plt_multiple([4, 5])