In [1]:
import matplotlib.pyplot as plt
from vowpalwabbit import pyvw
import random
import json


In [2]:
USER_LIKED_TEMPERATURE = -1.0
USER_DISLIKED_TEMPERATURE = 0.0

In [3]:
def get_cost(context, temperature):
    if context['user'] == "Tom":
        if context['time_of_day'] == "morning" and temperature < 18.0:
            return USER_LIKED_TEMPERATURE
        elif context['time_of_day'] == "afternoon" and temperature > 25.0:
            return USER_LIKED_TEMPERATURE
        else:
            return USER_DISLIKED_TEMPERATURE
    elif context['user'] == "Anna":
        if context['time_of_day'] == "morning" and temperature > 22.0:
            return USER_LIKED_TEMPERATURE
        elif context['time_of_day'] == "afternoon" and temperature < 18.0:
            return USER_LIKED_TEMPERATURE
        else:
            return USER_DISLIKED_TEMPERATURE
    else:
        return USER_DISLIKED_TEMPERATURE
    

def get_temp_to_suggest(context):
    if context['user'] == "Tom":
        if context['time_of_day'] == "morning":
            return 13.0
        elif context['time_of_day'] == "afternoon":
            return 28.0
    elif context['user'] == "Anna":
        if context['time_of_day'] == "morning":
            return 26.0
        elif context['time_of_day'] == "afternoon":
            return 13.0

In [4]:
def to_vw_example_format(context, temperature=None, cats_label=None):
    example_dict = {}
    if cats_label is not None:
        chosen_action, cost, pdf_value = cats_label
        example_dict['_label_ca'] = {'action' : chosen_action, 'cost': cost, 'pdf_value': pdf_value}
    if temperature is not None:
        example_dict['pdf'] = [{'chosen_action': temperature}]
    example_dict['c'] = {'user={}'.format(context['user']): 1, 'time_of_day={}'.format(context['time_of_day']) : 1}
    return json.dumps(example_dict)

In [5]:
users = ['Tom', 'Anna']
times_of_day = ['morning', 'afternoon']

def choose_user(users):
    return random.choice(users)

def choose_time_of_day(times_of_day):
    return random.choice(times_of_day)

def get_temperature(vw, context, temperature=None):
    vw_text_example = to_vw_example_format(context, temperature)
    return vw.predict(vw_text_example)

In [15]:
def run_simulation(vw, num_iterations, users, times_of_day, cost_function, do_learn = True):
    cost_sum = 0.
    ctr = []
    load_new_vw = True
    
    pos = 0
    neg = 0

    for i in range(1, num_iterations + 1):
        # choose a user randomly
        user = choose_user(users)
        # choose a time of day randomly
        time_of_day = choose_time_of_day(times_of_day)

        # context
        context = {'user': user, 'time_of_day': time_of_day}
        # give a guess
        temperature = 0
        pdf_value = 0
        
        # during the first 500 examples tell the model what temperature to predict (for first only)
        if i < 500:
            temp = get_temp_to_suggest(context)
            if get_cost(context, temp) != -1.0:
                raise RuntimeError("error")
            temperature, pdf_value = get_temperature(vw, context, temp)
        else:
            temperature, pdf_value = get_temperature(vw, context)
        # Get cost
        cost = cost_function(context, temperature)
        if i > 500:
            if cost == -1.0:
                pos += 1
            else:
                neg += 1
        cost_sum += cost

        if do_learn:
            vw_format = vw.parse(to_vw_example_format(context, cats_label=(temperature, cost, pdf_value)), pyvw.vw.lContinuous)
            vw.learn(vw_format)
            vw.finish_example(vw_format)

        if i > 500 and load_new_vw:
            # load the model again, but without using --first_only this time when predicting
            vw.finish()
            load_new_vw = False
            # TODO make it so that we resume learning here
            vw = pyvw.vw("--cats 32 --min_value 0 --max_value 32 --bandwidth 1 --dsjson --chain_hash -i mymodel.model")

        ctr.append(-1*cost_sum/i)
    vw.finish()
    print("positive {}, negative {}".format(pos, neg))
    return ctr

In [7]:
def plot_ctr(num_iterations, ctr):
    plt.show()
    plt.plot(range(1, num_iterations + 1), ctr)
    plt.xlabel('num_iterations', fontsize=14)
    plt.ylabel('ctr', fontsize=14)
    plt.ylim([0,1])

In [16]:
vw = pyvw.vw("--cats 32 --min_value 0 --max_value 32 --first_only --bandwidth 1 --quiet --dsjson --chain_hash -f mymodel.model")
num_iterations = 5000
ctr = run_simulation(vw, num_iterations, users, times_of_day, get_cost)
plot_ctr(num_iterations, ctr)

positive 1902, negative 2598
