In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join
from tqdm import tqdm
import pandas as pd
import sys
from IPython.display import display, HTML
from typing import List
from sasc.modules.emb_diff_module import EmbDiffModule
import numpy as np
import matplotlib
import imodelsx.util
from copy import deepcopy
import re
import notebook_helper
import sasc.viz
import scipy.special
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
import pandas as pd
from sasc.config import RESULTS_DIR
import torch.cuda
import json
from sklearn.linear_model import RidgeCV
# EXPT_NAME in ['huth2016clusters_mar21_i_time_traveled', 'voxels_mar21_hands_arms_emergency']:
# EXPT_NAME  = 'huth2016clusters_mar21_i_time_traveled'
# EXPT_DIR = join(RESULTS_DIR, 'stories', EXPT_NAME)
EXPT_NAMES = [
    # f'uts02_concepts_pilot_selected_mar28___ver={version}___seed={seed}'
    f'uts02_pilot_gpt4_mar28___ver={version}___seed={seed}'
                  for version in ['v4_noun', 'v5_noun']
                  for seed in [1, 2, 3, 4, 5, 6, 7]
                  ]

In [None]:
def process_timings(df: pd.DataFrame) -> pd.DataFrame:
    df['word_len'] = df['word'].apply(len)
    df['ends_in_period'] = df['word'].str.endswith('.')
    df['ends_in_comma'] = df['word'].str.endswith(',')

    # truncate values that are too large
    df['timing'] = df['timing'].apply(lambda x: min(x, 0.85))

    # fill na values with linreg
    X = df[['word_len', 'ends_in_period', 'ends_in_comma']].values
    y = df['timing'].values
    print('n', y.size, 'n_nan', np.sum(pd.isna(y)))
    idxs = ~pd.isna(y)
    m = RidgeCV()
    m.fit(X[idxs], y[idxs])
    if np.any(pd.isna(y)):
        df['timing'][~idxs] = m.predict(X[~idxs])

        # recompute running time
        df['time_running'] = np.cumsum(df['timing'])
    return df

n_skipped = 0
n_processed = 0
for EXPT_NAME in EXPT_NAMES:
    EXPT_DIR = join(RESULTS_DIR, 'stories', EXPT_NAME)
    timings_file = join(EXPT_DIR, 'timings.csv')
    if not os.path.exists(timings_file):
        print('skipping', EXPT_NAME)
        n_skipped += 1
        continue
    print('processing', EXPT_NAME)

    df = pd.read_csv(timings_file)
    print('shape', df.shape)
    try:
        df = process_timings(df)
        df.to_csv(join(EXPT_DIR, 'timings_processed.csv'), index=False)
        n_processed += 1
    except Exception as e:
        print('error', e)
        print('skipping', EXPT_NAME)
        n_skipped += 1
        continue
print('n_skipped', n_skipped, 'n_processed', n_processed)

In [None]:
df['ends_in'] = 'Normal'
df.loc[df['ends_in_period'], 'ends_in'] = 'Period'
df.loc[df['ends_in_comma'], 'ends_in'] = 'Comma'
# df = df.sort_values(by='ends_in')

sns.lmplot(data=df, x='word_len', y='timing', hue='ends_in', markers='.', legend=False, aspect=1.5)
plt.xlabel('Word length (characters)')
plt.ylabel('Word timing (sec)')
plt.axhline(np.mean(df['timing']), color='gray')
plt.text(16, np.mean(df['timing']) - 0.02, f'Mean: {np.mean(df["timing"]):.2f} sec', color='gray')
# change legend label value
handles, labels = plt.gca().get_legend_handles_labels()
plt.gca().legend(bbox_to_anchor=(1, 0.7), labelcolor='linecolor')
# plt.gca().legend(handles=handles, labels=['Does not end in period', 'Ends in period'], bbox_to_anchor=(1, 0.7))
plt.show()