In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys
import time
import random
import json
import gc

import numpy as np
import pandas as pd
import torch
import torchio as tio
import h5py
from ipywidgets import interact
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import nibabel as nib
from einops import rearrange
from scipy import ndimage
from fracridge import FracRidgeRegressorCV


dir2 = os.path.abspath('../..')
dir1 = os.path.dirname(dir2)
if not dir1 in sys.path:
    sys.path.append(dir1)
    
from research.models.regression_torch import (
    pearsonr, 
    rsquared,
    frac_ridge_regression,
    frac_ridge_regression_cv,
    ridge_regression,
    ridge_regression_cv,
)

from research.data.natural_scenes import (
    NaturalScenesDataset
)

C:\Users\Cefir\anaconda3\envs\Neurophysiological-Data-Decoding\lib\site-packages\numpy\.libs\libopenblas.EL2C6PLE4ZYW3ECEVIV3OXXGRN2NRFM2.gfortran-win_amd64.dll
C:\Users\Cefir\anaconda3\envs\Neurophysiological-Data-Decoding\lib\site-packages\numpy\.libs\libopenblas.GK7GX5KEQ4F6UYO3P26ULGBQYHGQO7J4.gfortran-win_amd64.dll
C:\Users\Cefir\anaconda3\envs\Neurophysiological-Data-Decoding\lib\site-packages\numpy\.libs\libopenblas.WCDJNK7YVMPZQ2ME2ZZHJJRJ3JIKNDB7.gfortran-win_amd64.dll
C:\Users\Cefir\anaconda3\envs\Neurophysiological-Data-Decoding\lib\site-packages\numpy\.libs\libopenblas.xwydx2ikjw2nmtwsfyngfuwkqu3lytcz.gfortran-win_amd64.dll


In [None]:
dataset_path = Path('D:\\Datasets\\NSD\\')
dataset = NaturalScenesDataset(dataset_path)

derivatives_path = dataset_path / 'derivatives'
betas_path = dataset_path / 'nsddata_betas' / 'ppdata'
ppdata_path = dataset_path / 'nsddata' / 'ppdata'
subjects = dataset.subjects

In [None]:
subjects = {f'subj0{i}': {} for i in range(1, 9)}

for subject_name, subject_data in subjects.items():
    responses_file_path = ppdata_path / subject_name / 'behav' / 'responses.tsv'
    subject_data['responses'] = pd.read_csv(responses_file_path, sep='\t',)
    
    # The last 3 sessions are currently held-out for the algonauts challenge
    # remove them for now.
    session_ids = subject_data['responses']['SESSION']
    held_out_mask = session_ids > (np.max(session_ids) - 3)
    subject_data['responses'] = subject_data['responses'][~held_out_mask]
    
    subject_betas_path = derivatives_path / subject_name / 'func1pt8mm' / 'betas_fithrf_GLMdenoise_RR'
    num_sessions = np.max(subject_data['responses']['SESSION'])
    
    #subject_data['sessions'] = [
    #    h5py.File(subject_betas_path / f'betas_session{i:02}.hdf5', 'r')
    #    for i in range(1, num_sessions + 1)
    #]
    
    subject_data['betas'] = h5py.File(subject_betas_path / f'betas_sessions.hdf5', 'r')
    
    subject_data['brainmask'] = nib.load(ppdata_path / subject_name / 'func1pt8mm' / 'brainmask.nii.gz')
    subject_data['t1_path'] = ppdata_path / subject_name / 'func1pt8mm' / 'T1_to_func1pt8mm.nii.gz'

In [None]:
# run an encoder

split_name = 'split-01'
split = h5py.File(derivatives_path / 'data_splits' / f'{split_name}.hdf5')
fractions = torch.arange(.05, 1.05, .05).cuda()
#alpha = 10 ** torch.linspace(1, 5, 20).cuda()

run_stimulus_embeddings = {
    #'bigbigan-resnet50': ['z_mean'],
    #'ViT-B=32': [f'transformer.resblocks.{i}' for i in range(12)],
    #'ViT-B=32': ['embedding', *(f'transformer.resblocks.{i}' for i in range(12))],
    'DPT_Large': ['depth-pyramid-24'],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

embedding_files = {
    model_name: h5py.File(derivatives_path / 'stimulus_embeddings' / f'{model_name}.hdf5', 'r')
    for model_name in run_stimulus_embeddings.keys()
}

seed = 0
max_features = 1e9 # 512
normalize_embeddings = True

with h5py.File(derivatives_path / 'fracridge-parameters.hdf5', 'a') as f:
    for subject_name, subject_data in list(subjects.items())[:1]:
        sessions = subject_data['sessions']
        num_sessions = len(sessions)
        shape = sessions[0]['betas'].shape
        T, W, H, D = shape
        
        subject_split = split[subject_name]
        test_mask = subject_split['test_response_mask'][:].astype(bool)
        validation_mask = subject_split['validation_response_mask'][:].astype(bool)
        training_mask = ~(test_mask | validation_mask)
        training_indices = np.where(training_mask)[0]

        responses = subject_data['responses']
        training_stimulus_ids = responses['73KID'].to_numpy()[training_indices] - 1
        
        subject_stimulus_embeddings = {}
        for model_name, embedding_names in run_stimulus_embeddings.items():
            subject_stimulus_embeddings[model_name] = {}
            for embedding_name in embedding_names:
                keys = (subject_name, model_name, embedding_name)
                print('loading', keys)
                
                embedding = embedding_files[model_name][embedding_name][:][training_stimulus_ids]
                embedding = torch.from_numpy(embedding)
                embedding = embedding.flatten(start_dim=1)
                if embedding.shape[1] > max_features:
                    np.random.seed(seed)
                    choice = np.random.choice(max_features, size=max_features)
                    embedding = embedding[:, choice]
                embedding = embedding.float().cuda()
                embedding_mean = embedding.mean(dim=0, keepdims=True)
                embedding_std = embedding.std(dim=0, keepdims=True)
                if normalize_embeddings:
                    embedding = (embedding - embedding_mean) / embedding_std
                subject_stimulus_embeddings[model_name][embedding_name] = embedding.float().cuda()
                
                key = '/'.join(keys)
                E = embedding.shape[-1]
                group = f.require_group(key)
                group.require_dataset('coefs', shape=(W, H, D, E), dtype='f4')
                group.require_dataset('alpha', shape=(W, H, D), dtype='f4')
                group.require_dataset('r2', shape=(W, H, D), dtype='f4')
                group.require_dataset('fractions', shape=(W, H, D), dtype='f4')
                
                group.require_dataset('embedding_mean', embedding_mean.shape, 'f4')
                group.require_dataset('embedding_std', embedding_std.shape, 'f4')
                group['embedding_mean'][:] = embedding_mean.cpu().numpy()
                group['embedding_std'][:] = embedding_std.cpu().numpy()

        mask = (subjects[subject_name]['brainmask'].get_fdata() > 0.).T
        
        load_time = 0
        compute_time = 0
        store_time = 0
        
        for i in tqdm(range(32, W)):
            if (i + 1) % 5 == 0:
                print(f'{load_time=:03}s, {compute_time=:03}s, {store_time=:03}s')
    
            mask_slice = mask[i]
            if mask_slice.sum() == 0:
                continue
            
            t = time.time()
            Y = np.concatenate([
                session['betas'][:, i] 
                for session in sessions
            ])
            Y = Y[training_indices]
            Y = Y[:, mask_slice]
            Y = torch.from_numpy(Y).float() / 300
            Y = (Y - Y.mean(dim=0, keepdims=True)) / Y.std(dim=0, keepdims=True)
            #Y = Y.T[..., None]
            
            slice_indices = torch.from_numpy(np.argwhere(mask_slice))
            load_time += time.time() - t
            
            batch_size = 100
            Y_splits = torch.split(Y, batch_size, dim=1)
            indices_splits = torch.split(slice_indices, batch_size)
            for Y_batch, indices_batch in zip(Y_splits, indices_splits):
                Y_batch = Y_batch.cuda()
                for model_name, embeddings in subject_stimulus_embeddings.items():
                    for embedding_name, embedding in embeddings.items():
                        t = time.time()
                        
                        gc.collect()
                        torch.cuda.empty_cache()

                        keys = (subject_name, model_name, embedding_name)
                        group = f['/'.join(keys)]
                        X = embedding[:]
                        
                        #print(X.shape, Y_batch.shape, fractions.shape)
                        #coefs, alpha, r2 = ridge_regression_cv(X, Y_batch[..., 0].T, alpha=alpha)
                        coefs, alpha, r2, frac = frac_ridge_regression_cv(X, Y_batch, fractions)
                        coefs = coefs.cpu().numpy()
                        a = alpha.cpu().numpy()
                        r2 = r2.cpu().numpy()
                        frac = frac.cpu().numpy()
                        
                        compute_time += time.time() - t

                        t = time.time()
                        for v, (j, k) in enumerate(indices_batch):
                            group['coefs'][i, j, k] = coefs[:, v]
                            group['alpha'][i, j, k] = a[v]
                            group['r2'][i, j, k] = r2[v]
                            group['fractions'][i, j, k] = frac[v]
                        store_time += time.time() - t


In [None]:
# save encoder results

encoder_name = 'fracridge'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r') as f:
    for subject_name, subject in f.items():
        subject_out_path = derivatives_path / 'images' / encoder_name / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)
        affine = subjects[subject_name]['brainmask'].affine
        
        for model_name, model in subject.items():
            for embedding_name, embedding in model.items():
                for image_name, image, in embedding.items():
                    keys = (subject_name, encoder_name, model_name, embedding_name, image_name)
                    save_file_name = f'{"__".join(keys)}.nii.gz'
                    print(keys, image.shape)
                    if len(image.shape) == 3:
                        image = nib.Nifti1Image(image[:].T, affine)
                    elif len(image.shape) == 4:
                        continue
                        image = nib.Nifti1Image(image[:].T, affine)
                    else:
                        continue
                    nib.save(image, subject_out_path / save_file_name)

In [None]:
# Save sequence

encoder_name = 'fracridge'
model_name = 'ViT-B=32'
embedding_names = [*(f'transformer.resblocks.{i}' for i in range(12)), 'embedding']
image_names = ['alpha', 'fractions', 'r2']
sequence_name = 'depth_sequence'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r') as f:
    for subject_name, subject in f.items():
        subject_out_path = derivatives_path / 'images' / encoder_name / subject_name
        subject_out_path.mkdir(parents=True, exist_ok=True)
        affine = subjects[subject_name]['brainmask'].affine
        
        model = subject[model_name]
        for image_name in image_names:
            image_data = np.stack([
                model[embedding_name][image_name][:]
                for embedding_name in embedding_names
            ]).T
            keys = (subject_name, encoder_name, model_name, sequence_name, image_name)
            save_file_name = f'{"__".join(keys)}.nii.gz'
            image = nib.Nifti1Image(image_data, affine)
            nib.save(image, subject_out_path / save_file_name)

In [None]:
# Add feature selection indices

def require_dataset(group, name, data):
    group.require_dataset(name, shape=data.shape, dtype=data.dtype)
    group[name][:] = data

encoder_name = 'fracridge'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'a') as f:
    for subject_name, subject in f.items():
        for model_name, model in subject.items():
            for embedding_name, embedding in model.items():
                r2 = embedding['r2'][:]
                sorted_indices_flat = r2.argsort(axis=None)[::-1]
                grid = np.argwhere(np.ones_like(r2, dtype=bool))
                sorted_indices = grid[sorted_indices_flat]
                
                require_dataset(embedding, 'sorted_indices_flat', sorted_indices_flat)
                require_dataset(embedding, 'sorted_indices', sorted_indices)

In [None]:
encoder_name = 'fracridge'
encoder_results = h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r')

@interact(subject=encoder_results.items())
def select_subject(subject):
    subject_name = subject.name[1:]
    #sessions = subjects[subject_name]['sessions']
    betas = subjects[subject_name]['betas']
    
    @interact(model=subject.items())
    def select_model(model):
        
        @interact(embedding=model.items())
        def select_embedding(embedding):
            
            r2 = embedding['r2'][:]
            sorted_indices_flat = r2.argsort(axis=None)[::-1]
            grid = np.argwhere(np.ones_like(r2, dtype=bool))
            sorted_indices = grid[sorted_indices_flat]
            print(grid)
            t = time.time()
            
            Y = np.stack([
                betas[:, i] 
                for i in sorted_indices_flat[:2500]
            ], axis=1)
            
            #for i, j, k in sorted_indices[:2500]:
            #    Y = np.concatenate([
            #        session['betas'][:, i, j, k]
            #        for session in sessions
            #    ])
            print(time.time() - t)
            print(Y.shape)
        
            @interact(w=(0, r2.shape[0]-1), num=(0, 50000))
            def show_top(w, num):
                selection_map = np.zeros_like(r2)
                i, j, k = list(sorted_indices[:num].T)
                selection_map[i, j, k] = 1
                plt.imshow(selection_map[w].T)
                

In [None]:
# concatenate sessions

for subject_name, subject_data in subjects.items():
    #if subject_name in ('subj01', 'subj02', 'subj03', 'subj04', ):
    #    continue
    print(subject_name)
    path = derivatives_path / 'betas' / subject_name / 'func1pt8mm' / 'betas_fithrf_GLMdenoise_RR'
    path.mkdir(parents=True, exist_ok=True)
    with h5py.File(path / 'betas_sessions_new.hdf5', 'a') as f:

        sessions = subject_data['sessions']
        num_sessions = len(sessions)
        shape = sessions[0]['betas'].shape
        T, W, H, D = shape
        T_full = T * len(sessions)
        
        f.require_dataset('betas', shape=(T_full, W * H * D), dtype=np.int16, chunks=(T_full, 1))
        for i in tqdm(range(W)):
            Y = np.concatenate([
                session['betas'][:, i]
                for session in sessions
            ])
            slice_size = H * D
            f['betas'][:, slice_size * i:slice_size * (i + 1)] = rearrange(Y, 't ... -> t (...)')
    break
        

In [25]:
rearrange(np.ones((500, 500)), '(s b) v -> s b v', s=100).shape

(100, 5, 500)

In [32]:
# Compute and cache mean and std of betas for each session

dataset_path = Path('D:\\Datasets\\NSD\\')
space = 'func1pt8mm'
glm = 'betas_fithrf_GLMdenoise_RR'

betas_scale = 300
betas_per_session = 750

for i in range(1, 9):
    subject_name = f'subj0{i}'
    subject_betas_path = dataset_path / 'derivatives' / 'betas' / subject_name / space / glm
    betas_file_path = subject_betas_path / 'betas_sessions.hdf5'
    
    original_session_path = dataset_path / 'nsddata_betas' / 'ppdata' / subject_name / space / glm
    original_session = h5py.File(original_session_path / 'betas_session01.hdf5', 'r')
    
    with h5py.File(betas_file_path, 'a') as f:
        print(num_voxels)
        num_betas = f['betas'].shape[0]
        num_voxels = f['betas'].shape[1]
        
        num_sessions = int(num_betas / betas_per_session)
        spatial_shape = original_session['betas'].shape[1:]
        
        f['betas'].attrs['spatial_shape'] = spatial_shape
        f.require_dataset('mean', shape=(num_sessions, num_voxels), dtype=np.float32)
        f.require_dataset('std', shape=(num_sessions, num_voxels), dtype=np.float32)
        indices = np.argwhere(np.ones(shape=spatial_shape, dtype=bool)).astype(int)
        f.require_dataset('indices', shape=indices.shape, dtype=indices.dtype)
        f['indices'][:] = indices
        
        num_batches = 100
        for voxel_indices in tqdm(np.array_split(np.arange(num_voxels), num_batches)):
            betas = f['betas'][:, voxel_indices]
            betas = rearrange(betas, '(s b) v -> s b v', s=num_sessions, b=betas_per_session)
            betas = betas.astype(float) / betas_scale
            f['mean'][:, voxel_indices] = betas.mean(axis=1)
            f['std'][:, voxel_indices] = betas.std(axis=1)

673200


  0%|          | 0/100 [00:00<?, ?it/s]

699192


  0%|          | 0/100 [00:00<?, ?it/s]

730128


  0%|          | 0/100 [00:00<?, ?it/s]

704052


  0%|          | 0/100 [00:00<?, ?it/s]

673200


  0%|          | 0/100 [00:00<?, ?it/s]

597714


  0%|          | 0/100 [00:00<?, ?it/s]

797215


  0%|          | 0/100 [00:00<?, ?it/s]

600210


  0%|          | 0/100 [00:00<?, ?it/s]

In [33]:
# save noise ceilings for voxel selection

subjects

NameError: name 'subjects' is not defined

In [None]:
# Save feature selection

encoder_name = 'fracridge'

with h5py.File(derivatives_path / f'{encoder_name}-parameters.hdf5', 'r') as f:

In [None]:
# Encoder experiments on a slice

subject_name = 'subj01'

subject = subjects[subject_name]
sessions = subject['sessions']
t1 = nib.load(subject['t1_path']).get_fdata().T

split_name = 'split-01'
split = h5py.File(derivatives_path / 'data_splits' / f'{split_name}.hdf5')

subject_split = split[subject_name]
test_mask = subject_split['test_response_mask'][:].astype(bool)
validation_mask = subject_split['validation_response_mask'][:].astype(bool)
training_mask = ~(test_mask | validation_mask)
training_indices = np.where(training_mask)[0]

run_stimulus_embeddings = {
    'bigbigan-resnet50': ['z_mean'],
    'ViT-B=32': ['embedding', *(f'transformer.resblocks.{i}' for i in range(12))],
    #'biggan-128': ['z', 'y_embedding'],
    #'vqgan': ['vqgan-f16-1024-pre_quant'],
}

embedding_files = {
    model_name: h5py.File(derivatives_path / 'stimulus_embeddings' / f'{model_name}-embeddings.hdf5', 'r')
    for model_name in run_stimulus_embeddings.keys()
}

W, H, D = t1.shape
axial_slice = 37
saggital_slice = slice(None)
coronal_slice = slice(-80)

# Extract an interesting slice from the visual cortex
# (time, axial, coronal, saggital)
betas_slice = np.concatenate([
    session['betas'][:, axial_slice, coronal_slice, saggital_slice]
    for session in sessions
])[training_indices]

betas_slice = torch.from_numpy(betas_slice).float() / 300
betas_slice = betas_slice - betas_slice.mean(dim=0, keepdims=True)
betas_slice = betas_slice / (betas_slice.std(dim=0, keepdims=True) + 1e-7)

mask = (subject['brainmask'].get_fdata() > 0.).T
mask = mask[axial_slice, coronal_slice, saggital_slice]

Y = betas_slice[:, mask]

responses = subject['responses']
training_stimulus_ids = responses['73KID'].to_numpy()[training_indices] - 1

max_features = 512
seed = 0

subject_stimulus_embeddings = {}
for model_name, embedding_names in run_stimulus_embeddings.items():
    subject_stimulus_embeddings[model_name] = {}
    for embedding_name in embedding_names:
        keys = (subject_name, model_name, embedding_name)
        print('loading', keys)

        embedding = embedding_files[model_name][embedding_name][:][training_stimulus_ids]
        embedding = torch.from_numpy(embedding)
        embedding = embedding.flatten(start_dim=1)
        if embedding.shape[1] > max_features:
            np.random.seed(seed)
            choice = np.random.choice(max_features, size=max_features)
            embedding = embedding[:, choice]
        embedding = embedding.float()
        subject_stimulus_embeddings[model_name][embedding_name] = embedding.float().cuda()

'''
T = Y.shape[0]
@interact(t=(0, T-1))
def show(t):
    plt.imshow(t1.get_fdata()[axial_slice, saggital_slice, coronal_slice],  cmap='gray')'''

width=30
@interact(w=(0, W-1), h=(width, H-width-1), d=(width, D-width-1),)
def show(w, h, d):
    #plt.imshow(t1[w, h-width:h+width, d-width:d+width],  cmap='gray')
    plt.imshow(t1[axial_slice, coronal_slice, saggital_slice][::-1, ::-1], cmap='gray')

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
X.shape, Y.shape

In [None]:
def to_device(device, *tensors):
    return [tensor.to(device) for tensor in tensors]

X_bigbigan = subject_stimulus_embeddings['bigbigan-resnet50']['z_mean']
X_clip = subject_stimulus_embeddings['ViT-B=32']['transformer.resblocks.3']

X = X_clip.clone()
X = X - X.mean(dim=0, keepdims=True)
X = X / (X.std(dim=0, keepdims=True) + 1e-7)


gc.collect()
torch.cuda.empty_cache()

'''
alpha = 10 ** torch.linspace(1, 5, 20)
results = [
    to_device('cpu', *ridge_regression_cv(*to_device('cuda', X, y, alpha)))
    for y in tqdm(torch.split(Y, 100, dim=1))
]

coefs, alpha, r2 = [
    torch.cat([elem for elem in val], dim=-1)
    for val in list(zip(*results))
]'''


fractions = torch.arange(.05, 1.05, .05)
results = [
    to_device('cpu', *frac_ridge_regression_cv(*to_device('cuda', X, y, fractions)))
    for y in tqdm(torch.split(Y, 100, dim=1))
]

#results = [
#    frac_ridge_regression_cv2(X.cpu(), y.cpu(), fractions)
#    for y in tqdm(torch.split(Y, 100, dim=1))
#]

coefs, alpha, r2, frac = [
    torch.cat([elem for elem in val], dim=-1)
    for val in list(zip(*results))
]

'''
frac = []
alpha = []
r2 = []
coefs = []
fractions = np.arange(.05, 1.05, .05)
for i in tqdm(range(Y.shape[1])):
    model = FracRidgeRegressorCV(frac_grid=fractions, cv=5)
    model.fit(X.cpu().numpy(), Y[:, i].cpu().numpy())
    frac.append(model.best_frac_)
    alpha.append(model.alpha_.item())
    r2.append(model.best_score_)
    coefs.append(model.coef_)
frac = np.array(frac)
alpha = np.array(alpha)
r2 = np.array(r2)
coefs = np.stack(coefs)'''

In [None]:
[coefs[:, i].isnan().sum() for i in range(coefs.shape[1])]

In [None]:
plot_slice = np.zeros_like(mask, dtype=float)
plot_slice[mask] = frac

print(plot_slice.max())
plt.imshow(plot_slice[::-1, ::-1], vmin=0, vmax=plot_slice.max(), cmap='jet')

In [None]:
r2.max()

In [None]:
t = time.time()
coefs, alpha = frac_ridge_regression(*to_device('cuda', X, Y[:, 0:1], fractions))
print(time.time() - t)

In [None]:
#from fracridge import fracridge

t = time.time()
coefs, alphas = fracridge(X.cpu().numpy(), Y.cpu().numpy(), fractions.cpu().numpy())
print(time.time() - t)


In [None]:
from fracridge.fracridge import _do_svd
BIG_BIAS = 10e3
SMALL_BIAS = 10e-3
BIAS_STEP = 0.2
from numpy import interp
from sklearn.model_selection import KFold
from einops import rearrange
    

def fracridge(X, y, fracs=None, tol=1e-10, jit=True):
    """
    Approximates alpha parameters to match desired fractions of OLS length.
    Parameters
    ----------
    X : ndarray, shape (n, p)
        Design matrix for regression, with n number of
        observations and p number of model parameters.
    y : ndarray, shape (n, b)
        Data, with n number of observations and b number of targets.
    fracs : float or 1d array, optional
        The desired fractions of the parameter vector length, relative to
        OLS solution. If 1d array, the shape is (f,). This input is required
        to be sorted. Otherwise, raises ValueError.
        Default: np.arange(.1, 1.1, .1).
    jit : bool, optional
        Whether to speed up computations by using a just-in-time compiled
        version of core computations. This may not work well with very large
        datasets. Default: True
    Returns
    -------
    coef : ndarray, shape (p, f, b)
        The full estimated parameters across units of measurement for every
        desired fraction.
    alphas : ndarray, shape (f, b)
        The alpha coefficients associated with each solution
    Examples
    --------
    Generate random data:
    >>> np.random.seed(0)
    >>> y = np.random.randn(100)
    >>> X = np.random.randn(100, 10)
    Calculate coefficients with naive OLS:
    >>> coef = np.linalg.inv(X.T @ X) @ X.T @ y
    >>> print(np.linalg.norm(coef))  # doctest: +NUMBER
    0.35
    Call fracridge function:
    >>> coef2, alpha = fracridge(X, y, 0.3)
    >>> print(np.linalg.norm(coef2))  # doctest: +NUMBER
    0.10
    >>> print(np.linalg.norm(coef2) / np.linalg.norm(coef))  # doctest: +NUMBER
    0.3
    Calculate coefficients with naive RR:
    >>> alphaI = alpha * np.eye(X.shape[1])
    >>> coef3 = np.linalg.inv(X.T @ X + alphaI) @ X.T @ y
    >>> print(np.linalg.norm(coef2 - coef3))  # doctest: +NUMBER
    0.0
    """
    if fracs is None:
        fracs = np.arange(.1, 1.1, .1)

    if hasattr(fracs, "__len__"):
        if np.any(np.diff(fracs) < 0):
            raise ValueError("The `frac` inputs to the `fracridge` function ",
                             f"must be sorted. You provided: {fracs}")

    else:
        fracs = [fracs]
    fracs = np.array(fracs)

    nn, pp = X.shape
    if len(y.shape) == 1:
        y = y[:, np.newaxis]

    bb = y.shape[-1]
    ff = fracs.shape[0]
    print(y.dtype)

    # Calculate the rotation of the data
    selt, v_t, ols_coef = _do_svd(X, y, jit=jit)
    print(f'{selt.dtype=}, {v_t.dtype=}, {ols_coef.dtype=}')

    # Set solutions for small eigenvalues to 0 for all targets:
    isbad = selt < tol
    if np.any(isbad):
        warnings.warn("Some eigenvalues are being treated as 0")

    ols_coef[isbad, ...] = 0

    # Limits on the grid of candidate alphas used for interpolation:
    val1 = BIG_BIAS * selt[0] ** 2
    val2 = SMALL_BIAS * selt[-1] ** 2

    # Generates the grid of candidate alphas used in interpolation:
    alphagrid = np.concatenate(
        [np.array([0]),
         10 ** np.arange(np.floor(np.log10(val2)),
                         np.ceil(np.log10(val1)), BIAS_STEP)])
    print(f'{alphagrid.dtype=}')

    # The scaling factor applied to coefficients in the rotated space is
    # lambda**2 / (lambda**2 + alpha), where lambda are the singular values
    seltsq = selt**2
    sclg = seltsq / (seltsq + alphagrid[:, None])
    sclg_sq = sclg**2
    print(f'{sclg_sq.dtype=} {sclg.dtype=}')

    # Prellocate the solution:
    if nn >= pp:
        first_dim = pp
    else:
        first_dim = nn

    coef = np.empty((first_dim, ff, bb))
    alphas = np.empty((ff, bb))
    print(f'{coef.dtype=} {alphas.dtype=}')

    # The main loop is over targets:
    for ii in range(y.shape[-1]):
        # Applies the scaling factors per alpha
        newlen = np.sqrt(sclg_sq @ ols_coef[..., ii]**2).T
        # Normalize to the length of the unregularized solution,
        # because (alphagrid[0] == 0)
        newlen = (newlen / newlen[0])
        # Perform interpolation in a log transformed space (so it behaves
        # nicely), avoiding log of 0.
        temp = interp(fracs, newlen[::-1], np.log(1 + alphagrid)[::-1])
        print(f'{temp.dtype=} {newlen.dtype=}')
        # Undo the log transform from the previous step
        targetalphas = np.exp(temp) - 1
        # Allocate the alphas for this target:
        alphas[:, ii] = targetalphas
        # Calculate the new scaling factor, based on the interpolated alphas:
        sc = seltsq / (seltsq + targetalphas[np.newaxis].T)
        # Use the scaling factor to calculate coefficients in the rotated
        # space:
        coef[..., ii] = (sc * ols_coef[..., ii]).T

    # After iterating over all targets, we unrotate using the unitary v
    # matrix and reshape to conform to desired output:
    coef = np.reshape(v_t.T @ coef.reshape((first_dim, ff * bb)),
                      (pp, ff, bb))

    return coef.squeeze(), alphas

def frac_ridge_regression_cv2(X, Y, fractions=None, cv=5, tol=1e-6):
    if fractions is None:
        fractions = torch.arange(.1, 1.1, .1)

    kf = KFold(n_splits=cv, shuffle=True)
    r2 = []
    for train_ids, val_ids in kf.split(np.arange(X.shape[-2])):
        X_train = X[..., train_ids, :]
        Y_train = Y[..., train_ids, :]
        X_val = X[..., val_ids, :]
        Y_val = Y[..., val_ids, :]
        
        coef, alpha = fracridge(X_train.numpy(), Y_train.numpy(), fractions.numpy())
        coef = torch.from_numpy(coef).float()
        alpha = torch.from_numpy(alpha).float()
        #print(X_val.shape, coef.shape)
        
        Y_val_pred = torch.einsum('... n p, ... p f b -> ... f n b', X_val, coef)
        r2.append(rsquared(Y_val[..., None, :, :], Y_val_pred, dim=-2))
    r2 = torch.stack(r2).mean(dim=0)
    best_r2, best_fraction_ids = r2.max(dim=-2)
    best_fractions = fractions[best_fraction_ids]

    #best_coefs, best_alpha = fracridge(X.numpy(), Y.numpy(), best_fractions[None].numpy())
    return coef, alpha, best_r2, best_fractions

In [None]:
r2_diff = r2_reg - r2
max_diff = torch.max(r2_diff.abs())

plot_slice[mask] = r2_diff

plt.imshow(plot_slice[::-1, ::-1], vmin=-max_diff, vmax=max_diff, cmap='bwr')

In [None]:
X = X_clip[:100]
X = X - X.mean(dim=0, keepdims=True)
X = X / (X.std(dim=0, keepdims=True) + 1e-7)

y = np.stack([np.arange(X.shape[1]) for _ in range(X.shape[0])])

plt.figure(figsize=(24, 8))
plt.scatter(y.flatten(), X.flatten().cpu())

In [None]:
X_bigbigan.shape

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Torch implementation of ridge


'''
N = int(10000)
batch_size = 100
embedding_size = 512
t = time.time()
for i in tqdm(range(N // batch_size)):
    X = torch.randn(21750, embedding_size).cuda()
    Y = torch.randn(21750, batch_size).cuda()
    alpha = 10 ** torch.linspace(1, 5, 20).cuda()
    coefs, alpha, r2 = ridge_regression_cv(X, Y, alpha=alpha)
    print(f'{coefs.shape=}, {alpha.shape=}, {r2.shape=}')
    break'''

In [None]:
gc.collect()
torch.cuda.empty_cache()

In [None]:
solution.shape