In [2]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import random
import json
import gc

import numpy as np
import pandas as pd
import torch
import torchio as tio
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import nibabel as nib
from einops import rearrange
from scipy import ndimage
from fracridge import FracRidgeRegressorCV
from sklearn.model_selection import KFold


dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path: 
    sys.path.append(dir1)

In [3]:
dataset_path = Path('D:\\Datasets\\NSD\\')
derivatives_path = dataset_path / 'derivatives'
betas_path = dataset_path / 'nsddata_betas' / 'ppdata'
ppdata_path = dataset_path / 'nsddata' / 'ppdata'

In [4]:
# Load image ids of the shared 1000 images across all participants
shared_1000_path = dataset_path / 'nsddata' / 'stimuli' / 'nsd' / 'shared1000.tsv'
shared_1000 = pd.read_csv(shared_1000_path, sep='\t', header=None)
shared_1000 = set(shared_1000[0])

In [5]:
subjects = {f'subj0{i}': {} for i in range(1, 9)}

for subject_name, subject_data in subjects.items():
    responses_file_path = ppdata_path / subject_name / 'behav' / 'responses.tsv'
    subject_data['responses'] = pd.read_csv(responses_file_path, sep='\t',)
    
    # The last 3 sessions are currently held-out for the algonauts challenge
    # remove them for now.
    session_ids = subject_data['responses']['SESSION']
    held_out_mask = session_ids > (np.max(session_ids) - 3)
    subject_data['responses'] = subject_data['responses'][~held_out_mask]
    
    subject_betas_path = betas_path / subject_name / 'func1pt8mm' / 'betas_fithrf_GLMdenoise_RR'
    num_sessions = np.max(subject_data['responses']['SESSION'])
    
    subject_data['sessions'] = [
        h5py.File(subject_betas_path / f'betas_session{i:02}.hdf5', 'r')
        for i in range(1, num_sessions + 1)
    ]
    
    #subject_data['betas'] = h5py.File(subject_betas_path / f'betas_sessions.hdf5', 'r')
    
    subject_data['brainmask'] = nib.load(ppdata_path / subject_name / 'func1pt8mm' / 'brainmask.nii.gz')
    subject_data['t1_path'] = ppdata_path / subject_name / 'func1pt8mm' / 'T1_to_func1pt8mm.nii.gz'

In [None]:
# Define a train-test-validation split
split_name = 'split-01'
N_test = 1000
N_validation = 1000
N_non_shared = N_test - len(shared_1000_three_repetitions)
seed = 0

for subject_name, subject_data in subjects.items():
    responses = subject_data['responses']
    
    image_ids = responses['73KID'].to_numpy()
    unique_image_ids, unique_counts = np.unique(image_ids, return_counts=True)
    three_repetition_ids = unique_image_ids[unique_counts == 3]
    subject_data['three_repetition_ids'] = set(three_repetition_ids)
    print(f'{subject_name} {image_ids.shape=}, {len(three_repetition_ids)=}')
    
shared_1000_three_repetitions = set.intersection(
    shared_1000,
    *[subject_data['three_repetition_ids']
    for subject_data in subjects.values()]
)
print(f'{len(shared_1000_three_repetitions)=}')

for subject_name, subject_data in subjects.items():
    three_repetition_ids = subject_data['three_repetition_ids']
    non_shared_three_repetition_ids = list(three_repetition_ids - shared_1000_three_repetitions)
    random.Random(seed).shuffle(non_shared_three_repetition_ids)
    
    test_image_ids = list(shared_1000_three_repetitions) + non_shared_three_repetition_ids[:N_non_shared]
    validation_image_ids = non_shared_three_repetition_ids[N_non_shared:(N_non_shared + N_validation)]
    subject_data['test_image_ids'] = np.array(test_image_ids)
    subject_data['validation_image_ids'] = np.array(test_image_ids)
    
    test_image_ids = set(test_image_ids)
    validation_image_ids = set(validation_image_ids)
    image_ids = subject_data['responses']['73KID'].to_numpy()
    subject_data['test_response_ids'] = np.argwhere([image_id in test_image_ids for image_id in image_ids])[:, 0]
    subject_data['validation_response_ids'] = np.argwhere([image_id in validation_image_ids for image_id in image_ids])[:, 0]
    
with h5py.File(derivatives_path / 'data_splits' / f'{split_name}.hdf5', 'w') as f:
    for subject_name, subject_data in subjects.items():
        subject = f.require_group(subject_name)
        
        three_repetition_ids = subject_data['three_repetition_ids']
        non_shared_three_repetition_ids = list(three_repetition_ids - shared_1000_three_repetitions)
        random.Random(seed).shuffle(non_shared_three_repetition_ids)

        test_image_ids = list(shared_1000_three_repetitions) + non_shared_three_repetition_ids[:N_non_shared]
        validation_image_ids = non_shared_three_repetition_ids[N_non_shared:(N_non_shared + N_validation)]
        subject['test_image_ids'] = np.array(test_image_ids)
        subject['validation_image_ids'] = np.array(test_image_ids)

        test_image_ids = set(test_image_ids)
        validation_image_ids = set(validation_image_ids)
        image_ids = subject_data['responses']['73KID'].to_numpy()
        subject['test_response_mask'] = np.array([image_id in test_image_ids for image_id in image_ids], dtype=bool)
        subject['validation_response_mask'] = np.array([image_id in validation_image_ids for image_id in image_ids], dtype=bool)


In [None]:
# Torch implementation of fractional ridge regression

def pearsonr(X, Y, dim=0, cast_dtype=torch.float64):
    in_dtype = X.dtype
    X = X.to(cast_dtype)
    Y = Y.to(cast_dtype)

    X = X - X.mean(dim=dim, keepdim=True)
    Y = Y - Y.mean(dim=dim, keepdim=True)

    X = X / torch.norm(X, dim=dim, keepdim=True)
    Y = Y / torch.norm(Y, dim=dim, keepdim=True)

    r = torch.tensordot(X, Y, dims=dim).to(in_dtype)
    return r


def rsquared(Y, Y_pred, dim=0, cast_dtype=torch.float64):
    in_dtype = Y.dtype
    Y = Y.to(cast_dtype)
    Y_pred = Y_pred.to(cast_dtype)

    ss_res = ((Y - Y_pred) ** 2).sum(dim=dim)
    ss_tot = ((Y - Y.mean(dim=dim, keepdim=True)) ** 2).sum(dim=dim)

    r2 = 1 - ss_res / ss_tot
    return r2

from scipy.linalg import svd
from functools import partial
svd = partial(svd, full_matrices=False)
from numpy import interp

def frac_ridge_regression(X, Y, fractions=None, tol=1e-6):
    BIG_BIAS = 10e3
    SMALL_BIAS = 10e-3
    BIAS_STEP = 0.2
    
    if fractions is None:
        fractions = torch.arange(.1, 1.1, .1)
    
    U, S, Vt = torch.linalg.svd(X, full_matrices=False)
    Y_new = U.transpose(-1, -2) @ Y
    ols_coef = (Y_new.transpose(-1, -2) / S[..., None, :]).transpose(-1, -2)
    
    S_small = torch.broadcast_to(S < tol, ols_coef.shape[:-1])
    ols_coef[S_small, ...] = 0.
    
    val1 = BIG_BIAS * S[..., 0] ** 2
    val2 = SMALL_BIAS * S[..., -1] ** 2

    grid_low = torch.floor(torch.log10(val2))
    grid_high = torch.ceil(torch.log10(val1))
    steps = int(torch.max(grid_high - grid_low).item() / BIAS_STEP)
    alphagrid = 10 ** torch.stack([
        (i / steps) * grid_high + (1 - i / steps) * grid_low
        for i in range(steps)
    ])
    alphagrid = torch.cat([torch.zeros_like(grid_low)[None], alphagrid])
    
    S_squared = S ** 2
    scaling = S_squared / (S_squared + alphagrid[..., None])
    scaling_squared = scaling ** 2

    newlen = torch.sqrt(torch.einsum('g ... p, ... p b -> g ... b', scaling_squared, ols_coef ** 2))
    newlen = (newlen / newlen[0])
    
    while len(fractions.shape) < len(newlen.shape):
        fractions = fractions[:, None]
    
    threshold = fractions[None, :] < newlen[:, None]
    threshold = (threshold[1:] != threshold[:-1]).int()
    threshold = threshold.argmax(dim=0)

    newlen_high = torch.gather(newlen, 0, threshold)
    newlen_low = torch.gather(newlen, 0, threshold + 1)
    
    t = (newlen_high - fractions) / (newlen_high - newlen_low)
    log_alphagrid = torch.log(1 + alphagrid)
    log_alphagrid = torch.broadcast_to(log_alphagrid[..., None], newlen.shape)

    alpha_high = torch.gather(log_alphagrid, 0, threshold)
    alpha_low = torch.gather(log_alphagrid, 0, threshold + 1)
    alpha = (1. - t) * alpha_high + t * alpha_low
    alpha = torch.exp(alpha) - 1.
    
    sc = S_squared / (S_squared + rearrange(alpha, 'f ... b -> f b ... 1'))
    coef = sc * rearrange(ols_coef, '... p b -> 1 b ... p')
    
    coef = torch.einsum('... p i, f b ... p -> ... f i b', Vt, coef)
    alpha = rearrange(alpha, 'f ... b -> ... f b')
    
    return coef, alpha

def predict(X, coef):
    return torch.einsum('... n p, ... f p b -> ... f n b', X, coef)


def frac_ridge_regression_cv(X, Y, fractions=None, cv=5, tol=1e-6):
    if fractions is None:
        fractions = torch.arange(.1, 1.1, .1, device=X.device)
    
    kf = KFold(n_splits=cv, shuffle=True)
    r2 = []
    for train_ids, val_ids in kf.split(np.arange(X.shape[-2])):
        X_train = X[..., train_ids, :]
        Y_train = Y[..., train_ids, :]
        X_val = X[..., val_ids, :]
        Y_val = Y[..., val_ids, :]
        
        coef, alpha = frac_ridge_regression(X_train, Y_train, fractions)
        Y_val_pred = predict(X_val, coef)
        r2.append(rsquared(Y_val[..., None, :, :], Y_val_pred, dim=-2))
    r2 = torch.stack(r2).mean(dim=0)
    best_r2, best_fraction_ids = r2.max(dim=-2)
    best_fractions = fractions[best_fraction_ids]
    
    best_coefs, best_alpha = frac_ridge_regression(X_train, Y_train, best_fractions[None])
    return best_coefs[..., 0, :, :], best_alpha[..., 0, :], best_r2, best_fractions


N = 20000
batch_size = 100
embedding_size = 512
t = time.time()
for i in tqdm(range(N // batch_size)):
    X = torch.randn(21750, embedding_size).cuda()
    Y = torch.randn(21750, batch_size).cuda()
    fractions = torch.arange(.05, 1.05, .05).cuda()
    coefs, alpha, r2, frac = frac_ridge_regression_cv(X, Y, fractions)

'''
for i in tqdm(range(N // batch_size)):
    break
    X = np.random.randn(21750, embedding_size)
    Y = np.random.randn(21750, batch_size)
    fractions = np.arange(.05, 1.05, .05)
    model = FracRidgeRegressorCV(frac_grid=fractions)
    model.fit(X, Y)
    print(model.best_frac_)'''

In [None]:
# run an encoder

split_name = 'split-01'
split = h5py.File(derivatives_path / 'data_splits' / f'{split_name}.hdf5')
fractions = torch.arange(.05, 1.05, .05).cuda()
#alpha = 10 ** torch.linspace(1, 5, 20).cuda()

run_stimulus_embeddings = {
    'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': ['embedding',],# *(f'transformer.resblocks.{i}' for i in range(12))],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

embedding_files = {
    model_name: h5py.File(derivatives_path / 'stimulus_embeddings' / f'{model_name}-embeddings.hdf5', 'r')
    for model_name in run_stimulus_embeddings.keys()
}

with h5py.File(derivatives_path / 'fracridge-parameters.hdf5', 'a') as f:
    for subject_name, subject_data in subjects.items():
        sessions = subject_data['sessions']
        num_sessions = len(sessions)
        shape = sessions[0]['betas'].shape
        T, W, H, D = shape
        
        subject_split = split[subject_name]
        test_mask = subject_split['test_response_mask'][:].astype(bool)
        validation_mask = subject_split['validation_response_mask'][:].astype(bool)
        training_mask = ~(test_mask | validation_mask)
        training_indices = np.where(training_mask)[0]

        responses = subject_data['responses']
        training_stimulus_ids = responses['73KID'].to_numpy()[training_indices] - 1

        subject_stimulus_embeddings = {
            model_name : {
                embedding_name: torch.from_numpy(embedding_files[model_name][embedding_name][:][training_stimulus_ids]).clone().cuda()
                for embedding_name in embedding_names
            }
            for model_name, embedding_names in run_stimulus_embeddings.items()
        }

        for model_name, embeddings in subject_stimulus_embeddings.items():
            for embedding_name, embedding in embeddings.items():
                keys = (subject_name, model_name, embedding_name)
                key = '/'.join(keys)
                E = embedding.shape[-1]
                group = f.require_group(key)
                group.require_dataset('coefs', shape=(W, H, D, E), dtype='f4')
                group.require_dataset('alpha', shape=(W, H, D), dtype='f4')
                group.require_dataset('r2', shape=(W, H, D), dtype='f4')
                group.require_dataset('fractions', shape=(W, H, D), dtype='f4')
                
                embedding_mean = embedding.mean(dim=0, keepdims=True)
                embedding_std = embedding.std(dim=0, keepdims=True)
                embeddings[embedding_name] = (embedding - embedding_mean) / embedding_std
                
                group['embedding_mean'] = embedding_mean.cpu().numpy()
                group['embedding_std'] = embedding_std.cpu().numpy()

        mask = (subjects[subject_name]['brainmask'].get_fdata() > 0.).T
        
        load_time = 0
        compute_time = 0
        store_time = 0
        
        for i in tqdm(range(W)):
            if (i + 1) % 5 == 0:
                print(f'{load_time=:03}s, {compute_time=:03}s, {store_time=:03}s')
    
            mask_slice = mask[i]
            if mask_slice.sum() == 0:
                continue
            
            t = time.time()
            Y = np.concatenate([
                session['betas'][:, i] 
                for session in sessions
            ])
            Y = Y[training_indices]
            Y = Y[:, mask_slice]
            Y = torch.from_numpy(Y).float() / 300
            Y = (Y - Y.mean(dim=0, keepdims=True)) / Y.std(dim=0, keepdims=True)
            #Y = Y.T[..., None]
            
            slice_indices = torch.from_numpy(np.argwhere(mask_slice))
            load_time += time.time() - t
            
            batch_size = 100
            Y_splits = torch.split(Y, batch_size, dim=1)
            indices_splits = torch.split(slice_indices, batch_size)
            for Y_batch, indices_batch in zip(Y_splits, indices_splits):
                Y_batch = Y_batch.cuda()
                for model_name, embeddings in subject_stimulus_embeddings.items():
                    for embedding_name, embedding in embeddings.items():
                        t = time.time()
                        
                        gc.collect()
                        torch.cuda.empty_cache()

                        keys = (subject_name, model_name, embedding_name)
                        group = f['/'.join(keys)]
                        X = embedding[:]
                        
                        #print(X.shape, Y_batch.shape, fractions.shape)
                        #coefs, alpha, r2 = ridge_regression_cv(X, Y_batch[..., 0].T, alpha=alpha)
                        coefs, alpha, r2, frac = frac_ridge_regression_cv(X, Y_batch, fractions)
                        coefs = coefs.cpu().numpy()
                        a = alpha.cpu().numpy()
                        r2 = r2.cpu().numpy()
                        frac = frac.cpu().numpy()
                        
                        compute_time += time.time() - t

                        t = time.time()
                        for v, (j, k) in enumerate(indices_batch):
                            group['coefs'][i, j, k] = coefs[:, v]
                            group['alpha'][i, j, k] = a[v]
                            group['r2'][i, j, k] = r2[v]
                            group['fractions'][i, j, k] = frac[v]
                        store_time += time.time() - t


In [None]:
# save encoder results

encoder_name = 'fracridge'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r') as f:
    for subject_name, subject in f.items():
        subject_out_path = derivatives_path / 'images' / encoder_name / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)
        affine = subjects[subject_name]['brainmask'].affine
        
        for model_name, model in subject.items():
            for embedding_name, embedding in model.items():
                for image_name, image, in embedding.items():
                    keys = (subject_name, encoder_name, model_name, embedding_name, image_name)
                    save_file_name = f'{"__".join(keys)}.nii.gz'
                    print(keys, image.shape)
                    if len(image.shape) == 3:
                        image = nib.Nifti1Image(image[:].T, affine)
                    elif len(image.shape) == 4:
                        continue
                        image = nib.Nifti1Image(image[:].T, affine)
                    else:
                        continue
                    nib.save(image, subject_out_path / save_file_name)

In [9]:
# Add feature selection indices

encoder_name = 'fracridge'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'a') as f:
    for subject_name, subject in f.items():
        for model_name, model in subject.items():
            for embedding_name, embedding in model.items():
                r2 = embedding['r2'][:]
                sorted_indices_flat = r2.argsort(axis=None)[::-1]
                grid = np.argwhere(np.ones_like(r2, dtype=bool))
                sorted_indices = grid[sorted_indices_flat]
                
                embedding['sorted_indices_flat'] = sorted_indices_flat
                embedding['sorted_indices'] = sorted_indices

In [None]:
encoder_results = h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r')

@interact(subject=encoder_results.items())
def select_subject(subject):
    subject_name = subject.name[1:]
    sessions = subjects[subject_name]['sessions']
    betas = subjects[subject_name]['betas']
    print(sessions[0]['betas'].shape)
    
    @interact(model=subject.items())
    def select_model(model):
        
        @interact(embedding=model.items())
        def select_embedding(embedding):
            
            r2 = embedding['r2'][:]
            sorted_indices_flat = r2.argsort(axis=None)[::-1]
            grid = np.argwhere(np.ones_like(r2, dtype=bool))
            sorted_indices = grid[sorted_indices_flat]
            print(grid)
            t = time.time()
            
            Y = np.stack([
                betas['betas'][:, i] 
                for i in sorted_indices_flat[:2500]
            ], axis=1)
            
            #for i, j, k in sorted_indices[:2500]:
            #    Y = np.concatenate([
            #        session['betas'][:, i, j, k]
            #        for session in sessions
            #    ])
            print(time.time() - t)
            print(Y.shape)
        
            @interact(w=(0, r2.shape[0]-1), num=(0, 50000))
            def show_top(w, num):
                selection_map = np.zeros_like(r2)
                i, j, k = list(sorted_indices[:num].T)
                selection_map[i, j, k] = 1
                plt.imshow(selection_map[w].T)
                

In [10]:
# concatenate sessions

for subject_name, subject_data in subjects.items():
    if subject_name in ('subj01', 'subj02', 'subj03', 'subj04', ):
        continue
    print(subject_name)
    path = derivatives_path / 'betas' / subject_name / 'func1pt8mm' / 'betas_fithrf_GLMdenoise_RR'
    path.mkdir(parents=True, exist_ok=True)
    with h5py.File(path / 'betas_sessions.hdf5', 'a') as f:

        sessions = subject_data['sessions']
        num_sessions = len(sessions)
        shape = sessions[0]['betas'].shape
        T, W, H, D = shape
        T_full = T * len(sessions)
        
        f.require_dataset('betas', shape=(T_full, W * H * D), dtype=np.int16, chunks=(T_full, 1))
        for i in tqdm(range(W)):
            Y = np.concatenate([
                session['betas'][:, i]
                for session in sessions
            ])
            slice_size = H * D
            f['betas'][:, slice_size * i:slice_size * (i + 1)] = rearrange(Y, 't ... -> t (...)')
        

subj05


  0%|          | 0/78 [00:00<?, ?it/s]

subj06


  0%|          | 0/83 [00:00<?, ?it/s]

subj07


  0%|          | 0/81 [00:00<?, ?it/s]

subj08


  0%|          | 0/78 [00:00<?, ?it/s]

In [None]:
# Save feature selection

encoder_name = 'fracridge'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r') as f:

In [None]:
# Encoder experiments on a slice

subject_name = 'subj01'

subject = subjects[subject_name]
sessions = subject['sessions']
t1 = nib.load(subject['t1_path']).get_fdata().T

split_name = 'split-01'
split = h5py.File(derivatives_path / 'data_splits' / f'{split_name}.hdf5')

subject_split = split[subject_name]
test_mask = subject_split['test_response_mask'][:].astype(bool)
validation_mask = subject_split['validation_response_mask'][:].astype(bool)
training_mask = ~(test_mask | validation_mask)
training_indices = np.where(training_mask)[0]

run_stimulus_embeddings = {
    'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': ['embedding',],# *(f'transformer.resblocks.{i}' for i in range(12))],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

embedding_files = {
    model_name: h5py.File(derivatives_path / 'stimulus_embeddings' / f'{model_name}-embeddings.hdf5', 'r')
    for model_name in run_stimulus_embeddings.keys()
}

W, H, D = t1.shape
axial_slice = 37
saggital_slice = slice(None)
coronal_slice = slice(-80)

# Extract an interesting slice from the visual cortex
# (time, axial, coronal, saggital)
betas_slice = np.concatenate([
    session['betas'][:, axial_slice, coronal_slice, saggital_slice]
    for session in sessions
])[training_indices]

betas_slice = torch.from_numpy(betas_slice).float() / 300
betas_slice = betas_slice - betas_slice.mean(dim=0, keepdims=True)
betas_slice = betas_slice / (betas_slice.std(dim=0, keepdims=True) + 1e-7)

mask = (subject['brainmask'].get_fdata() > 0.).T
mask = mask[axial_slice, coronal_slice, saggital_slice]

Y = betas_slice[:, mask]

responses = subject['responses']
training_stimulus_ids = responses['73KID'].to_numpy()[training_indices] - 1

subject_stimulus_embeddings = {
    model_name : {
        embedding_name: torch.from_numpy(embedding_files[model_name][embedding_name][:][training_stimulus_ids])
        for embedding_name in embedding_names
    }
    for model_name, embedding_names in run_stimulus_embeddings.items()
}

'''
T = Y.shape[0]
@interact(t=(0, T-1))
def show(t):
    plt.imshow(t1.get_fdata()[axial_slice, saggital_slice, coronal_slice],  cmap='gray')'''

width=30
@interact(w=(0, W-1), h=(width, H-width-1), d=(width, D-width-1),)
def show(w, h, d):
    #plt.imshow(t1[w, h-width:h+width, d-width:d+width],  cmap='gray')
    plt.imshow(t1[axial_slice, coronal_slice, saggital_slice][::-1, ::-1], cmap='gray')

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
def to_device(device, *tensors):
    return (tensor.to(device) for tensor in tensors)

X_bigbigan = subject_stimulus_embeddings['bigbigan-resnet50']['z_mean']
X_clip = subject_stimulus_embeddings['ViT-B=32']['embedding']

X = X_clip.clone()
X = X - X.mean(dim=0, keepdims=True)
X = X / (X.std(dim=0, keepdims=True) + 1e-7)

gc.collect()
torch.cuda.empty_cache()
fractions = torch.arange(.05, 1.05, .05)
coefs, alpha, r2, frac = to_device('cpu', *frac_ridge_regression_cv(*to_device('cuda', X, Y, fractions)))

'''
#frac = []
#alpha = []
#r2 = []
#coefs = []
fractions = np.arange(.05, 1.05, .05)
for i in tqdm(range(912, Y.shape[1])):
    model = FracRidgeRegressorCV(frac_grid=fractions, cv=5)
    model.fit(X.numpy(), Y[:, i].numpy())
    frac.append(model.best_frac_)
    alpha.append(model.alpha_.item())
    r2.append(model.best_score_)
    coefs.append(model.coef_)
frac = np.array(frac)
alpha = np.array(alpha)
r2 = np.array(r2)
coefs = np.stack(coefs)'''

In [None]:
r2.max()

In [None]:
plot_slice = np.zeros_like(mask, dtype=float)
plot_slice[mask] = r2

print(plot_slice.max())
plt.imshow(plot_slice[::-1, ::-1], vmin=0, vmax=plot_slice.max(), cmap='jet')

In [None]:
r2_diff = r2_reg - r2
max_diff = torch.max(r2_diff.abs())

plot_slice[mask] = r2_diff

plt.imshow(plot_slice[::-1, ::-1], vmin=-max_diff, vmax=max_diff, cmap='bwr')

In [None]:
X = X_bigbigan[:100]
X = X - X.mean(dim=0, keepdims=True)
X = X / (X.std(dim=0, keepdims=True) + 1e-7)

y = np.stack([np.arange(X.shape[1]) for _ in range(X.shape[0])])

plt.figure(figsize=(24, 8))
plt.scatter(y.flatten(), X.flatten())

In [None]:
X_bigbigan.shape

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Torch implementation of ridge

from numbers import Number
from typing import Sequence

def rsquared(Y, Y_pred, dim=0, cast_dtype=torch.float64):
    in_dtype = Y.dtype
    Y = Y.to(cast_dtype)
    Y_pred = Y_pred.to(cast_dtype)

    ss_res = ((Y - Y_pred) ** 2).sum(dim=dim)
    ss_tot = ((Y - Y.mean(dim=dim, keepdim=True)) ** 2).sum(dim=dim)

    r2 = 1 - ss_res / ss_tot
    return r2


def ridge_regression(X, Y, alpha=None,):
    lhs = X.transpose(-2, -1) @ X
    rhs = X.transpose(-2, -1) @ Y
    if alpha is None:
        return torch.linalg.lstsq(lhs, rhs).solution
    else:
        ridge = alpha * torch.eye(lhs.shape[-2], device=X.device)
        return torch.linalg.lstsq(lhs + ridge, rhs).solution


def ridge_regression_cv(X, Y, alpha=None, cv=5, tol=1e-6):
    if alpha is None:
        alpha = 10 ** torch.linspace(1, 5, 20)
    
    kf = KFold(n_splits=cv, shuffle=True)
    r2 = []
    for train_ids, val_ids in kf.split(np.arange(X.shape[-2])):
        X_train = X[..., train_ids, :]
        Y_train = Y[..., train_ids, :]
        X_val = X[..., val_ids, :]
        Y_val = Y[..., val_ids, :]
        
        # Add a new dimension for alphas (try every alpha vs every target)
        coefs = ridge_regression(X_train[None], Y_train[None], alpha[:, None, None])
        Y_val_pred = X_val[None] @ coefs
        
        r2.append(rsquared(Y_val[None], Y_val_pred, dim=-2))
        
    r2 = torch.stack(r2).mean(dim=0)
    best_r2, best_alpha_ids = r2.max(dim=-2)
    best_alpha = alpha[best_alpha_ids]
    
    best_coefs = ridge_regression(X_train[None], Y_train.transpose(-2, -1)[..., None], best_alpha[:, None, None])
    best_coefs = best_coefs[..., 0].transpose(-1, -2)
    return best_coefs, best_alpha, best_r2

'''
N = int(10000)
batch_size = 100
embedding_size = 512
t = time.time()
for i in tqdm(range(N // batch_size)):
    X = torch.randn(21750, embedding_size).cuda()
    Y = torch.randn(21750, batch_size).cuda()
    alpha = 10 ** torch.linspace(1, 5, 20).cuda()
    coefs, alpha, r2 = ridge_regression_cv(X, Y, alpha=alpha)
    print(f'{coefs.shape=}, {alpha.shape=}, {r2.shape=}')
    break'''

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
solution.shape