In [60]:
%load_ext autoreload
%autoreload 2


import torch

import torch.nn as nn
from torch.distributions import Categorical, Poisson, MixtureSameFamily
from matplotlib import pyplot as plt


# Cd to code
import os
import sys
os.chdir('/cluster/home/kheuto01/code/prob_diff_topk')
sys.path.append('/cluster/home/kheuto01/code/prob_diff_topk')

from datasets import example_datasets, to_numpy
from torch_perturb.torch_pert_topk import PerturbedTopK
from torch_models import MixtureOfPoissonsModel, torch_bpr_uncurried, deterministic_bpr
from torch_distributions import TruncatedNormal

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
seed=360
# tracts/distributions
S=12
# history/features
H = 3
# total timepoints
T= 500
num_components=4
K=4

In [25]:
train_dataset, val_dataset, test_dataset = example_datasets(H, T, seed=seed)
train_X_THS, train_y_TS = to_numpy(train_dataset)

In [26]:
# take inverse softplus
ideal_means = torch.tensor([0+1e-8, 7, 10, 100])
ideal_softinv_means = ideal_means + torch.log(-torch.expm1(-ideal_means))
ideal_scales = torch.tensor([0.2, 0.2, 0.2, 0.2])
ideal_softinv_scales = ideal_scales + torch.log(-torch.expm1(-ideal_scales)) 
ideal_mix_weights = torch.log(1e-13 + torch.tensor(
                                [[0,1,0,0],
                                 [0,1,0,0],
                                 [0,1,0,0],
                                 [0,1,0,0],
                                 [0.3,0,0.7,0],
                                 [0.3,0,0.7,0],
                                 [0.3,0,0.7,0],
                                 [0.3,0,0.7,0],
                                 [0.9,0,0,0.1],
                                 [0.9,0,0,0.1],
                                 [0.9,0,0,0.1],
                                 [0.9,0,0,0.1]]))

In [27]:
model = MixtureOfTruncNormModel()
step_size = 0.05
optimizer = torch.optim.Adam(model.parameters(), lr=step_size)

model.update_params(torch.cat([ideal_softinv_means, ideal_softinv_scales, ideal_mix_weights.view(-1)]))

In [28]:
M_score_func =  200
M_action = 200
train_T = train_y_TS.shape[0]
perturbed_top_K_func = PerturbedTopK(k=4)


In [18]:
mix_model = model()
sample = mix_model.sample()

In [19]:
# print sample as floats rounded to 2 decimal places
print([samp.round(2) for samp in sample.numpy()])

[7.1, 7.22, 7.17, 6.96, 9.98, 0.1, 9.45, 10.12, 0.04, 0.04, 100.17, 0.1]


In [20]:
torch.sum(mix_model.log_prob(sample))

tensor(0.3683, grad_fn=<SumBackward0>)

In [22]:
torch.sum(mix_model.log_prob(torch.tensor(train_y_TS)))

tensor(2408.5791, grad_fn=<SumBackward0>)

In [97]:
losses = []
bprs = []
nlls = []

In [98]:
model.params_to_single_tensor()

tensor([-18.4207,   6.9991,  10.0000, 100.0000,  -1.5078,  -1.5078,  -1.5078,
         -1.5078, -29.9336,   0.0000, -29.9336, -29.9336, -29.9336,   0.0000,
        -29.9336, -29.9336, -29.9336,   0.0000, -29.9336, -29.9336, -29.9336,
          0.0000, -29.9336, -29.9336,  -1.2040, -29.9336,  -0.3567, -29.9336,
         -1.2040, -29.9336,  -0.3567, -29.9336,  -1.2040, -29.9336,  -0.3567,
        -29.9336,  -1.2040, -29.9336,  -0.3567, -29.9336,  -0.1054, -29.9336,
        -29.9336,  -2.3026,  -0.1054, -29.9336, -29.9336,  -2.3026,  -0.1054,
        -29.9336, -29.9336,  -2.3026,  -0.1054, -29.9336, -29.9336,  -2.3026],
       grad_fn=<CatBackward0>)

In [99]:

for epoch in range(1000):
    mix_model = model()
    
    y_sample_TMS = mix_model.sample((train_T, M_score_func))
    y_sample_action_TMS = mix_model.sample((train_T, M_action))

    ratio_rating_TMS = y_sample_action_TMS/y_sample_action_TMS.sum(dim=-1, keepdim=True)
    ratio_rating_TS =  ratio_rating_TMS.mean(dim=1)
    ratio_rating_TS.requires_grad_(True)

    #pred_y_TS = torch.mean(y_sample_action_TMS, dim=1)
    #pred_y_TS.requires_grad_(True)

    def get_log_probs_baked(param):
        distribution = model.build_from_single_tensor(param)
        log_probs_TMS = distribution.log_prob(y_sample_TMS)

        return log_probs_TMS

    jac_TMSP = torch.autograd.functional.jacobian(get_log_probs_baked, (model.params_to_single_tensor()), strategy='forward-mode', vectorize=True)

    score_func_estimator_TMSP = jac_TMSP * ratio_rating_TMS.unsqueeze(-1)
    score_func_estimator_TSP = score_func_estimator_TMSP.mean(dim=1)    

    # get gradient of negative bpr_t  with respect to ratio rating_TS
    positive_bpr_T = torch_bpr_uncurried(ratio_rating_TS, torch.tensor(train_y_TS), K=4, perturbed_top_K_func=perturbed_top_K_func)
    negative_bpr = torch.mean(-positive_bpr_T)
    
    nll = torch.sum(-mix_model.log_prob( torch.tensor(train_y_TS)))

    print(f'Neg bpr: {negative_bpr}')
    print(f'nll: {nll}')

    loss = 500*negative_bpr + nll
    print(f'Loss: {loss}')
    losses.append(loss)
    bprs.append(negative_bpr)
    nlls.append(nll)
    
    loss.backward()

    loss_grad_TS = ratio_rating_TS.grad

    gradient_TSP = score_func_estimator_TSP * torch.unsqueeze(loss_grad_TS, -1)
    gradient_P = torch.sum(gradient_TSP, dim=[0,1])

    gradient_tuple = model.single_tensor_to_params(gradient_P)

    for param, gradient in zip(model.parameters(), gradient_tuple):
        param.grad = gradient
    optimizer.step()
        
    #model.update_params(model.params_to_single_tensor() - step_size * gradient_P)


    
    


Neg bpr: -0.5302702188491821
nll: -2408.5791015625
Loss: -2673.714111328125
Neg bpr: -0.5286932587623596
nll: -2322.26123046875
Loss: -2586.60791015625
Neg bpr: -0.5222402215003967
nll: -2304.0625
Loss: -2565.1826171875
Neg bpr: -0.5265845656394958
nll: -2260.807861328125
Loss: -2524.10009765625
Neg bpr: -0.5275028347969055
nll: -2238.63720703125
Loss: -2502.388671875
Neg bpr: -0.5284132361412048
nll: -2202.7490234375
Loss: -2466.95556640625
Neg bpr: -0.5229679346084595
nll: -2192.816650390625
Loss: -2454.300537109375
Neg bpr: -0.5284321308135986
nll: -2195.840087890625
Loss: -2460.05615234375
Neg bpr: -0.5273725390434265
nll: -2148.55419921875
Loss: -2412.240478515625
Neg bpr: -0.5310458540916443
nll: -2115.156494140625
Loss: -2380.679443359375
Neg bpr: -0.5344506502151489
nll: -2143.5654296875
Loss: -2410.790771484375
Neg bpr: -0.5283373594284058
nll: -2153.397216796875
Loss: -2417.56591796875
Neg bpr: -0.5284692645072937
nll: -2194.61767578125
Loss: -2458.852294921875
Neg bpr: -0.52

In [18]:
y_sample_TMS.shape

torch.Size([300, 200, 12])

In [20]:
 model.build_from_single_tensor(model.params_to_single_tensor()).log_prob(y_sample_TMS)

RuntimeError: The size of tensor a (4) must match the size of tensor b (12) at non-singleton dimension 2

In [23]:
 model.build_from_single_tensor(model.params_to_single_tensor()).log_prob(y_sample_TMS)

RuntimeError: The size of tensor a (4) must match the size of tensor b (12) at non-singleton dimension 2

In [25]:
distribution= model.build_from_single_tensor(model.params_to_single_tensor())

In [29]:
distribution.log_prob(distribution.sample((300, M_score_func)))

RuntimeError: The size of tensor a (4) must match the size of tensor b (12) at non-singleton dimension 2

In [29]:
mix_model = model()

y_sample_TMS = mix_model.sample((train_T, M_score_func))
y_sample_action_TMS = mix_model.sample((train_T, M_action))

ratio_rating_TMS = y_sample_action_TMS/y_sample_action_TMS.sum(dim=-1, keepdim=True)
ratio_rating_TS =  ratio_rating_TMS.mean(dim=1)
ratio_rating_TS.requires_grad_(True)

tensor([[0.0963, 0.0965, 0.0968,  ..., 0.0605, 0.0558, 0.0600],
        [0.1012, 0.1004, 0.1007,  ..., 0.0532, 0.0384, 0.0672],
        [0.0903, 0.0905, 0.0902,  ..., 0.0771, 0.0702, 0.0654],
        ...,
        [0.1042, 0.1037, 0.1038,  ..., 0.0390, 0.0443, 0.0439],
        [0.0997, 0.0997, 0.0996,  ..., 0.0672, 0.0556, 0.0592],
        [0.0946, 0.0949, 0.0946,  ..., 0.0579, 0.0533, 0.0592]],
       requires_grad=True)

In [31]:
torch.mean(ratio_rating_TS, dim=0)

tensor([0.0979, 0.0979, 0.0979, 0.0979, 0.0935, 0.0929, 0.0932, 0.0932, 0.0599,
        0.0588, 0.0582, 0.0589], grad_fn=<MeanBackward1>)

In [32]:
ratio_rating_TS.shape

torch.Size([300, 12])

In [35]:
perturbed_top_K_func(torch.mean(ratio_rating_TS, dim=0, keepdim=True)).sum(dim=-2)

tensor([[0.4100, 0.4620, 0.4440, 0.4140, 0.3800, 0.4360, 0.3800, 0.3980, 0.1840,
         0.1760, 0.1680, 0.1480]], grad_fn=<SumBackward1>)

In [36]:
torch.topk(torch.tensor([[1,2,3],[6,5,4]]), k=2)

torch.return_types.topk(
values=tensor([[3, 2],
        [6, 5]]),
indices=tensor([[2, 1],
        [0, 1]]))

In [39]:
perturbed_top_K_func(torch.tensor([[1,2,3,4,5],[6,5,4,3,2]]))

tensor([[[0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.],
         [0., 0., 0., 0., 1.]],

        [[1., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.]]])

In [42]:
torch.mean(torch_bpr_uncurried(ratio_rating_TS, torch.tensor(train_y_TS), K=4, perturbed_top_K_func=perturbed_top_K_func))

tensor(0.5286, grad_fn=<MeanBackward0>)

In [62]:
torch.mean(deterministic_bpr(ratio_rating_TS, torch.tensor(train_y_TS), K=4))

tensor(0.5899)

In [52]:
true_topk = torch.topk(torch.tensor(train_y_TS), K)
pred_value_at_topk  = torch.gather(ratio_rating_TS, 1, true_topk.indices)
numerator = torch.sum(pred_value_at_topk, dim=-1)
denominator = torch.sum(true_topk.values, dim=-1)
bpr = numerator/denominator

In [59]:
denominator

tensor([130.,  31.,  37.,  37.,  34.,  34.,  31.,  34.,  34.,  34.,  37.,  37.,
         40.,  37.,  37.,  37., 130., 124.,  34.,  37.,  40., 124.,  37., 130.,
        130.,  34., 130., 130., 130.,  37., 130., 130., 130.,  37.,  37., 124.,
         34., 121.,  34.,  37.,  40.,  40.,  31.,  37.,  31.,  37.,  37., 130.,
         40.,  37., 130.,  37.,  40.,  37.,  34., 127., 124., 127.,  34.,  37.,
         34.,  37.,  40.,  40., 130.,  37.,  40.,  37.,  40.,  31.,  37., 220.,
         37.,  37., 130.,  37.,  31., 220.,  40.,  40., 220.,  31.,  37.,  34.,
         37.,  37.,  31.,  40.,  37., 130.,  34.,  40., 130.,  34., 220.,  37.,
        127.,  37.,  40.,  40., 130.,  37.,  37., 130.,  37., 130.,  37.,  37.,
        130., 124., 220.,  31., 130., 127.,  34.,  37.,  40.,  34.,  34.,  40.,
         31., 130.,  37.,  37.,  40.,  40., 127.,  31.,  34., 220.,  37.,  37.,
         34.,  37.,  40.,  40.,  34., 130., 130.,  37.,  40.,  37.,  37.,  37.,
         40.,  37., 130.,  37., 124.,  4

In [46]:
denominator.shape

torch.Size([300])

In [53]:
pred_value_at_topk

tensor([[0.0558, 0.0907, 0.0929, 0.0958],
        [0.0937, 0.1004, 0.1012, 0.1007],
        [0.0848, 0.0831, 0.0919, 0.0905],
        ...,
        [0.0439, 0.0996, 0.0986, 0.1037],
        [0.0942, 0.0815, 0.0952, 0.0997],
        [0.0937, 0.1033, 0.0858, 0.0900]], grad_fn=<GatherBackward0>)

In [54]:
ratio_rating_TS[0]

tensor([0.0963, 0.0965, 0.0968, 0.0965, 0.0907, 0.0939, 0.0929, 0.0958, 0.0644,
        0.0605, 0.0558, 0.0600], grad_fn=<SelectBackward0>)

In [55]:
train_y_TS[0]

array([  7.,   7.,   7.,   7.,  10.,   0.,  10.,  10.,   0.,   0., 100.,
         0.], dtype=float32)