In [1]:
import matplotlib.pyplot as plt
from vowpalwabbit import pyvw
import random
import json


In [2]:
USER_LIKED_TEMPERATURE = -1.0
USER_DISLIKED_TEMPERATURE = 0.0

In [3]:
def get_cost(context, temperature):
    if context['user'] == "Tom":
        if context['time_of_day'] == "morning" and temperature < 18.0 and temperature > 15.0:
            return USER_LIKED_TEMPERATURE
        elif context['time_of_day'] == "afternoon" and temperature > 25.0 and temperature < 29.0:
            return USER_LIKED_TEMPERATURE
        else:
            return USER_DISLIKED_TEMPERATURE
    elif context['user'] == "Anna":
        if context['time_of_day'] == "morning" and temperature > 22.0 and temperature < 29.0:
            return USER_LIKED_TEMPERATURE
        elif context['time_of_day'] == "afternoon" and temperature < 18.0 and temperature > 15.0:
            return USER_LIKED_TEMPERATURE
        else:
            return USER_DISLIKED_TEMPERATURE
    else:
        return USER_DISLIKED_TEMPERATURE
    

def get_temp_to_suggest(context):
    if context['user'] == "Tom":
        if context['time_of_day'] == "morning":
            return 16.0
        elif context['time_of_day'] == "afternoon":
            return 28.0
    elif context['user'] == "Anna":
        if context['time_of_day'] == "morning":
            return 26.0
        elif context['time_of_day'] == "afternoon":
            return 16.0

In [4]:
def to_vw_example_format(context, temperature=None, cats_label=None):
    example_dict = {}
    if cats_label is not None:
        chosen_action, cost, pdf_value = cats_label
        example_dict['_label_ca'] = {'action' : chosen_action, 'cost': cost, 'pdf_value': pdf_value}
    if temperature is not None:
        example_dict['pdf'] = [{'chosen_action': temperature}]
    example_dict['c'] = {'user={}'.format(context['user']): 1, 'time_of_day={}'.format(context['time_of_day']) : 1}
    return json.dumps(example_dict)

In [5]:
users = ['Tom', 'Anna']
times_of_day = ['morning', 'afternoon']

def choose_user(users):
    return random.choice(users)

def choose_time_of_day(times_of_day):
    return random.choice(times_of_day)

def get_temperature(vw, context, temperature=None):
    vw_text_example = to_vw_example_format(context, temperature)
    return vw.predict(vw_text_example)

In [46]:
def generate_new_example(skip, predict_vw_first_only, vw, i, cost_function):
    # choose a user randomly
    user = choose_user(users)
    # choose a time of day randomly
    time_of_day = choose_time_of_day(times_of_day)
    # context
    context = {'user': user, 'time_of_day': time_of_day}
    
    temperature = 0
    pdf_value = 0

    # during the first 'skip' examples tell the model what the ideal temperature is for the given context
    if i < skip:
        temp = get_temp_to_suggest(context)
        if get_cost(context, temp) != -1.0:
            raise RuntimeError("error")
        temperature, pdf_value = get_temperature(predict_vw_first_only, context, temp)
    else:
        # just predict
        temperature, pdf_value = get_temperature(vw, context)
    # Get cost
    cost = cost_function(context, temperature)

    return to_vw_example_format(context, cats_label=(temperature, cost, pdf_value)), cost

In [59]:
def reuse_example_from_file(vw, cost_function, file_reader):
    t, c, p, ctxt = next(file_reader)
    temperature, pdf_value = get_temperature(vw, ctxt)
    cost = cost_function(ctxt, temperature)
    return to_vw_example_format(ctxt, cats_label=(t, c, p)), cost

In [61]:
def get_next_from_file(output_file):
    with open(output_file, 'r') as f:
        for line in f:
            line = line.strip()
            data = json.loads(line)
            temperature = data['_label_ca']['action']
            cost = data['_label_ca']['cost']
            pdf_value = data['_label_ca']['pdf_value']
            context = {}
            for key in data['c']:
                context[key.split("=")[0]] = key.split("=")[1]
            yield temperature, cost, pdf_value, context

def run_simulation(num_iterations, users, times_of_day, cost_function, num_actions, bandwidth, skip=0, output_file=None, write_to_file=True, do_learn=True):
    cost_sum = 0.
    ctr = []
    load_new_vw = True
    
    predict_vw_first_only = pyvw.vw("--cats " + str(num_actions) + "  --bandwidth " + str(bandwidth) + " --quiet --min_value 0 --max_value 32 --first_only --coin --loss_option 1 --dsjson --chain_hash -t")
    vw = pyvw.vw("--cats " + str(num_actions) + "  --bandwidth " + str(bandwidth) + " --quiet --min_value 0 --max_value 32 --dsjson --chain_hash --coin --loss_option 1 -f mymodel.model")
    
    pos = 0
    neg = 0
    
    keep = []
    file_reader = None
    if output_file is not None:
        file_reader = get_next_from_file(output_file)

    for i in range(1, num_iterations + 1):
        txt_ex = None

        if output_file is None:
            txt_ex, cost = generate_new_example(skip, predict_vw_first_only, vw, i, cost_function)
            if i > skip: # don't count the hand fed predictions
                if cost == -1.0:
                    pos += 1
                else:
                    neg += 1
                cost_sum += cost
            if write_to_file:
                keep.append(json.loads(txt_ex))
        else:
            try:
                # re-use the data from the file
                txt_ex, cost = reuse_example_from_file(vw, cost_function, file_reader)
                if cost == -1.0:
                    pos += 1
                else:
                    neg += 1
                cost_sum += cost
            except StopIteration:
                break
                    

        if do_learn:
            vw_format = vw.parse(txt_ex, pyvw.vw.lContinuous)
            vw.learn(vw_format)
            vw.finish_example(vw_format)

        if output_file is not None:
            ctr.append(-1*cost_sum/i)
        elif i > skip:
            ctr.append(-1*cost_sum/(i-skip))
    
    if write_to_file:
        with open('outfile.json', 'w') as f:
            for line in keep:
                f.write(json.dumps(line))
                f.write('\n')

    vw.finish()
    print("actions {}, bandwidth {}, positive {}, negative {}".format(num_actions, bandwidth, pos, neg))
    return ctr, pos

In [40]:
def plot_ctr(num_iterations, ctr, title, skip=0):
    plt.show()
    plt.plot(range(1, num_iterations + 1 - skip), ctr)
    plt.xlabel('num_iterations', fontsize=14)
    plt.ylabel('ctr', fontsize=14)
    plt.title(title)
    plt.ylim([0,1])

In [34]:
def plot_ctr_cfe(num_iterations, actions, bandwidths, data, skip=0):
    plt.show()
    fig, axs = plt.subplots(len(actions), len(bandwidths))
    for i in range(0, len(actions)):
        for j in range(0, len(bandwidths)):
            if bandwidths[j] >= actions[i]:
                axs[i, j].set_title('NA')
                continue
            ctr, pos = data[str(actions[i])][str(bandwidths[j])]
            hits = (pos/num_iterations)*100
            axs[i, j].plot(range(1, num_iterations + 1 - skip), ctr)
#             axs[i, j].set_title('k {} b {} hits {:.2f}%'.format(actions[i], bandwidths[j], hits))
            axs[i, j].set_title('hits {:.2f}%'.format(hits))
            axs[i, j].set_ylim([0,1])

    for i, row in enumerate(axs):
        for j, ax in enumerate(row):
            ax.set_xlabel('b: ' + str(bandwidths[j%len(bandwidths)]), fontsize=14)
            ax.set_ylabel('k: ' + str(actions[i%len(actions)]), fontsize=14)

    fig.text(0.5, 0.04, 'num_iterations', ha='center', fontsize=14)
    fig.text(0.04, 0.5, 'ctr', va='center', rotation='vertical', fontsize=14)
    fig.set_figheight(18)
    fig.set_figwidth(30)
    plt.suptitle('#examples {}'.format(num_iterations))

    # Hide x labels and tick labels for top plots and y ticks for right plots.
    for ax in axs.flat:
        ax.label_outer()

In [62]:
num_iterations = 50000

# generate the data using 32 actions
ctr = run_simulation(num_iterations, users, times_of_day, get_cost, num_actions=32, bandwidth=1, skip=0, output_file=None, write_to_file=True, do_learn=True)

data = {}
# use the data to do parameter sweeping
num_actions = [32, 64, 128, 256, 512, 1024, 2048, 4096]
bandwidths = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]
for actions in num_actions:
    for bd in bandwidths:
        if str(actions) not in data:
            data[str(actions)] = {}
        if bd >= actions:
            continue
        ctr, pos = run_simulation(num_iterations, users, times_of_day, get_cost, num_actions=actions, bandwidth=bd, skip=0, output_file="outfile.json", write_to_file=False, do_learn=True)
        data[str(actions)][str(bd)] = (ctr, pos)

plot_ctr_cfe(num_iterations, num_actions, bandwidths, data, 0)