In [6]:
%load_ext autoreload
%autoreload 2
import os
import sys
import torch
import numpy as np
import pandas as pd
import time
import argparse
from functools import partial

# add code directory to path
import sys
sys.path.append('/cluster/home/kheuto01/code/prob_diff_topk')

from metrics import top_k_onehot_indicator
from torch_perturb.perturbations import perturbed
from torch_models import NegativeBinomialRegressionModel, torch_bpr_uncurried, deterministic_bpr


The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
def load_data(data_dir):
    """Load and process training, validation, and test data."""
    # Load data
    train_X_df = pd.read_csv(os.path.join(data_dir, 'train_x.csv'), index_col=[0,1])
    train_Y_df = pd.read_csv(os.path.join(data_dir, 'train_y.csv'), index_col=[0,1])
    val_X_df = pd.read_csv(os.path.join(data_dir, 'valid_x.csv'), index_col=[0,1])
    val_Y_df = pd.read_csv(os.path.join(data_dir, 'valid_y.csv'), index_col=[0,1])
    test_X_df = pd.read_csv(os.path.join(data_dir, 'test_x.csv'), index_col=[0,1])
    test_Y_df = pd.read_csv(os.path.join(data_dir, 'test_y.csv'), index_col=[0,1])
    
    def convert_df_to_3d_array(df):
        geoids = sorted(df.index.get_level_values('geoid').unique())
        timesteps = sorted(df.index.get_level_values('timestep').unique())
        geoid_to_idx = {geoid: idx for idx, geoid in enumerate(geoids)}
        
        num_timesteps = len(timesteps)
        num_locations = len(geoids)
        num_features = len(df.columns)
        X = np.zeros((num_timesteps, num_locations, num_features))
        
        for (geoid, timestep), row in df.iterrows():
            t_idx = timesteps.index(timestep)
            g_idx = geoid_to_idx[geoid]
            X[t_idx, g_idx, :] = row.values
            
        return X, geoids, timesteps

    def convert_y_df_to_2d_array(y_df, geoids, timesteps):
        num_timesteps = len(timesteps)
        num_locations = len(geoids)
        y = np.zeros((num_timesteps, num_locations))
        geoid_to_idx = {geoid: idx for idx, geoid in enumerate(geoids)}
        
        for (geoid, timestep), value in y_df.iloc[:, 0].items():
            t_idx = timesteps.index(timestep)
            g_idx = geoid_to_idx[geoid]
            y[t_idx, g_idx] = value
            
        return y

    # Process training data
    train_X, geoids, timesteps = convert_df_to_3d_array(train_X_df)
    train_time = np.array([timesteps] * len(geoids)).T
    train_y = convert_y_df_to_2d_array(train_Y_df, geoids, timesteps)

    # Process validation data
    val_X, val_geoids, val_timesteps = convert_df_to_3d_array(val_X_df)
    val_time = np.array([val_timesteps] * len(val_geoids)).T
    val_y = convert_y_df_to_2d_array(val_Y_df, val_geoids, val_timesteps)

    # Process test data
    test_X, test_geoids, test_timesteps = convert_df_to_3d_array(test_X_df)
    test_time = np.array([test_timesteps] * len(test_geoids)).T
    test_y = convert_y_df_to_2d_array(test_Y_df, test_geoids, test_timesteps)

    return {
        'train': (torch.tensor(train_X, dtype=torch.float32), 
                 torch.tensor(train_time, dtype=torch.float32),
                 torch.tensor(train_y, dtype=torch.float32),
                 geoids, timesteps),
        'val': (torch.tensor(val_X, dtype=torch.float32),
               torch.tensor(val_time, dtype=torch.float32),
               torch.tensor(val_y, dtype=torch.float32)),
        'test': (torch.tensor(test_X, dtype=torch.float32),
                torch.tensor(test_time, dtype=torch.float32),
                torch.tensor(test_y, dtype=torch.float32))
    }

In [8]:
good_nll_model = '/cluster/tufts/hugheslab/kheuto01/opioid_grid_try_fix_params/cook/K100_bw30_nw1_ss0.001_nss100_nps100_seed123_sig0.001_tr0.5'
data_dir = '/cluster/tufts/hugheslab/datasets/NSF_OD/cleaned/cook'

In [9]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)
data = load_data(data_dir)

Device: cuda


In [10]:
# Load model
model_path = os.path.join(good_nll_model, 'best_model.pth')

    
# Initialize model with correct parameters
model = NegativeBinomialRegressionModel(
    num_locations=data['train'][0].shape[1],
    num_fixed_effects=data['train'][0].shape[2]
).to(device)


# Load saved weights
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

NegativeBinomialRegressionModel()

In [11]:
torch.load(model_path, map_location=device)

OrderedDict([('beta_0', tensor([-0.1400], device='cuda:0')),
             ('beta',
              tensor([ 0.6625, -0.4004, -0.2112, -0.6044,  0.7179, -0.4108, -0.4261,  0.7077,
                       0.3924,  0.2103,  0.2493,  0.3929,  0.1892, -0.1764], device='cuda:0')),
             ('b_0',
              tensor([ 1.0817,  0.0412,  0.1494,  ...,  0.2998, -0.9449,  0.9793],
                     device='cuda:0')),
             ('b_1',
              tensor([0.1904, 0.4757, 0.2961,  ..., 0.1005, 0.0554, 0.9847], device='cuda:0')),
             ('log_sigma_0', tensor([2.4615], device='cuda:0')),
             ('log_sigma_1', tensor([1.3859], device='cuda:0')),
             ('rho', tensor([0.0228], device='cuda:0')),
             ('softinv_theta', tensor([-0.9522], device='cuda:0'))])

In [13]:
train_X = data['train'][0].to(device)

In [14]:
X_train = torch.tensor(train_X, dtype=torch.float32).to(device)
device

  X_train = torch.tensor(train_X, dtype=torch.float32).to(device)


device(type='cuda')

In [15]:
geoids = data['train'][3]
timesteps = data['train'][4]
train_time_arr = np.array([timesteps] * len(geoids)).T
time_train = torch.tensor(train_time_arr, dtype=torch.float32).to(device)

In [16]:
num_score_samples = 100
num_pert_samples = 100
model=model.to(device)

In [17]:
time_train

tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [3., 3., 3.,  ..., 3., 3., 3.],
        [4., 4., 4.,  ..., 4., 4., 4.],
        [5., 5., 5.,  ..., 5., 5., 5.]], device='cuda:0')

In [18]:
dist = model(X_train, time_train)

y_sample_TMS = dist.sample((num_score_samples,)).permute(1, 0, 2)
y_sample_action_TMS = y_sample_TMS

ratio_rating_TMS = y_sample_action_TMS/y_sample_action_TMS.sum(dim=-1, keepdim=True)
ratio_rating_TS = ratio_rating_TMS.mean(dim=1)
ratio_rating_TS.requires_grad_(True)

def get_log_probs_baked(param):
    distribution = model.build_from_single_tensor(param, X_train, time_train)
    log_probs_TMS = distribution.log_prob(y_sample_TMS.permute(1, 0, 2)).permute(1, 0, 2)
    return log_probs_TMS

jac_TMSP = torch.autograd.functional.jacobian(get_log_probs_baked, 
                                            (model.params_to_single_tensor()), 
                                            strategy='forward-mode', 
                                            vectorize=True)

score_func_estimator_TMSP = jac_TMSP * ratio_rating_TMS.unsqueeze(-1)
score_func_estimator_TSP = score_func_estimator_TMSP.mean(dim=1)    

unconstrained theta: Parameter containing:
tensor([-0.9522], device='cuda:0', requires_grad=True)
Theta: tensor([0.3264], device='cuda:0', grad_fn=<SoftplusBackward0>)
unconstrained theta: Parameter containing:
tensor([-0.9522], device='cuda:0', requires_grad=True)
Theta: tensor([0.3264], device='cuda:0', grad_fn=<SoftplusBackward0>)


In [23]:
jac_TMSP.max()

tensor(0., device='cuda:0')

In [26]:
get_log_probs_baked(model.params_to_single_tensor())

unconstrained theta: Parameter containing:
tensor([-0.9522], device='cuda:0', requires_grad=True)
Theta: tensor([0.3264], device='cuda:0', grad_fn=<SoftplusBackward0>)


tensor([[[-1.6192, -1.1911, -2.4130,  ..., -0.8027, -0.3322, -1.4380],
         [-1.2150, -0.9942, -1.1915,  ..., -0.8027, -0.3322, -1.5018],
         [-1.2113, -0.9942, -1.7219,  ..., -1.2135, -0.3322, -1.8762],
         ...,
         [-1.6192, -0.9942, -1.0293,  ..., -0.8027, -0.3322, -2.2692],
         [-1.2113, -0.9942, -1.1915,  ..., -0.8027, -0.3322, -1.8762],
         [-2.2131, -2.4571, -1.1915,  ..., -1.9173, -0.3322, -1.4380]],

        [[-2.2596, -1.6920, -1.6904,  ..., -2.0296, -0.4506, -1.5144],
         [-1.2037, -1.1938, -1.0786,  ..., -0.7086, -1.4389, -1.5144],
         [-1.2037, -1.0760, -2.3548,  ..., -0.7086, -0.4506, -1.5144],
         ...,
         [-1.1674, -1.0760, -1.0786,  ..., -0.7086, -0.4506, -2.0603],
         [-2.9814, -2.3577, -3.1151,  ..., -1.2441, -0.4506, -2.1567],
         [-1.1674, -1.0760, -1.0786,  ..., -2.0296, -0.4506, -1.5285]],

        [[-1.4204, -1.1933, -1.0183,  ..., -1.2826, -1.6508, -2.3823],
         [-1.2605, -1.0682, -1.7294,  ..., -3