In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import numpy as np
import os

In [2]:
SUBJECT_ROLES = ['automated', 'out', 'manual']
TRADER_SCALED_FIELDS = ['bid', 'offer', 'staged_bid', 'staged_offer', 'net_worth', 'cash', 'implied_bid', 'implied_offer',
                'best_bid_except_me', 'best_offer_except_me']
MARKET_SCALED_FIELDS = ['best_bid', 'best_offer', 'next_bid', 'next_offer', 'e_best_bid', 'e_best_offer', 'reference_price']
K_SCALE = 1e-4
MIN_BID = 0
MAX_OFFER = 2147483647
FOLDER_NAME_BASE = 'results/subsession_%s_%s'
REPORT_FILENAME_BASE = os.path.join(FOLDER_NAME_BASE, '%s.png')


# helpers
def category_filter(df, cat, col_name='player_id'):
    cond = df[col_name] == cat
    df_sub = df[cond]
    return df_sub

def get_files_by_session_id(session_id):
    """assume name format: '{session_id}_{record_class}_accessed_{timestamp}.csv'"""
    session_files = [f for f in os.listdir('data') if f.startswith(session_id)]
    market_file = [f for f in session_files if 'market' in f]
    agent_file = [f for f in session_files if 'agent' in f]
    assert len(agent_file) == len(market_file) == 1
    session_files = (market_file.pop(), agent_file.pop())
    print('session files: {0}'.format(session_files))
    return session_files
    
def extract_session_ids(df):
    first_row = df.iloc[0]
    subsession_id = first_row['subsession_id']
    market_ids = df['market_id'].unique()
    return (subsession_id, market_ids)


def make_dir(subsession_id, ts):
    try:
        os.mkdir(FOLDER_NAME_BASE % (subsession_id, ts))
    except OSError:
        pass
    

def extract_date(df, ts_col_name='timestamp'):
    return str(df[ts_col_name].iloc[0])

def get_nice_tag_for_market_id(market_id):
    return 'focal' if market_id == 0 else 'external'
    

def df_player_processor(df, roles_to_include=SUBJECT_ROLES):
    df_copy = df.copy()
    is_subject_record = df_copy.trader_model_name.isin(roles_to_include)
    df_copy = df_copy[is_subject_record]
    df_copy[TRADER_SCALED_FIELDS] = K_SCALE * df_copy[TRADER_SCALED_FIELDS]
    return df_copy

def df_market_processor(df, market_id):
    df_copy = df.copy()
    is_market = df_copy.market_id == market_id
    df_copy = df_copy[is_market]
    df_copy[MARKET_SCALED_FIELDS] = K_SCALE * df_copy[MARKET_SCALED_FIELDS]
    return df_copy


def df_processor(df, ts_col_name='timestamp'):
    df_copy = df.copy()
    tseries = df_copy[ts_col_name]
    tseries = pd.to_datetime(tseries)
    if tseries.shape[0] != 0:
        tseries = (tseries - tseries.iloc[0]).dt.total_seconds()
        df_copy[ts_col_name] = tseries
    return df_copy

In [3]:
# plotting
plt.style.use('fivethirtyeight')

def line_axis_builder(ax, ts_index_column, y, legend_labels, y_axis_title, y_axis_limit):
    for i, col in enumerate(y):
        ax.step(ts_index_column, y[col], linewidth=2, label=legend_labels[i], alpha=0.7)
    ax.legend(fontsize='medium')
    ax.set_xlabel('Time')
    bot_y_lim, top_y_lim = y_axis_limit
    ax.set_ylim(bottom=bot_y_lim, top=top_y_lim)
    ax.set_ylabel(y_axis_title, fontsize=18)
    ax.tick_params(axis='x', which='both', labelsize=12, labelbottom=True)
    ax.tick_params(axis='y', which='both', labelsize=16)
    ax.set_xticks(np.arange(0, int(max(ts_index_column)) + 1, 1))
    ax.grid(linestyle='-', linewidth=2, alpha=0.2)
    
def exchange_event_map_builder(ax, ts_col_name, df):
    accepts = df[df['trigger_msg_type'] == 'A']
    ax.scatter(accepts[ts_col_name], accepts['trigger_msg_type'], s=100, marker='+', c='black')
    updates = df[df['trigger_msg_type'] == 'U']
    ax.scatter(updates[ts_col_name], updates['trigger_msg_type'], s=100, marker='^', c='cyan')
    executions = df[df['trigger_msg_type'] == 'E']
    ax.scatter(executions[ts_col_name], executions['trigger_msg_type'], s=100, marker=(5, 2), c='red')
    ax.set_ylabel('Exchange Events', fontsize=18)
    ax.tick_params(axis='y', which='both', labelsize=20)
        
        

FIGURE_NAME_FORMAT = 'Subsession {session_id}:{market_id} - {model_name} {model_id} Report'
def hft_report(source_filename, model_name, plots_meta, figsize=(40, 16), dpi=100, ts_col_name='timestamp'):
    df = pd.read_csv(source_filename)
    session_time = extract_date(df)
    if not df.empty:
        df = df_processor(df)
        subsession_id, market_ids = extract_session_ids(df)
        make_dir(subsession_id, session_time)
        if model_name == 'Market':
            for market_id in market_ids:
                df_market = df_market_processor(df, market_id)
               # df_market = df_processor(df_market)
                model_id = market_id
                fig_title = FIGURE_NAME_FORMAT.format(model_name=model_name, 
                                                      market_id=get_nice_tag_for_market_id(market_id), 
                                                      model_id=get_nice_tag_for_market_id(market_id), 
                                                      session_id=subsession_id)
                fig_path = REPORT_FILENAME_BASE % (subsession_id, session_time, fig_title)
                ts_plotter(df_market, plots_meta, fig_title, figsize, ts_col_name, save_path=fig_path)
        elif model_name == 'Trader':
            df_sub = df_player_processor(df)
            unique_pls = df_sub[plots_meta['player_id_column_name']].unique()
            for player_id in unique_pls:
                df_player = category_filter(df_sub, player_id, col_name=plots_meta['player_id_column_name'])
                market_id = df_player.iloc[0]['market_id']
                fig_title = FIGURE_NAME_FORMAT.format(model_name=model_name, market_id=market_id, model_id=player_id, 
                                                      session_id=subsession_id)
                fig_path = REPORT_FILENAME_BASE % (subsession_id, session_time, fig_title)
                try:
                    ts_plotter(df_player, plots_meta, fig_title, figsize, ts_col_name, save_path=fig_path, exchange_events=True)
                except Exception as e:
                    print('exception', e)


def ts_plotter(df, plots_meta, fig_title, figsize, ts_col_name, save_path='figure', exchange_events=False):
    num_axes = len(plots_meta['y_groups'])
    if exchange_events:
        num_axes += 1
    fig, axes = plt.subplots(num_axes, 1, figsize=figsize, sharex=True)
    fig.suptitle(fig_title, fontsize=32)
    for ix, columns_set in enumerate(plots_meta['y_groups']):
        line_axis_builder(axes[ix], df[ts_col_name], df[columns_set], plots_meta['legend_labels'][ix],
                    plots_meta['y_titles'][ix], plots_meta['y_limits'][ix])
    if exchange_events:
        exchange_event_map_builder(axes[-1], ts_col_name, df)
    plt.savefig(save_path)
    plt.show()
    
