In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
from os.path import join
from tqdm import tqdm
import pandas as pd
from os.path import expanduser
import sys
from typing import List
import numpy as np
import joblib
from pprint import pprint
import imodelsx.util
from os.path import dirname
import pickle as pkl
import json
from copy import deepcopy
from numpy.linalg import norm
from math import ceil
from imodelsx.qaemb.qaemb import QAEmb, get_sample_questions_and_examples
from neuro.treebank.config import STORIES_POPULAR, STORIES_UNPOPULAR, ECOG_DIR

In [None]:
story_fname = 'cars-2'
features_df = pd.read_csv(
    join(ECOG_DIR, 'data', 'transcripts', story_fname, 'features.csv'))
sec_window = 3
ngram_list = []
for i in tqdm(range(0, len(features_df))):
    row = features_df.iloc[i]
    time_end = row['end']
    time_start = time_end - sec_window
    ngram = features_df[(features_df['end'] >= time_start) & (
        features_df['end'] <= time_end)]['text'].values.tolist()
    ngram_list.append(ngram)
features_df['ngram'] = ngram_list

### Look at story

In [107]:
# remove duplicate consecutive values
df = features_df.loc[features_df['sentence'].shift() !=
                     features_df['sentence']]

# set speaker to '' for duplicate consecutive values
duplicate_speaker = df['speaker'].shift() == df['speaker']
df = df[df['sentence'].notna()]
df.loc[duplicate_speaker, 'speaker'] = ''

# numbered sentences
# df.loc[~duplicate_speaker, 'speaker'] = '<' + df['speaker'] + '>:\n'
# df['sentence_idx'] = np.arange(len(df)) + 1
# df['script'] = df['speaker'] + \
# df['sentence_idx'].astype(str) + '. ' + df['sentence']
# story = '\n'.join(df['script'].iloc[:50])

# unnumbered
df.loc[~duplicate_speaker, 'speaker'] = '\n<' + df['speaker'] + '>: '
df['script'] = df['speaker'] + df['sentence']
story = ' '.join(df['script'].iloc[:500])

In [None]:
print(story)

### Get popular stories

In [2]:
subject_metadata_files = os.listdir(join(ECOG_DIR, 'data', 'subject_metadata'))
jsons = {f: json.load(open(join(ECOG_DIR, 'data', 'subject_metadata', f)))
         ['title'] for f in subject_metadata_files}
df = pd.DataFrame(jsons.values(), index=jsons.keys()).reset_index()
df.rename(columns={'index': 'filename', 0: 'title'}, inplace=True)
df['subject'] = df['filename'].apply(lambda x: x.split('_trial')[
                                     0].split('_')[-1]).astype(int)
# df = df.sort_values(by='subject')
df = df.sort_values(by='title')

common_subjs = {3, 4, 6, 7, 10}
stories_popular = df[df.subject.isin(common_subjs)].title.unique()
stories_unpopular = df[~df.subject.isin(common_subjs)].title.unique()
print(f'{len(stories_popular)=} {len(stories_unpopular)=}')
stories_popular, stories_unpopular

len(stories_popular)=9 len(stories_unpopular)=12


(array(['Cars 2', 'Coraline', 'Lord Of The Rings 1', 'Lord Of The Rings 2',
        'Megamind', 'Shrek The Third', 'Spiderman Far From Home',
        'The Incredibles', 'Toy Story'], dtype=object),
 array(['Antman', 'Aquaman', 'Avengers Infinity War', 'Black Panther',
        'Fantastic Mr. Fox', 'Guardians Of The Galaxy 2',
        'Guardians Of the Galaxy', 'Sesame Street Episode 3990',
        'Spiderman Homecoming', 'The Martian', 'Thor Ragnarok', 'venom'],
       dtype=object))

### Save ensemble features

In [2]:
# setting = 'words'

# output_dir_ensemble = join(ECOG_DIR, 'features', checkpoint_clean, setting)
# output_dir_raw = join(ECOG_DIR, 'features_raw', checkpoint_clean, setting)
# os.makedirs(output_dir_ensemble, exist_ok=True)
# os.makedirs(output_dir_raw, exist_ok=True)

In [None]:
# save ensemble feats
for setting in ['words', 'sec_3']:
    out_checkpoint = 'ensemble1'
    ensemble1 = [
        'mistralai/Mistral-7B-Instruct-v0.2',
        'meta-llama/Meta-Llama-3-8B-Instruct',
        'google/gemma-7b-it',
    ]
    output_dir_ensemble = join(ECOG_DIR, 'features', out_checkpoint, setting)
    os.makedirs(output_dir_ensemble, exist_ok=True)

    # read in ensemble feats
    ensemble_checkpoint_story_dict = {}
    for checkpoint in tqdm(ensemble1):
        checkpoint_clean = checkpoint.replace('/', '___')
        output_dir_clean = join(ECOG_DIR, 'features',
                                checkpoint_clean, setting)
        story_fnames = os.listdir(output_dir_clean)
        checkpoint_story_dict = {}
        for story_fname in story_fnames:
            if story_fname.endswith('.pkl'):
                checkpoint_story_dict[story_fname] = joblib.load(
                    join(output_dir_clean, story_fname))
        ensemble_checkpoint_story_dict[checkpoint] = deepcopy(
            checkpoint_story_dict)

    # save avg feats
    common_stories = set.intersection(
        *[set(ensemble_checkpoint_story_dict[checkpoint].keys())
            for checkpoint in ensemble1]
    )
    for story_fname in tqdm(common_stories):
        # avg over all checkpoints
        story1_df = ensemble_checkpoint_story_dict[ensemble1[0]][story_fname]
        story2_df = ensemble_checkpoint_story_dict[ensemble1[1]][story_fname]
        story3_df = ensemble_checkpoint_story_dict[ensemble1[2]][story_fname]

        # align the dfs to have same cols and index
        story1_df = story1_df[story2_df.columns]
        assert story1_df.columns.equals(story2_df.columns)
        assert story1_df.index.equals(story2_df.index)

        story2_df = story2_df[story1_df.columns]
        assert story2_df.columns.equals(story1_df.columns)
        assert story2_df.index.equals(story1_df.index)

        # average values
        # avg_df = (story1_df.astype(float) + story2_df.astype(float)) / 2
        avg_df = (story1_df.astype(float) + story2_df.astype(float) +
                  story3_df.astype(float)) / 3

        # save
        avg_df.to_pickle(join(output_dir_ensemble, story_fname))
        avg_df.to_csv(join(output_dir_ensemble,
                           story_fname.replace('.pkl', '.csv')))
    print('avg feats', output_dir_ensemble, os.listdir(output_dir_ensemble))

In [None]:
!ls /home/chansingh/mntv1/ecog/features/ensemble1/

In [None]:
!rclone copy /home/chansingh/mntv1/ecog/features/ensemble1/ box:DeepTune/QA/cached_qa_tree --progress