In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
from os.path import join
from tqdm import tqdm
import pandas as pd
from os.path import expanduser
import sys
from typing import List
import numpy as np
import joblib
from pprint import pprint
import imodelsx.util
from os.path import dirname
import pickle as pkl
import json
from copy import deepcopy
from numpy.linalg import norm
from math import ceil
import questions
from imodelsx.qaemb.qaemb import QAEmb, get_sample_questions_and_examples

from questions import QS_O1_DEC26
from treebank import STORIES_POPULAR, STORIES_UNPOPULAR, ECOG_DIR


### Get popular stories

In [None]:
subject_metadata_files = os.listdir(join(ECOG_DIR, 'data', 'subject_metadata'))
jsons = {f: json.load(open(join(ECOG_DIR, 'data', 'subject_metadata', f)))
         ['title'] for f in subject_metadata_files}
df = pd.DataFrame(jsons.values(), index=jsons.keys()).reset_index()
df.rename(columns={'index': 'filename', 0: 'title'}, inplace=True)
df['subject'] = df['filename'].apply(lambda x: x.split('_trial')[
                                     0].split('_')[-1]).astype(int)
# df = df.sort_values(by='subject')
df = df.sort_values(by='title')

common_subjs = {3, 4, 6, 7, 10}
stories_popular = df[df.subject.isin(common_subjs)].title.unique()
stories_unpopular = df[~df.subject.isin(common_subjs)].title.unique()
print(f'{len(stories_popular)=} {len(stories_unpopular)=}')
stories_popular, stories_unpopular

# Extract features

In [193]:
stories_to_run = STORIES_POPULAR
qs_to_run = QS_O1_DEC26
checkpoint = 'meta-llama/Meta-Llama-3-8B-Instruct'
checkpoint_clean = checkpoint.replace('/', '___')
setting = 'words'

transcript_folders = os.listdir(join(ECOG_DIR, 'data', 'transcripts'))
output_dir_clean = join(ECOG_DIR, 'features', checkpoint_clean, setting)
output_dir_raw = join(ECOG_DIR, 'features_raw', checkpoint_clean, setting)
os.makedirs(output_dir_clean, exist_ok=True)
os.makedirs(output_dir_raw, exist_ok=True)

In [None]:
qa_embedder = QAEmb(
    questions=[QS_O1_DEC26[0]],
    checkpoint=checkpoint,
    batch_size=512,
    # CACHE_DIR=expanduser("~/cache_qa_ecog"),
    CACHE_DIR=None,
)

In [None]:
def get_texts(features_df, setting='words', replace_nan_with_empty_string=True):
    if setting == 'words':
        texts = features_df['text'].values.flatten()
    if replace_nan_with_empty_string:
        texts = [t if isinstance(t, str) else '""' for t in texts]
    return texts


for story in tqdm(stories_popular, desc='stories'):
    # for story in stories_unpopular:
    story_fname = (
        story.replace(' ', '-').lower()
        .replace('lord-of-the-rings', 'lotr')
        .replace('spiderman', 'spider-man')
        .replace('the-incredibles', 'incredibles')
        .replace('antman', 'ant-man')
        .replace('mr.', 'mr')
        .replace('spider-man-homecoming', 'spider-man-3-homecoming')
    )
    assert story_fname in transcript_folders, f'{story_fname} not found'
    features_df = pd.read_csv(
        join(ECOG_DIR, 'data', 'transcripts', story_fname, 'features.csv'))

    answers_dict = {}
    for q in tqdm(qs_to_run, desc='question', leave=False):
        output_file_q = join(output_dir_raw, f'{story_fname}___{q}.pkl')

        if os.path.exists(output_file_q):
            answers_dict[q] = joblib.load(output_file_q)
            print(f'Loaded {output_file_q}')
        else:
            texts = get_texts(features_df, setting='words')
            assert len(texts) == len(
                features_df), f'{len(texts)=} {len(features_df)=}'
            qa_embedder.questions = [q]
            answers = qa_embedder(
                texts, speed_up_with_unique_calls=True).flatten()
            joblib.dump(answers, output_file_q)
            answers_dict[q] = answers
    answers_df = pd.DataFrame(answers_dict, index=texts)
    answers_df.to_pickle(join(output_dir_clean, f'{story_fname}.pkl'))
    answers_df.to_csv(join(output_dir_clean, f'{story_fname}.csv'))

    # spot check
    q = 'Does the text reference a person’s name?'
    print('these should be names', list(df[df[q] > 0][q].index))