In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import torch 
os.chdir("../..")

from aexgym.env import PersSyntheticEnv, RankingSyntheticEnv
from aexgym.model import PersonalizedLinearModel, PersonalizedRankingModel
from aexgym.agent import LinearTS, LinearUniform, LinearUCB, LinearRho, RankingUniform, RankingTS, RankingRho
from aexgym.objectives import contextual_best_arm, contextual_simple_regret
from scripts.setup_script import make_uniform_prior

In [37]:
n_days = 5
n_arms = 10
context_len = 100
n_steps = n_days 
batch_size = 100
s2 = 0.1 * torch.ones((n_days, 1))

n_items = 6
total_items = 50

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
print(device)


cuda:0


In [38]:
#personalization 

#initialize parameterss
n_objs = 1
scaling = 1 / (batch_size*100)
pers_beta, pers_sigma = make_uniform_prior(context_len, scaling, n_objs=n_objs)
user_context_mu, user_context_var = torch.ones(context_len), 0.5*torch.eye(context_len)
item_context_mu, item_context_var = torch.ones(context_len), 0.5*torch.eye(context_len)


#initialize synthetic and agent model 
model = PersonalizedRankingModel(
    beta_0 = pers_beta, 
    sigma_0 = pers_sigma, 
    n_arms = n_arms, 
    s2 = s2,  
    n_objs=n_objs
)

#initialize synthetic environment
env = RankingSyntheticEnv(
    true_env = model,
    n_steps = n_steps,
    user_context_mu = user_context_mu, 
    user_context_var = user_context_var,
    item_context_mu = item_context_mu,
    item_context_var = item_context_var, 
    context_len = context_len, 
    batch_size = batch_size,
    n_arms = n_arms,
    n_items = n_items,
    total_items = total_items
)





In [39]:
env.reset()
contexts, cur_step = env.reset()
state_contexts, action_contexts, eval_contexts = contexts 
user_contexts, item_contexts = state_contexts
n_items, ranking_contexts = action_contexts
print(user_contexts.shape, item_contexts.shape)
print(n_items, ranking_contexts.shape)

torch.Size([100, 100]) torch.Size([100, 50, 100])
6 torch.Size([10, 100])


In [46]:
#initialize agent  
agent = RankingUniform(model, "Linear Uniform")
#agent = RankingTS(model, "Linear TS", toptwo=False, n_samples = 1)
#agent = RankingTS(model, "Linear TS", toptwo=True, n_samples = 100)
#agent = RankingRho(model, "Linear Rho", lr=0.4)

In [47]:
print_probs = False
torch.manual_seed(0)
objective = contextual_simple_regret()
torch.set_printoptions(sci_mode=False)
regret_list = []
percent_arms_correct_list = []



for i in range(10000):
    cumul_regret = 0
    env.reset()
    #print(env.mean_matrix)
    all_contexts, cur_step = env.reset()
    beta, sigma = agent.model.reset()
    #print(beta, sigma)
    beta, sigma = beta.to(device), sigma.to(device)
    while env.n_steps - cur_step > 0:

        #move to device 
        state_contexts, action_contexts, eval_contexts = all_contexts 
        state_contexts = tuple(contexts.to(device) for contexts in state_contexts)
        eval_contexts = tuple(contexts.to(device) for contexts in eval_contexts)
        action_contexts = (action_contexts[0], action_contexts[1].to(device))
        #train agent 
        agent.train_agent( 
            beta = beta, 
            sigma = sigma, 
            cur_step = cur_step, 
            n_steps = n_steps, 
            train_context_sampler = env.sample_train_contexts, 
            eval_contexts = eval_contexts,
            eval_action_contexts = action_contexts, 
            real_batch = batch_size, 
            print_losses=False, 
            objective=objective,
            repeats=10000
        )   
        #get probabilities
        probs = agent(
            beta = beta, 
            sigma = sigma, 
            contexts = state_contexts, 
            action_contexts = action_contexts, 
            objective = objective
        ) 
     
        #print probabilities 
        if print_probs == True:
            print(agent.name, env.n_steps - cur_step, torch.mean(probs, dim=0))
        
        #get actions and move to new state
        actions = torch.distributions.Categorical(probs).sample()
        
        #move to next environment state 
        all_contexts, sampled_rewards, sampled_features, cur_step  = env.step(
            state_contexts = state_contexts, 
            action_contexts = action_contexts, 
            actions = actions
        )


        rewards = objective(
            agent_actions = actions,
            true_rewards = env.get_true_rewards(state_contexts, action_contexts)
        )

        cumul_regret += rewards['regret']
        
        #update model state 
        beta, sigma = agent.model.update_posterior(
            beta = beta, 
            sigma = sigma, 
            rewards = sampled_rewards, 
            features = sampled_features, 
            idx = cur_step-1
        )
    #get evaluation contexts and true rewards 
    eval_contexts = env.sample_eval_contexts(access=True)
    eval_contexts = tuple(contexts.to(device) for contexts in eval_contexts)
    true_eval_rewards = env.get_true_rewards(eval_contexts, action_contexts)
    fantasy_rewards = agent.fantasize(beta, eval_contexts, action_contexts).to(device)
    agent_actions = torch.argmax(fantasy_rewards.squeeze(), dim=1)
    #calculate results from objective
    #fantasy_rewards = torch.randn(fantasy_rewards.shape) 
    results_dict = objective(
        agent_actions = agent_actions, 
        true_rewards = true_eval_rewards.to(device)
    )
    
    #append results 
    percent_arms_correct_list.append(results_dict['percent_arms_correct'])
    regret_list.append(results_dict['regret'])

    #print results 
    if i % 1 == 0:
        
        print(i, "Regret: ", np.mean(regret_list))
        print("Percent Arms Correct: ", np.mean(percent_arms_correct_list))

0 Regret:  0.038807712495326996
Percent Arms Correct:  0.66
1 Regret:  0.04593134671449661
Percent Arms Correct:  0.615
2 Regret:  0.045784598837296166
Percent Arms Correct:  0.61
3 Regret:  0.04896984621882439
Percent Arms Correct:  0.5800000000000001
4 Regret:  0.0506954088807106
Percent Arms Correct:  0.562
5 Regret:  0.04813119644920031
Percent Arms Correct:  0.5800000000000001
6 Regret:  0.04655326743211065
Percent Arms Correct:  0.5857142857142857
7 Regret:  0.04586641909554601
Percent Arms Correct:  0.6012500000000001
8 Regret:  0.045859927104579076
Percent Arms Correct:  0.6055555555555556
9 Regret:  0.04438349362462759
Percent Arms Correct:  0.608
10 Regret:  0.04441931217231534
Percent Arms Correct:  0.6072727272727273
11 Regret:  0.04423101603363951
Percent Arms Correct:  0.6075
12 Regret:  0.04285366188448209
Percent Arms Correct:  0.6146153846153847
13 Regret:  0.042702760946537764
Percent Arms Correct:  0.6157142857142858
14 Regret:  0.04274266796807448
Percent Arms Corre

KeyboardInterrupt: 