In [1]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
from os.path import join
from tqdm import tqdm
import pandas as pd
from os.path import expanduser
import sys
from typing import List
import numpy as np
import joblib
from pprint import pprint
import imodelsx.util
from os.path import dirname
import pickle as pkl
import json
from copy import deepcopy
from numpy.linalg import norm
from math import ceil
from imodelsx.qaemb.qaemb import QAEmb, get_sample_questions_and_examples
from neuro.treebank.config import STORIES_POPULAR, STORIES_UNPOPULAR, ECOG_DIR

### Save ensemble features

In [None]:
# suffix_qs = ''
suffix_qs = '___qs_35_stable'

# save ensemble feats
settings = ['words', 'sec_1.5', 'sec_3', 'sec_6']
# settings = ['sec_1.5', 'sec_6']
# settings = ['words', 'sec_3']
out_checkpoint = 'ensemble1'
ensemble1 = [
    'mistralai/Mistral-7B-Instruct-v0.2',
    'meta-llama/Meta-Llama-3-8B-Instruct',
    'google/gemma-7b-it',
]

for setting in settings:
    print(setting)
    output_dir_ensemble = join(
        ECOG_DIR, f'features{suffix_qs}', out_checkpoint, setting)
    os.makedirs(output_dir_ensemble, exist_ok=True)

    # read in ensemble feats
    ensemble_checkpoint_story_dict = {}
    for checkpoint in tqdm(ensemble1):
        checkpoint_clean = checkpoint.replace('/', '___')
        output_dir_clean = join(ECOG_DIR, f'features{suffix_qs}',
                                checkpoint_clean, setting)
        story_fnames = os.listdir(output_dir_clean)
        checkpoint_story_dict = {}
        for story_fname in story_fnames:
            if story_fname.endswith('.pkl'):
                checkpoint_story_dict[story_fname] = joblib.load(
                    join(output_dir_clean, story_fname))
        ensemble_checkpoint_story_dict[checkpoint] = deepcopy(
            checkpoint_story_dict)

    # save avg feats
    common_stories = set.intersection(
        *[set(ensemble_checkpoint_story_dict[checkpoint].keys())
            for checkpoint in ensemble1]
    )
    print('\tsaving avg feats for', len(common_stories), 'stories')
    for story_fname in tqdm(common_stories):
        out_fname_pkl = join(output_dir_ensemble, story_fname)
        if not os.path.exists(out_fname_pkl):
            # avg over all checkpoints
            story1_df = ensemble_checkpoint_story_dict[ensemble1[0]][story_fname]
            story2_df = ensemble_checkpoint_story_dict[ensemble1[1]][story_fname]
            story3_df = ensemble_checkpoint_story_dict[ensemble1[2]][story_fname]

            # align the dfs to have same cols and index
            story1_df = story1_df[story2_df.columns]
            assert story1_df.columns.equals(story2_df.columns)
            assert story1_df.index.equals(story2_df.index)

            story2_df = story2_df[story1_df.columns]
            assert story2_df.columns.equals(story1_df.columns)
            assert story2_df.index.equals(story1_df.index)

            # average values
            # avg_df = (story1_df.astype(float) + story2_df.astype(float)) / 2
            avg_df = (story1_df.astype(float) + story2_df.astype(float) +
                      story3_df.astype(float)) / 3

            # save
            avg_df.to_csv(join(output_dir_ensemble,
                               story_fname.replace('.pkl', '.csv')))
            avg_df.to_pickle(out_fname_pkl)

    print('\tavg feats', output_dir_ensemble, os.listdir(output_dir_ensemble))

In [None]:
# !ls /home/chansingh/mntv1/ecog/features/ensemble1/
!ls /home/chansingh/mntv1/ecog/features___qs_35_stable/ensemble1/

In [None]:
d = '/home/chansingh/mntv1/ecog/features___qs_35_stable/ensemble1/words'
n = 0
for k in os.listdir(d):
    if not k.endswith('.pkl'):
        continue
    df = pd.read_pickle(join(d, k))
    n += df.shape[0]
    # print(df.head())
n

In [None]:
# !rclone copy /home/chansingh/mntv1/ecog/features/ensemble1/ box:DeepTune/QA/cached_qa_tree --progress
!rclone copy /home/chansingh/mntv1/ecog/features___qs_35_stable/ensemble1/ box:DeepTune/QA/cached_qa_tree___qs_35_stable --progress

# Look at question answers

In [14]:
!ls /home/chansingh/mntv1/ecog/features_raw___qs_35_stable/gpt-4o-mini/sec_6 | wc

     32     274    2180


In [6]:
suffix_qs = '___qs_35_stable'

# save ensemble feats
settings = ['sec_1.5']
# settings = ['words']
# settings = ['sec_6']
ensemble1 = [
    # 'mistralai/Mistral-7B-Instruct-v0.2',
    # 'meta-llama/Meta-Llama-3-8B-Instruct',
    # 'google/gemma-7b-it',
    'gpt-4o-mini',
]
# story_fnames = ['ant-man.pkl']
story_fnames = ['lotr-1.pkl']

for setting in settings:
    print(setting)

    # read in ensemble feats
    ensemble_checkpoint_story_dict = {}
    for checkpoint in tqdm(ensemble1):
        checkpoint_clean = checkpoint.replace('/', '___')
        output_dir_clean = join(ECOG_DIR, f'features{suffix_qs}',
                                checkpoint_clean, setting)
        checkpoint_story_dict = {}
        for story_fname in story_fnames:
            if story_fname.endswith('.pkl'):
                checkpoint_story_dict[story_fname] = joblib.load(
                    join(output_dir_clean, story_fname))
        ensemble_checkpoint_story_dict[checkpoint] = deepcopy(
            checkpoint_story_dict)

sec_1.5


  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.07it/s]


In [7]:
easy_qs = ['Is time mentioned in the input?',
           'Does the input contain a measurement?', 'Does the input contain a number?']
for checkpoint in ensemble1:
    print(checkpoint)
    df = ensemble_checkpoint_story_dict[checkpoint][story_fnames[0]][easy_qs]
    for k in df.columns:
        print(k, df[k][df[k] > 0].index.tolist())
    print()

gpt-4o-mini
Is time mentioned in the input? ['It began', 'seven to', 'It was in this moment', 'moment when', 'forever But', 'History', 'and for two and a half', 'half thousand years', 'thousand years the', 'thousand years the ring', 'years the ring passed', 'until', 'until when', 'until when chance', 'For five hundred years', 'five hundred years it', 'hundred years it poisoned', 'hundred years it poisoned his', 'it s time', 'it s time had', 's time had now', 'time had now come', 'something happened then', 'For the time', 'For the time will', 'For the time will soon', 'time will soon come', 'Twenty second', 'September in', 'September in the', 'September in the year', 'in the year fourteen', 'year fourteen hundred', 'The Third Age', 'The Third Age of', 'The Third Age of this', 'Third Age of this world', 'Now', 'but today', 'today of all days', 'days it', 'days it is', 'Now', "You 're late", 'A wizard is never late', 'He arrives precisely', 'arrives precisely when', 'arrives precisely whe

In [5]:
df.head(500)

Unnamed: 0,Is time mentioned in the input?,Does the input contain a measurement?,Does the input contain a number?
The,False,False,False
world,False,False,False
is,False,False,False
changed,False,False,False
I,False,False,False
...,...,...,...
fourteen,False,False,False
hundred,False,True,False
by,False,False,False
Shire,False,False,False


In [20]:
df.shape

(9869, 3)

In [None]:
10_000 * 35

350000