In [1]:
import os

# move to project root
while True:
    # get list of directories
    dirs = os.listdir()
    if "README.md" in dirs:
        break
    else:
        os.chdir("..")
print(os.getcwd())

/home/ra/Codes/multilang_timescale


In [2]:
import pickle
import json

import time
import logging

import warnings

import joblib

import numpy as np
import pandas as pd

from scipy.stats import zscore

# from himalaya.ridge import RidgeCV, Ridge

from himalaya.backend import set_backend
from himalaya.kernel_ridge import (
    Kernelizer,
    ColumnKernelizer,
    MultipleKernelRidgeCV,
    WeightedKernelRidge,
)
from himalaya.scoring import r2_score_split, correlation_score_split


from sklearn.model_selection import KFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler

from voxelwise_tutorials.delayer import Delayer
from voxelwise_tutorials.utils import explainable_variance
from voxelwise_tutorials.viz import (
    plot_2d_flatmap_from_mapper,
    plot_flatmap_from_mapper,
)

import matplotlib.pyplot as plt

In [3]:

from src.utils import make_delayed, load_dict
from src.config import train_stories, test_stories, timescales
from src.settings import TrainerConfig, SubjectConfig, FeatureConfig
from src.trainer import Trainer

# Config

In [4]:
config_dir = ".temp/config/"
config_subject_dir = ".temp/config/subject/"
config_train_dir = ".temp/config/train/"
config_feature_dir = ".temp/config/feature"

if not os.path.exists(config_subject_dir):
    os.makedirs(config_dir)

if not os.path.exists(config_train_dir):
    os.makedirs(config_train_dir)

if not os.path.exists(config_feature_dir):
    os.makedirs(config_feature_dir)

In [5]:
lm_feature_type = "BERT"
subject_id = "07"

In [6]:
sub_config = SubjectConfig()

sub_config.sub_id = subject_id 

sub_config.sub_fmri_train_path = (
    f"/media/data/dataset/timescale/fmri/deniz2019/subject{subject_id}_reading_fmri_data_trn.hdf"
)
sub_config.sub_fmri_test_path = (
    f"/media/data/dataset/timescale/fmri/deniz2019/subject{subject_id}_reading_fmri_data_val.hdf"
)
sub_config.sub_fmri_mapper_path = (
    f"/media/data/dataset/timescale/fmri/deniz2019/subject{subject_id}_mappers.hdf"
)
sub_config.task = "reading"

In [7]:
# save to json
with open(
    config_subject_dir + f"subject-{sub_config.sub_id}-{sub_config.task}.json", "w"
) as fp:
    json.dump(sub_config.__dict__, fp, indent=4)

In [8]:
def generate_feature_config(lm_feature_type=lm_feature_type):
    timescales = [
        "2_4_words",
        "4_8_words",
        "8_16_words",
        "16_32_words",
        "32_64_words",
        "64_128_words",
        "128_256_words",
        "256+ words",
    ]

    if lm_feature_type == "BERT":
        lm_feature_path = (
            "/media/data/dataset/timescale/features/en/timescales_BERT_all.npz"
        )
    else:
        lm_feature_path = (
            "/media/data/dataset/timescale/features/en/timescales_mBERT_all.npz"
        )

    dir = os.path.join(config_train_dir, lm_feature_type.lower())
    if not os.path.exists(dir):
        os.makedirs(dir)

    f_dir = os.path.join(config_feature_dir, lm_feature_type.lower())
    if not os.path.exists(f_dir):
        os.makedirs(f_dir)

    for t in timescales:
        # Feature Config
        feature_config = FeatureConfig()

        feature_config.timescale = t

        feature_config.lm_feature_path = lm_feature_path

        feature_config.sensory_feature_train_paths = (
            "/media/data/dataset/timescale/features/en/features_trn_NEW.hdf"
        )
        feature_config.sensory_feature_test_paths = (
            "/media/data/dataset/timescale/features/en/features_val_NEW.hdf"
        )
        feature_config.sensory_features = ["numletters", "numwords"]

        feature_config.motion_energy_feature_paths = (
            "/media/data/dataset/timescale/features/en/m_ll.npz"
        )
        feature_config.motion_energy_features = ["7"]

        ## saving config
        fn = os.path.join(f_dir, t + "-feature_config.json")
        with open(fn, "w") as fp:
            json.dump(feature_config.__dict__, fp, indent=4)

In [9]:
generate_feature_config(lm_feature_type="BERT")
generate_feature_config(lm_feature_type="mBERT")

In [10]:

def generate_trainer_config(lm_feature_type:str, n_iter = 1000, appendage :str = "trainer_config"):
    # Trainer Config
    trainer_config = TrainerConfig()

    trainer_config.backend = "torch_cuda"
     
    trainer_config.fit_on_mask = True
    
    trainer_config.n_iter = n_iter
    trainer_config.n_targets_batch  = 4096
    trainer_config.n_targets_batch_refit = 1024

    trainer_config.hyperparams_save_dir = (
        f".temp/results/{lm_feature_type.lower()}/hyperparams"
    )

    trainer_config.stats_save_dir = f".temp/results/{lm_feature_type.lower()}/stats"
    ## saving config
    fn = os.path.join(config_train_dir, f"{lm_feature_type.lower()}_{appendage}.json")
    with open(fn, "w") as fp:
        json.dump(trainer_config.__dict__, fp, indent=4)

In [11]:
generate_trainer_config(lm_feature_type="BERT")
generate_trainer_config(lm_feature_type="mBERT")

In [12]:
generate_trainer_config(lm_feature_type="BERT", n_iter=100, appendage="trainer_config_100")
generate_trainer_config(lm_feature_type="mBERT", n_iter=100, appendage="trainer_config_100")

# Training

# Test training

In [None]:
train_config_path = ".temp/config/train/bert_trainer_config.json"
feature_config_path = ".temp/config/feature/bert/2_4_words-feature_config.json"
sub_config_path = ".temp/config/subject/subject-07-reading.json"

In [None]:
# load config
with open(train_config_path) as f:
    train_config = json.load(f)
train_config = TrainerConfig(**train_config)

with open(feature_config_path) as f:
    feature_config = json.load(f)
feature_config = FeatureConfig(**feature_config)

with open(sub_config_path) as f:
    sub_config = json.load(f)
sub_config = SubjectConfig(**sub_config)


In [None]:
trainer = Trainer(sub_config=sub_config, feature_config=feature_config)

In [None]:
# train_config.n_iter = 10
trainer.train(trainer_config=train_config)

In [None]:
trainer.refit_and_evaluate(trainer_config=train_config)