In [None]:
%load_ext autoreload
%autoreload 2
import os
import matplotlib.pyplot as plt
from os.path import join
from tqdm import tqdm
import pandas as pd
from os.path import expanduser
import sys
from typing import List
import numpy as np
import joblib
from pprint import pprint
import imodelsx.util
from os.path import dirname
import pickle as pkl
import json
from copy import deepcopy
from numpy.linalg import norm
from math import ceil
# from imodelsx.qaemb.qaemb import QAEmb, get_sample_questions_and_examples
from neuro.ecog.config import STORIES_POPULAR, STORIES_UNPOPULAR, ECOG_DIR

### Save ensemble features

In [None]:
# suffix_qs = ''
suffix_qs = '___qs_35_stable'

# save ensemble feats
settings = ['words', 'sec_1.5', 'sec_3', 'sec_6']
# settings = ['sec_1.5', 'sec_6']
# settings = ['words', 'sec_3']
# out_checkpoint = 'ensemble1'
ensemble1 = [
    'mistralai/Mistral-7B-Instruct-v0.2',
    'meta-llama/Meta-Llama-3-8B-Instruct',
    'google/gemma-7b-it',
]

for setting in settings:
    print(setting)
    output_dir_ensemble = join(
        ECOG_DIR, f'features{suffix_qs}', out_checkpoint, setting)
    os.makedirs(output_dir_ensemble, exist_ok=True)

    # read in ensemble feats
    ensemble_checkpoint_story_dict = {}
    for checkpoint in tqdm(ensemble1):
        checkpoint_clean = checkpoint.replace('/', '___')
        output_dir_clean = join(ECOG_DIR, f'features{suffix_qs}',
                                checkpoint_clean, setting)
        story_fnames = os.listdir(output_dir_clean)
        checkpoint_story_dict = {}
        for story_fname in story_fnames:
            if story_fname.endswith('.pkl'):
                checkpoint_story_dict[story_fname] = joblib.load(
                    join(output_dir_clean, story_fname))
        ensemble_checkpoint_story_dict[checkpoint] = deepcopy(
            checkpoint_story_dict)

    # save avg feats
    common_stories = set.intersection(
        *[set(ensemble_checkpoint_story_dict[checkpoint].keys())
            for checkpoint in ensemble1]
    )
    print('\tsaving avg feats for', len(common_stories), 'stories')
    for story_fname in tqdm(common_stories):
        out_fname_pkl = join(output_dir_ensemble, story_fname)
        if not os.path.exists(out_fname_pkl):
            # avg over all checkpoints
            story1_df = ensemble_checkpoint_story_dict[ensemble1[0]][story_fname]
            story2_df = ensemble_checkpoint_story_dict[ensemble1[1]][story_fname]
            story3_df = ensemble_checkpoint_story_dict[ensemble1[2]][story_fname]

            # align the dfs to have same cols and index
            story1_df = story1_df[story2_df.columns]
            assert story1_df.columns.equals(story2_df.columns)
            assert story1_df.index.equals(story2_df.index)

            story2_df = story2_df[story1_df.columns]
            assert story2_df.columns.equals(story1_df.columns)
            assert story2_df.index.equals(story1_df.index)

            # average values
            # avg_df = (story1_df.astype(float) + story2_df.astype(float)) / 2
            avg_df = (story1_df.astype(float) + story2_df.astype(float) +
                      story3_df.astype(float)) / 3

            # save
            avg_df.to_csv(join(output_dir_ensemble,
                               story_fname.replace('.pkl', '.csv')))
            avg_df.to_pickle(out_fname_pkl)

    print('\tavg feats', output_dir_ensemble, os.listdir(output_dir_ensemble))

In [None]:
# !ls /home/chansingh/mntv1/ecog/features/ensemble1/
# !ls /home/chansingh/mntv1/ecog/features___qs_35_stable/ensemble1/
!ls /home/chansingh/mntv1/ecog/features___qs_35_stable/gpt-4o-mini/sec_3

In [None]:
d = '/home/chansingh/mntv1/ecog/features___qs_35_stable/ensemble1/words'
n = 0
for k in os.listdir(d):
    if not k.endswith('.pkl'):
        continue
    df = pd.read_pickle(join(d, k))
    n += df.shape[0]
    # print(df.head())
n

In [None]:
# !rclone copy /home/chansingh/mntv1/ecog/features/ensemble1/ box:DeepTune/QA/cached_qa_tree_ensemble1 --progress
# !rclone copy /home/chansingh/mntv1/ecog/features___qs_35_stable/ensemble1/ box:DeepTune/QA/cached_qa_tree___qs_35_stable_ensemble1 --progress
!rclone copy /home/chansingh/mntv1/ecog/features___qs_35_stable/gpt-4o-mini/ box:DeepTune/QA/cached_qa_tree___qs_35_stable_gpt-4o-mini --progress

# Look at question answers

In [None]:
suffix_qs = '___qs_35_stable'

# save ensemble feats
settings = ['sec_1.5']
# settings = ['words']
# settings = ['sec_6']
ensemble1 = [
    # 'mistralai/Mistral-7B-Instruct-v0.2',
    # 'meta-llama/Meta-Llama-3-8B-Instruct',
    # 'google/gemma-7b-it',
    'gpt-4o-mini',
]
# story_fnames = ['ant-man.pkl']
# story_fnames = ['lotr-1.pkl']
story_fnames = ['___podcasts-story___.pkl']

for setting in settings:
    print(setting)

    # read in ensemble feats
    ensemble_checkpoint_story_dict = {}
    for checkpoint in tqdm(ensemble1):
        checkpoint_clean = checkpoint.replace('/', '___')
        output_dir_clean = join(ECOG_DIR, f'features{suffix_qs}',
                                checkpoint_clean, setting)
        checkpoint_story_dict = {}
        for story_fname in story_fnames:
            if story_fname.endswith('.pkl'):
                checkpoint_story_dict[story_fname] = joblib.load(
                    join(output_dir_clean, story_fname))
        ensemble_checkpoint_story_dict[checkpoint] = deepcopy(
            checkpoint_story_dict)

In [None]:
easy_qs = ['Is time mentioned in the input?',
           'Does the input contain a measurement?', 'Does the input contain a number?']
for checkpoint in ensemble1:
    print(checkpoint)
    df = ensemble_checkpoint_story_dict[checkpoint][story_fnames[0]][easy_qs]
    for k in df.columns:
        print(k, df[k][df[k] > 0].index.tolist())
    print()

In [None]:
df.head(500)

In [None]:
df.shape

In [None]:
10_000 * 35