In [1]:
import numpy as np
import os
import torch 
os.chdir("../..")

from aexgym.env import PersSyntheticEnv
from aexgym.model import PersonalizedLinearModel
from aexgym.agent import LinearTS, LinearUniform, LinearUCB
from aexgym.objectives import contextual_best_arm, contextual_simple_regret
from scripts.setup_script import make_uniform_prior

In [2]:
n_days = 10
n_arms = 10
context_len = n_days 
n_steps = n_days 
batch_size = 100
s2 = 0.1 * torch.ones((n_days, 1))

if torch.cuda.is_available():
    device = 'cuda:0'
else:
    device = 'cpu'
print(device)


cpu


In [3]:
#personalization 

#initialize parameterss
n_objs = 1
scaling = 1 / (batch_size*10)
pers_beta, pers_sigma = make_uniform_prior(context_len*n_arms, scaling, n_objs=n_objs)
context_mu, context_var = torch.ones(context_len), 5*torch.eye(context_len)

#initialize synthetic and agent model 
model = PersonalizedLinearModel(
    beta_0 = pers_beta, 
    sigma_0 = pers_sigma, 
    n_arms = n_arms, 
    s2 = s2,  
    n_objs=n_objs
)

#initialize synthetic environment
env = PersSyntheticEnv(
    model = model, 
    context_mu = context_mu, 
    context_var = context_var, 
    context_len = context_len, 
    batch_size = batch_size, 
    n_steps = n_steps
)





In [4]:
#initialize agent 
agent = LinearUniform(model, "Linear Uniform")
#agent = LinearTS(mdp, "Linear TS", toptwo=False, n_samples = 1)
#agent = LinearTS(mdp, "Linear TS", toptwo=True, n_samples = 100)


In [5]:
print_probs = False
torch.manual_seed(0)
objective = contextual_simple_regret()
torch.set_printoptions(sci_mode=False)
regret_list = []
percent_arms_correct_list = []



for i in range(10000):
    env.reset()
    #print(env.mean_matrix)
    all_contexts, cur_step = env.reset()
    beta, sigma = agent.model.reset()
    #print(beta, sigma)
    beta, sigma = beta.to(device), sigma.to(device)
    beta_0, sigma_0 = beta.clone(), sigma.clone()
    
    while env.n_steps - cur_step > 0:

        #move to device 
        state_contexts, action_contexts, eval_contexts = tuple(contexts.to(device) for contexts in all_contexts)
        
        #get batch size 
        batch = state_contexts.shape[0]

        #get probabilities
        probs = agent(
            beta = beta, 
            sigma = sigma, 
            contexts = state_contexts, 
            action_contexts = action_contexts, 
            objective = objective
        )
     
        #print probabilities 
        if print_probs == True:
            print(agent.name, env.n_steps - cur_step, torch.mean(probs, dim=0))
        
        #get actions and move to new state
        actions = torch.distributions.Categorical(probs).sample()
        
        #move to next environment state 
        all_contexts, sampled_rewards, sampled_features, cur_step  = env.step(
            state_contexts = state_contexts, 
            action_contexts = action_contexts, 
            actions = actions
        )
        
        #update model state 
        beta, sigma = agent.model.update_posterior(
            beta = beta_0, 
            sigma = sigma_0, 
            rewards = sampled_rewards, 
            features = agent.model.feature_map(actions, state_contexts, action_contexts), 
            idx = cur_step-1
        )

    #get evaluation contexts and true rewards 
    eval_contexts = env.sample_eval_contexts(access=True).to(device)
    true_eval_rewards = env.get_true_rewards(eval_contexts, action_contexts)
    
    #calculate results from objective 
    results_dict = objective(
        fantasy_rewards = agent.fantasize(beta, eval_contexts, action_contexts).to(device), 
        true_rewards = true_eval_rewards.to(device)
    )
    
    #append results 
    percent_arms_correct_list.append(results_dict['percent_arms_correct'])
    regret_list.append(results_dict['regret'])

    #print results 
    if i % 10 == 0:
        
        print("Regret: ", np.mean(regret_list))
        print("Percent Arms Correct: ", np.mean(percent_arms_correct_list))

Regret:  0.08411479741334915
Percent Arms Correct:  0.5
Regret:  0.1705380156636238
Percent Arms Correct:  0.29545454545454547
Regret:  0.17850385216020404
Percent Arms Correct:  0.2785714285714286
Regret:  0.17754314359157317
Percent Arms Correct:  0.2887096774193549
Regret:  0.17021061661766795
Percent Arms Correct:  0.301219512195122
Regret:  0.1695796426604776
Percent Arms Correct:  0.30137254901960786
Regret:  0.16913791122983712
Percent Arms Correct:  0.29950819672131146
Regret:  0.16774130948412586
Percent Arms Correct:  0.3050704225352114
Regret:  0.1681516061042562
Percent Arms Correct:  0.30111111111111105
Regret:  0.16998028370377782
Percent Arms Correct:  0.2997802197802198
Regret:  0.170214253339437
Percent Arms Correct:  0.300990099009901
Regret:  0.16959955227804613
Percent Arms Correct:  0.30009009009009
Regret:  0.1683657125623758
Percent Arms Correct:  0.30181818181818176
Regret:  0.16872038436299971
Percent Arms Correct:  0.30282442748091604
Regret:  0.16883441367259

KeyboardInterrupt: 