In [34]:
import torch
import gpytorch
from torch import nn
import pandas as pd
import numpy as np

class IndicatorEncoder(nn.Module):
    def __init__(self, indicator_dim, context_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(indicator_dim, 128),
            nn.ReLU(),
            nn.Linear(128, context_dim)
        )
    
    def forward(self, indicators):
        return self.net(indicators)

class ContextualKernel(gpytorch.kernels.Kernel):
    def __init__(self, base_kernel, context_dim, param_dim):
        super().__init__()
        self.base_kernel = base_kernel
        self.context_dim = context_dim
        self.param_dim = param_dim
        self.context_scaling = nn.Linear(context_dim, 1)
    
    def forward(self, x1, x2, diag=False, **params):
        # x1 and x2 are (parameters, context embeddings)
        p1, c1 = x1[..., :self.param_dim], x1[..., self.param_dim:]
        p2, c2 = x2[..., :self.param_dim], x2[..., self.param_dim:]
        
        base_k = self.base_kernel(p1, p2, diag=diag, **params)
        
        
        if diag:
            # Compute context similarity just between matching points
            context_diff = c1 - c2  # [N, context_dim]
            context_similarity = torch.exp(-torch.norm(context_diff, dim=-1))  # [N]
            
            scaling_c1 = self.context_scaling(c1).sigmoid().squeeze(-1)  # [N]
            scaling_c2 = self.context_scaling(c2).sigmoid().squeeze(-1)  # [N]
            scaling = scaling_c1 * scaling_c2  # [N]
            
            return base_k * context_similarity * scaling
        else:
            # Full matrix case
            c1_exp = c1.unsqueeze(-2)  # [N, 1, context_dim]
            c2_exp = c2.unsqueeze(-3)  # [1, M, context_dim]
            context_diff = c1_exp - c2_exp  # [N, M, context_dim]
            context_similarity = torch.exp(-torch.norm(context_diff, dim=-1))  # [N, M]
            
            scaling_c1 = self.context_scaling(c1).sigmoid()  # [N, 1]
            scaling_c2 = self.context_scaling(c2).sigmoid()  # [M, 1]
            scaling = scaling_c1 * scaling_c2.transpose(-1, -2)  # [N, M]
            
            return base_k * context_similarity * scaling
        

class RewardVariationalGPModel(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points, context_dim):
        # Define variational distribution + strategy
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(0)
        )
        variational_strategy = gpytorch.variational.VariationalStrategy(
            self, inducing_points, variational_distribution, learn_inducing_locations=True
        )
        super().__init__(variational_strategy)
        
        self.mean_module = gpytorch.means.ConstantMean()
        base_kernel = gpytorch.kernels.RBFKernel()
        self.covar_module = ContextualKernel(base_kernel, context_dim, param_dim)
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x, x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)


class NN_VariationalGP_Model(nn.Module):
    def __init__(self, indicator_dim, param_dim, context_dim, num_inducing=128):
        super().__init__()
        self.indicator_encoder = IndicatorEncoder(indicator_dim, context_dim)
        self.likelihood = gpytorch.likelihoods.GaussianLikelihood()
        self.num_inducing = num_inducing
        self.gp_model = None  # will create dynamically
    
    def initialize_gp(self, indicators, parameters):
        context_embeddings = self.indicator_encoder(indicators)
        train_x = torch.cat([parameters, context_embeddings], dim=-1)
        
        # Choose random inducing points from training data
        rand_idx = torch.randperm(train_x.size(0))[:self.num_inducing]
        inducing_points = train_x[rand_idx]
        
        self.gp_model = RewardVariationalGPModel(inducing_points, context_dim=context_embeddings.shape[-1])
    
    def forward(self, parameters, indicators):
        context_embeddings = self.indicator_encoder(indicators)
        test_x = torch.cat([parameters, context_embeddings], dim=-1)
        
        self.gp_model.eval()
        self.likelihood.eval()
        with torch.no_grad():
            pred = self.likelihood(self.gp_model(test_x))
        return pred


In [53]:
example_df = pd.DataFrame(columns=['indicators', 'samples'])

indicators = np.arange(0, 10)

for indicator in indicators:
    parameters = np.arange(0, 1, 0.01)
    values = parameters + indicator

    samples = np.array([[param, param + indicator] for param in parameters])
    
    example_df = pd.concat([example_df, pd.DataFrame({'indicators': [[indicator, 10 - indicator]], 'samples': [samples]})], ignore_index=True)


In [54]:
# Load your training data the same way
# train_indicators, train_parameters, train_rewards

indicator_dim = 2
param_dim = 1
context_dim = 32

# Assume you already have your dataframe
# df = pd.read_pickle("your_dataframe.pkl")  # or however you load
df = example_df.copy()

# Sample training points
train_indicators = []
train_parameters = []
train_rewards = []

for idx, row in df.iterrows():
    indicators = torch.tensor(row['indicators'], dtype=torch.float32)
    param_reward_array = row['samples']  # array of shape (N_samples, param_dim + 1)
    
    for sample in param_reward_array:
        param_vec = torch.tensor(sample[:-1], dtype=torch.float32)
        reward_val = torch.tensor(sample[-1], dtype=torch.float32)
        
        train_indicators.append(indicators)
        train_parameters.append(param_vec)
        train_rewards.append(reward_val)

train_indicators = torch.stack(train_indicators)
train_parameters = torch.stack(train_parameters)
train_rewards = torch.stack(train_rewards)

In [55]:
# Load your training data the same way
# train_indicators, train_parameters, train_rewards

model = NN_VariationalGP_Model(indicator_dim, param_dim, context_dim)

model.initialize_gp(train_indicators, train_parameters)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
mll = gpytorch.mlls.VariationalELBO(model.likelihood, model.gp_model, num_data=train_indicators.size(0))

model.train()
for i in range(1000):
    optimizer.zero_grad()
    
    cat = torch.cat([train_parameters, model.indicator_encoder(train_indicators)], dim=-1)
    # print(cat)

    # output = model.likelihood(model.gp_model(cat))
    output = model.gp_model(cat)
    loss = -mll(output, train_rewards)
    loss.backward()
    optimizer.step()
    
    if (i + 1) % 100 == 0:
        print(f"Iter {i + 1}: Loss = {loss.item():.4f}")


Iter 100: Loss = 11.7795
Iter 200: Loss = 7.7332
Iter 300: Loss = 5.8599
Iter 400: Loss = 4.8083
Iter 500: Loss = 4.1567
Iter 600: Loss = 3.7295
Iter 700: Loss = 3.4399
Iter 800: Loss = 3.2389
Iter 900: Loss = 3.0971
Iter 1000: Loss = 2.9954


In [51]:
# Given new indicator vector and candidate parameters
# new_indicators = torch.tensor([[...]], dtype=torch.float32)
# new_parameters = torch.tensor([[...]], dtype=torch.float32)

new_indicators = train_indicators
new_parameters = train_parameters

# cat = torch.cat([train_parameters, model.indicator_encoder(train_indicators)], dim=-1)
# predictions = model(cat)
predictions = model(new_parameters, new_indicators)
pred_mean = predictions.mean
pred_var = predictions.variance


In [52]:
pred_mean, pred_var

(tensor([4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491, 4.3491,
         4.3491, 4.3491, 4.3