In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import torch 
os.chdir("../..")

from aexgym.env import PersSyntheticEnv, RankingSyntheticEnv
from aexgym.model import PersonalizedLinearModel, PersonalizedRankingModel
from aexgym.agent import LinearTS, LinearUniform, LinearUCB, LinearRho, RankingUniform, RankingTS, RankingRho
from aexgym.objectives import contextual_best_arm, contextual_simple_regret
from scripts.setup_script import make_uniform_prior

In [2]:
n_days = 5
n_arms = 10
context_len = 8
n_steps = n_days 
batch_size = 100
s2 = 0.1 * torch.ones((n_days, 1))

n_items = 6
total_items = 50

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
print(device)


cpu


In [3]:
#personalization 

#initialize parameterss
n_objs = 1
scaling = 1 / (batch_size*100)
pers_beta, pers_sigma = make_uniform_prior(context_len, scaling, n_objs=n_objs)
user_context_mu, user_context_var = torch.ones(context_len), 0.5*torch.eye(context_len)
item_context_mu, item_context_var = torch.ones(context_len), 0.5*torch.eye(context_len)


#initialize synthetic and agent model 
model = PersonalizedRankingModel(
    beta_0 = pers_beta, 
    sigma_0 = pers_sigma, 
    n_arms = n_arms, 
    s2 = s2,  
    n_objs=n_objs
)

#initialize synthetic environment
env = RankingSyntheticEnv(
    true_env = model,
    n_steps = n_steps,
    user_context_mu = user_context_mu, 
    user_context_var = user_context_var,
    item_context_mu = item_context_mu,
    item_context_var = item_context_var, 
    context_len = context_len, 
    batch_size = batch_size,
    n_arms = n_arms,
    n_items = n_items,
    total_items = total_items
)





In [4]:
env.reset()
contexts, cur_step = env.reset()
state_contexts, action_contexts, eval_contexts = contexts 
user_contexts, item_contexts = state_contexts
n_items, ranking_contexts = action_contexts
print(user_contexts.shape, item_contexts.shape)
print(n_items, ranking_contexts.shape)

torch.Size([100, 8]) torch.Size([100, 50, 8])
6 torch.Size([10, 8])


In [36]:
#initialize agent  
agent = RankingUniform(model, "Linear Uniform")
#agent = RankingTS(model, "Linear TS", toptwo=False, n_samples = 1)
#agent = RankingTS(model, "Linear TS", toptwo=True, n_samples = 100)
#agent = RankingRho(model, "Linear Rho", lr=0.4)

In [37]:
print_probs = False
torch.manual_seed(0)
objective = contextual_simple_regret()
torch.set_printoptions(sci_mode=False)
regret_list = []
percent_arms_correct_list = []



for i in range(10000):
    cumul_regret = 0
    env.reset()
    #print(env.mean_matrix)
    all_contexts, cur_step = env.reset()
    beta, sigma = agent.model.reset()
    #print(beta, sigma)
    beta, sigma = beta.to(device), sigma.to(device)
    
    while env.n_steps - cur_step > 0:

        #move to device 
        state_contexts, action_contexts, eval_contexts = all_contexts 
        state_contexts = tuple(contexts.to(device) for contexts in state_contexts)
        eval_contexts = tuple(contexts.to(device) for contexts in eval_contexts)
        action_contexts = (action_contexts[0], action_contexts[1].to(device))
        
        #train agent 
        agent.train_agent( 
            beta = beta, 
            sigma = sigma, 
            cur_step = cur_step, 
            n_steps = n_steps, 
            train_context_sampler = env.sample_train_contexts, 
            eval_contexts = eval_contexts,
            eval_action_contexts = action_contexts, 
            real_batch = batch_size, 
            print_losses=False, 
            objective=objective,
            repeats=10000
        )   
        #get probabilities
        probs = agent(
            beta = beta, 
            sigma = sigma, 
            contexts = state_contexts, 
            action_contexts = action_contexts, 
            objective = objective
        ) 
     
        #print probabilities 
        if print_probs == True:
            print(agent.name, env.n_steps - cur_step, torch.mean(probs, dim=0))
        
        #get actions and move to new state
        actions = torch.distributions.Categorical(probs).sample()
        
        #move to next environment state 
        all_contexts, sampled_rewards, sampled_features, cur_step  = env.step(
            state_contexts = state_contexts, 
            action_contexts = action_contexts, 
            actions = actions
        )

        rewards = objective(
            agent_actions = actions,
            true_rewards = env.get_true_rewards(state_contexts, action_contexts)
        )

        cumul_regret += rewards['regret']
        
        #update model state 
        beta, sigma = agent.model.update_posterior(
            beta = beta, 
            sigma = sigma, 
            rewards = sampled_rewards, 
            features = sampled_features, 
            idx = cur_step-1
        )
    #get evaluation contexts and true rewards 
    eval_contexts = env.sample_eval_contexts(access=True)
    eval_contexts = tuple(contexts.to(device) for contexts in eval_contexts)
    true_eval_rewards = env.get_true_rewards(eval_contexts, action_contexts)
    fantasy_rewards = agent.fantasize(beta, eval_contexts, action_contexts).to(device)
    agent_actions = torch.argmax(fantasy_rewards.squeeze(), dim=1)
    #calculate results from objective
    #fantasy_rewards = torch.randn(fantasy_rewards.shape) 
    results_dict = objective(
        agent_actions = agent_actions, 
        true_rewards = true_eval_rewards.to(device)
    )
    
    #append results 
    percent_arms_correct_list.append(results_dict['percent_arms_correct'])
    regret_list.append(results_dict['regret'])

    #print results 
    if i % 1 == 0:
        
        print(i, "Regret: ", np.mean(regret_list))
        print("Percent Arms Correct: ", np.mean(percent_arms_correct_list))

0 Regret:  0.01405759435147047
Percent Arms Correct:  0.7
1 Regret:  0.010629332158714533
Percent Arms Correct:  0.725
2 Regret:  0.011854222354789576
Percent Arms Correct:  0.6833333333333332
3 Regret:  0.010795207461342216
Percent Arms Correct:  0.6875
4 Regret:  0.016048735193908214
Percent Arms Correct:  0.6519999999999999
5 Regret:  0.017061014504482348
Percent Arms Correct:  0.6466666666666666
6 Regret:  0.016052419851933206
Percent Arms Correct:  0.6571428571428571
7 Regret:  0.014642253110650927
Percent Arms Correct:  0.675
8 Regret:  0.01399844755522079
Percent Arms Correct:  0.6755555555555556
9 Regret:  0.012962699262425303
Percent Arms Correct:  0.687
10 Regret:  0.013904563362964174
Percent Arms Correct:  0.67
11 Regret:  0.014679149840958416
Percent Arms Correct:  0.6566666666666666
12 Regret:  0.01467341030589663
Percent Arms Correct:  0.643076923076923
13 Regret:  0.01579371257685125
Percent Arms Correct:  0.6214285714285713
14 Regret:  0.015383812443663677
Percent Arms

KeyboardInterrupt: 