In [1]:
%load_ext autoreload
%autoreload 2


import torch

import torch.nn as nn
from torch.distributions import Categorical, Poisson, MixtureSameFamily
# Cd to code
import os
import sys
os.chdir('/cluster/home/kheuto01/code/prob_diff_topk')
sys.path.append('/cluster/home/kheuto01/code/prob_diff_topk')

from datasets import example_datasets, to_numpy
from torch_perturb.torch_pert_topk import PerturbedTopK

2024-07-08 18:25:37.715545: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-07-08 18:25:37.779106: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-08 18:25:37.779142: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-08 18:25:37.780249: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-08 18:25:37.788536: I tensorflow/core/platform/cpu_feature_guar

In [2]:
seed=360
# tracts/distributions
S=12
# history/features
H = 3
# total timepoints
T= 500
num_components=4
K=4

In [3]:
train_dataset, val_dataset, test_dataset = example_datasets(H, T, seed=seed)
train_X_THS, train_y_TS = to_numpy(train_dataset)

2024-07-08 18:26:23.937274: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2024-07-08 18:26:23.937312: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:129] retrieving CUDA diagnostic information for host: s1cmp008.pax.tufts.edu
2024-07-08 18:26:23.937317: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:136] hostname: s1cmp008.pax.tufts.edu
2024-07-08 18:26:23.937353: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:159] libcuda reported version is: 535.129.3
2024-07-08 18:26:23.937380: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:163] kernel reported version is: 535.129.3
2024-07-08 18:26:23.937386: I external/local_xla/xla/stream_executor/cuda/cuda_diagnostics.cc:241] kernel version seems to match DSO: 535.129.3


In [66]:

class MixtureOfPoissonsModel(nn.Module):
    def __init__(self, num_components=4, S=12):
        super(MixtureOfPoissonsModel, self).__init__()
        self.num_components = num_components
        self.S = S
        
        # Initialize the log rates and mixture probabilities as learnable parameters
        self.log_poisson_rates = nn.Parameter(torch.rand(num_components))  # Initialize log rates
        self.mixture_probs = nn.Parameter(torch.rand(S, num_components))  # Initialize probabilities

    def params_to_single_tensor(self):
        return torch.cat([self.log_poisson_rates, self.mixture_probs.view(-1)])
    
    def single_tensor_to_params(self, single_tensor):
        log_poisson_rates = single_tensor[:self.num_components]
        mixture_probs = single_tensor[self.num_components:].view(self.S, self.num_components)
        return log_poisson_rates, mixture_probs
    
    def update_params(self, single_tensor):
        log_poisson_rates, mixture_probs = self.single_tensor_to_params(single_tensor)
        self.log_poisson_rates = nn.Parameter(log_poisson_rates)
        self.mixture_probs = nn.Parameter(mixture_probs)
        return
    
    def build_from_single_tensor(self, single_tensor):
        log_poisson_rates, mixture_probs = self.single_tensor_to_params(single_tensor)
        poisson_rates = torch.exp(log_poisson_rates)
        mixture_probs_normalized = torch.nn.functional.softmax(mixture_probs, dim=1)
        categorical_dist = Categorical(mixture_probs_normalized)
        expanded_rates = poisson_rates.expand(self.S, self.num_components)
        poisson_dist = Poisson(expanded_rates, validate_args=False)
        mixture_dist = MixtureSameFamily(categorical_dist, poisson_dist)
        return mixture_dist
        
    def forward(self):
        # Transform log rates to rates
        poisson_rates = torch.exp(self.log_poisson_rates)
        
        # Normalize mixture_probs to sum to 1 across the components
        mixture_probs_normalized = torch.nn.functional.softmax(self.mixture_probs, dim=1)
        
        # Create the Categorical distribution with the normalized probabilities
        categorical_dist = Categorical(mixture_probs_normalized)
        
        # Expand the Poisson rates to match the number of samples
        expanded_rates = poisson_rates.expand(self.S, self.num_components)
        
        # Create the Poisson distribution with the expanded rates
        poisson_dist = Poisson(expanded_rates, validate_args=False)
        
        # Create the MixtureSameFamily distribution
        mixture_dist = MixtureSameFamily(categorical_dist, poisson_dist,  validate_args=False)
        
        
        
        return mixture_dist
    
def torch_bpr_uncurried(y_pred, y_true, K=4, perturbed_top_K_func=None):

    top_K_ids = perturbed_top_K_func(y_pred)
    # Sum over k dim
    top_K_ids = top_K_ids.sum(dim=-2)

    true_top_K_val, _  = torch.topk(y_true, K) 
    denominator = torch.sum(true_top_K_val, dim=-1)
    numerator = torch.sum(top_K_ids * y_true, dim=-1)
    bpr = numerator/denominator

    return bpr

In [67]:
ideal_log_rates = torch.log(torch.tensor([0+1e-8, 7, 10, 100]))
ideal_mix_weights = torch.log(1e-13 + torch.tensor(
                                [[0,1,0,0],
                                 [0,1,0,0],
                                 [0,1,0,0],
                                 [0,1,0,0],
                                 [0.3,0,0.7,0],
                                 [0.3,0,0.7,0],
                                 [0.3,0,0.7,0],
                                 [0.3,0,0.7,0],
                                 [0.9,0,0,0.1],
                                 [0.9,0,0,0.1],
                                 [0.9,0,0,0.1],
                                 [0.9,0,0,0.1]]))

In [74]:
bpr_log_rates = torch.log(torch.tensor([0+1e-8, 7, 10, 100]))
bpr_mix_weights = torch.log(1e-13 + torch.tensor(
                                [[0,0,0,1],
                                 [0,0,0,1],
                                 [0,0,0,1],
                                 [0,0,0,1],
                                 [1,0,0,0],
                                 [1,0,0,0],
                                 [1,0,0,0],
                                 [1,0,0,0],
                                 [1,0,0,0],
                                 [1,0,0,0],
                                 [1,0,0,0],
                                 [1,0,0,0]]))

In [75]:
# Instantiate the model
model = MixtureOfPoissonsModel()
#model.update_params(torch.cat([ideal_log_rates, ideal_mix_weights.view(-1)]))
model.update_params(torch.cat([bpr_log_rates, bpr_mix_weights.view(-1)]))

# Define an optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [76]:
M_score_func =  100
M_action = 100
train_T = train_y_TS.shape[0]
perturbed_top_K_func = PerturbedTopK(k=4)

In [81]:
M_action

10

In [77]:
mix_model = model()
mix_model.sample()

tensor([106., 115., 101., 114.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.])

In [78]:
model.mixture_probs

Parameter containing:
tensor([[-29.9336, -29.9336, -29.9336,   0.0000],
        [-29.9336, -29.9336, -29.9336,   0.0000],
        [-29.9336, -29.9336, -29.9336,   0.0000],
        [-29.9336, -29.9336, -29.9336,   0.0000],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336],
        [  0.0000, -29.9336, -29.9336, -29.9336]], requires_grad=True)

In [79]:
torch.nn.functional.softmax(model.mixture_probs, dim=1)

tensor([[1.0000e-13, 1.0000e-13, 1.0000e-13, 1.0000e+00],
        [1.0000e-13, 1.0000e-13, 1.0000e-13, 1.0000e+00],
        [1.0000e-13, 1.0000e-13, 1.0000e-13, 1.0000e+00],
        [1.0000e-13, 1.0000e-13, 1.0000e-13, 1.0000e+00],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13],
        [1.0000e+00, 1.0000e-13, 1.0000e-13, 1.0000e-13]],
       grad_fn=<SoftmaxBackward0>)

In [80]:
for epoch in range(1000):
    mix_model = model()
    
    y_sample_TMS = mix_model.sample((train_T, M_score_func))
    y_sample_action_TMS = mix_model.sample((train_T, M_action))

    ratio_rating_TMS = y_sample_TMS/y_sample_TMS.sum(dim=-1, keepdim=True)
    ratio_rating_TS =  ratio_rating_TMS.mean(dim=1)
    ratio_rating_TS.requires_grad_(True)

    pred_y_TS = torch.mean(y_sample_action_TMS, dim=1)
    pred_y_TS.requires_grad_(True)

    def get_log_probs_baked(param):
        distribution = model.build_from_single_tensor(param)
        log_probs_TMS = distribution.log_prob(y_sample_TMS)

        return log_probs_TMS

    jac_TMSP = torch.autograd.functional.jacobian(get_log_probs_baked, (model.params_to_single_tensor()), strategy='forward-mode', vectorize=True)

    score_func_estimator_TMSP = jac_TMSP * ratio_rating_TMS.unsqueeze(-1)
    score_func_estimator_TSP = score_func_estimator_TMSP.mean(dim=1)    

    # get gradient of negative bpr_t  with respect to ratio rating_TS
    positive_bpr_T = torch_bpr_uncurried(ratio_rating_TS, torch.tensor(train_y_TS), K=4, perturbed_top_K_func=perturbed_top_K_func)
    negative_bpr = torch.mean(-positive_bpr_T)
    
    log_prob = torch.mean(mix_model.log_prob(pred_y_TS))

    print(f'Neg bpr: {negative_bpr}')
    print(f'Log prob: {log_prob}')

    loss = negative_bpr + log_prob
    print(f'Loss: {loss}')
    
    loss.backward()

    loss_grad_TS = ratio_rating_TS.grad

    gradient_TSP = score_func_estimator_TSP * torch.unsqueeze(loss_grad_TS, -1)
    gradient_P = torch.sum(gradient_TSP, dim=[0,1])
    model.update_params(model.params_to_single_tensor() - 0.0005 * gradient_P)


    
    


Neg bpr: -0.5936790704727173
Log prob: -1.0909003019332886
Loss: -1.6845793724060059
Neg bpr: -0.5936261415481567
Log prob: -1.0909314155578613
Loss: -1.684557557106018
Neg bpr: -0.593609631061554
Log prob: -1.0920900106430054
Loss: -1.685699701309204
Neg bpr: -0.5935839414596558
Log prob: -1.0907151699066162
Loss: -1.684299111366272
Neg bpr: -0.5935624241828918
Log prob: -1.0903267860412598
Loss: -1.6838891506195068
Neg bpr: -0.5936207175254822
Log prob: -1.0907315015792847
Loss: -1.684352159500122
Neg bpr: -0.5935508608818054
Log prob: -1.0908491611480713
Loss: -1.6844000816345215
Neg bpr: -0.5935357213020325
Log prob: -1.0912786722183228
Loss: -1.684814453125
Neg bpr: -0.5936082601547241
Log prob: -1.0910122394561768
Loss: -1.6846204996109009
Neg bpr: -0.5935199856758118
Log prob: -1.0908797979354858
Loss: -1.6843998432159424
Neg bpr: -0.5935778021812439
Log prob: -1.090424656867981
Loss: -1.68400239944458
Neg bpr: -0.5936400294303894
Log prob: -1.0910803079605103
Loss: -1.684720277

KeyboardInterrupt: 

In [84]:
    mix_model = model()
    
    y_sample_TMS = mix_model.sample((train_T, M_score_func))
    y_sample_action_TMS = mix_model.sample((train_T, M_action))

    ratio_rating_TMS = y_sample_action_TMS/y_sample_action_TMS.sum(dim=-1, keepdim=True)
    ratio_rating_TS =  ratio_rating_TMS.mean(dim=1)
    ratio_rating_TS.requires_grad_(True)

    pred_y_TS = torch.mean(y_sample_action_TMS, dim=1)
    pred_y_TS.requires_grad_(True)

    def get_log_probs_baked(param):
        distribution = model.build_from_single_tensor(param)
        log_probs_TMS = distribution.log_prob(y_sample_TMS)

        return log_probs_TMS

    get_log_probs_baked(model.params_to_single_tensor())
    jac_TMSP = torch.autograd.functional.jacobian(get_log_probs_baked, (model.params_to_single_tensor()), strategy='forward-mode', vectorize=True)

    score_func_estimator_TMSP = jac_TMSP * ratio_rating_TMS.unsqueeze(-1)
    score_func_estimator_TSP = score_func_estimator_TMSP.mean(dim=1)    

    # get gradient of negative bpr_t  with respect to ratio rating_TS
    positive_bpr_T = torch_bpr_uncurried(ratio_rating_TS, torch.tensor(train_y_TS), K=4, perturbed_top_K_func=perturbed_top_K_func)
    negative_bpr = torch.mean(-positive_bpr_T)
    
    log_prob = torch.mean(mix_model.log_prob(pred_y_TS))

    print(f'Neg bpr: {negative_bpr}')
    print(f'Log prob: {log_prob}')

    loss = negative_bpr + log_prob
    print(f'Loss: {loss}')
    
    loss.backward()

    loss_grad_TS = ratio_rating_TS.grad

    gradient_TSP = score_func_estimator_TSP * torch.unsqueeze(loss_grad_TS, -1)
    gradient_P = torch.sum(gradient_TSP, dim=[0,1])

Neg bpr: -0.48453548550605774
Log prob: -3.0042552947998047
Loss: -3.48879075050354


In [79]:
torch.mean(torch_bpr_uncurried(torch.tensor([[0,0,0,0,0,0,0,0,1,1,1,1]])*torch.ones_like(torch.tensor(train_y_TS)),torch.tensor(train_y_TS),K=4, perturbed_top_K_func=perturbed_top_K_func))


tensor(0.2520)

In [73]:
28/130

0.2153846153846154

In [25]:
M =  1000
train_T = train_y_TS.shape[0]
y_sample_TMS = mix_model.sample((train_T, M))

In [26]:
ratio_rating_TMS = y_sample_TMS/y_sample_TMS.sum(dim=-1, keepdim=True)
ratio_rating_TS =  ratio_rating_TMS.mean(dim=1)
ratio_rating_TS.requires_grad_(True)

tensor([[0.0829, 0.0823, 0.0825,  ..., 0.0791, 0.0794, 0.0867],
        [0.0837, 0.0819, 0.0831,  ..., 0.0869, 0.0860, 0.0828],
        [0.0862, 0.0787, 0.0771,  ..., 0.0796, 0.0818, 0.0868],
        ...,
        [0.0798, 0.0823, 0.0776,  ..., 0.0822, 0.0822, 0.0849],
        [0.0867, 0.0820, 0.0799,  ..., 0.0804, 0.0839, 0.0822],
        [0.0847, 0.0793, 0.0790,  ..., 0.0841, 0.0791, 0.0862]],
       requires_grad=True)

In [27]:
def get_log_probs_baked(param):
    distribution = model.build_from_single_tensor(param)
    log_probs_TMS = distribution.log_prob(y_sample_TMS)

    return log_probs_TMS

In [28]:
jac_TMSP = torch.autograd.functional.jacobian(get_log_probs_baked, (model.params_to_single_tensor()), strategy='forward-mode', vectorize=True)

In [29]:
score_func_estimator_TMSP = jac_TMSP * ratio_rating_TMS.unsqueeze(-1)
score_func_estimator_TSP = score_func_estimator_TMSP.mean(dim=1)

In [30]:
perturbed_top_K_func = PerturbedTopK(k=4)

In [31]:
# get gradient of negative bpr_t  with respect to ratio rating_TS
positive_bpr_T = torch_bpr_uncurried(ratio_rating_TS, torch.tensor(train_y_TS), K=4, perturbed_top_K_func=perturbed_top_K_func)
negative_bpr = torch.sum(-positive_bpr_T)
negative_bpr.backward()

In [32]:
loss_grad_TS = ratio_rating_TS.grad

In [33]:
gradient_TSP = score_func_estimator_TSP * torch.unsqueeze(loss_grad_TS, -1)

In [34]:
gradient_P = torch.sum(gradient_TSP, dim=[0,1])

In [37]:
model.update_params(model.params_to_single_tensor() - 0.05 * gradient_P)

['OptimizerPostHook',
 'OptimizerPreHook',
 '__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_cuda_graph_capture_health_check',
 '_group_tensors_by_device_and_dtype',
 '_init_group',
 '_optimizer_load_state_dict_post_hooks',
 '_optimizer_load_state_dict_pre_hooks',
 '_optimizer_state_dict_post_hooks',
 '_optimizer_state_dict_pre_hooks',
 '_optimizer_step_code',
 '_optimizer_step_post_hooks',
 '_optimizer_step_pre_hooks',
 '_patch_step_function',
 '_process_value_according_to_param_policy',
 '_warned_capturable_if_run_uncaptured',
 '_zero_grad_profile_name',
 'add_param_group',
 'defaults',
 'load_state_dict',
 'param