In [45]:
import os
import numpy as np
import h5py
import torch
import torch.nn as nn
import scipy.io
from scipy.io import savemat
os.chdir('/home/cyberspace007/mpicek/NeuralMAE')
import neuralmae.neural_models.models_multimodal_neuralmae_up2 as models_mae_multimodal
import neuralmae.neural_models.models_neuralmae_bsi as models_mae


### Load dataset

In [46]:
folder = '/media/cyberspace007/T7/martin/dt5/results/day_9/rec2/'
input_path = os.path.join(folder, 'all.mat')
output_path = os.path.join(folder, 'DINO_SSL_small_weight_10.mat')
multimodal = True

# lost files that were recomputed later needed to be opened this way :)
all_data = scipy.io.loadmat(input_path)['x']
all_data = np.transpose(all_data, (0, 3, 2, 1))


# all_data = h5py.File(input_path)['X'] # FIXME
print(f"dataset shape: {all_data.shape}")
# expected shape: (n_samples, 32, 24, 10) .. n_samples = 10*seconds

dataset shape: (10962, 32, 24, 10)


### Load model

In [47]:

def prepare_pretrained_brainGPT(chkpt_dir, arch='mae_neut_base_patch245_1implant'):
    #build model
    model = getattr(models_mae, arch)()

    #load model
    chkpt = torch.load(chkpt_dir, map_location='cpu')
    msg = model.load_state_dict(chkpt['model'], strict=False)
    print(msg)
    return model

def prepare_pretrained_brainGPT_multimodal(chkpt_dir):

    freeze_brainGPT = False
    # 'mae_neut_conf_tiny_multimodal_mlp_delta_DINO'
    # model = models_mae_multimodal.__dict__['mae_neut_conf_tiny_multimodal_mlp_accelerometer'](norm_pix_loss=False,
    model = models_mae_multimodal.__dict__['mae_neut_conf_tiny_multimodal_mlp_delta_DINO'](norm_pix_loss=False,
                                                    norm_session_loss=True,
                                                    uniformity_loss=False,
                                                    lamb=0.01,
                                                    # input_size=tuple(args.input_size), 
                                                    # patch_size=tuple(args.patch_size),
                                                    use_projector=False,
                                                    projector_dim=64,
                                                    freeze_brainGPT=freeze_brainGPT)

    checkpoint = torch.load(chkpt_dir, map_location='cpu')
    model.load_state_dict(checkpoint['model'])

    return model


In [48]:

if multimodal:
    # Multimodal BrainGPT (accelerometer)
    # chkpt_dir = "/media/cyberspace007/T7/tmp/training_logs/neuralmae/checkpoints/MLP_predict_ACC_NOT_frozen_BrainGPT/checkpoint-10.pth"
    chkpt_dir = "/media/cyberspace007/T7/tmp/training_logs/neuralmae/checkpoints/MLP_predict_DINO_NOT_frozen_BrainGPT_25mask_small_weight/checkpoint-10.pth"
    model_mae = prepare_pretrained_brainGPT_multimodal(chkpt_dir)
    print('Multimodal BrainGPT model loaded.')
else:
    # Vanilla BrainGPT
    chkpt_dir = '/home/cyberspace007/mpicek/NeuralMAE/pretrained_brainGPT/checkpoint-14_up2001.pth'
    model_mae = prepare_pretrained_brainGPT(chkpt_dir, 'mae_neut_base_patch245_1implant')
    print('Vanilla BrainGPT model loaded.')


Number of parameters of the BrainGPT: 110.970265 M
Number of parameters of the DINO MLP Projector: 6.039552 M
Multimodal BrainGPT model loaded.


### Extract latents

In [49]:
# extract latents
from tqdm import tqdm
device = torch.device('cuda:0')
model_mae = model_mae.to(device)

with torch.no_grad():
    
    model_mae.eval()
    inputs, latents, targets, sessions = [], [], [], []

    for i in tqdm(range(all_data.shape[0])):
        # print(f'epoch {i+1}/{all_data.shape[0]}, ', end='')
        epoch_wavelet = np.transpose(all_data[i,:,:,:], (1, 0, 2))[:,:32,:]
        samp = torch.from_numpy(epoch_wavelet).to(device, non_blocking=True, dtype=torch.float32).unsqueeze(0)

        with torch.cuda.amp.autocast():
            if multimodal:
                lat = model_mae.brainGPT.transform(samp, mask_ratio=0)
            else:
                lat = model_mae.transform(samp, mask_ratio=0)
            latents.append(lat)

    latents = np.concatenate(latents, axis=0)

100%|██████████| 10962/10962 [01:36<00:00, 113.90it/s]


In [50]:
def rankme(Z):
    """
    RankMe smooth rank estimation
    from: https://arxiv.org/abs/2210.02885

    Z: (N, K), N: nb samples, K: embed dim
    N = 25000 is a good approximation in general
    """

    S = torch.linalg.svdvals(Z) # singular values
    S_norm1 = torch.linalg.norm(S, 1)

    p = S/S_norm1 + 1e-7 # normalize sum to 1
    entropy = - torch.sum(p*torch.log(p))
    return torch.exp(entropy)

print(f'rank latent: {rankme(torch.from_numpy(latents))}') # ~2x4x10
print(latents.shape)

rank latent: 408.8719482421875
(10962, 768)


### Save latents as .mat

In [51]:

struct_name = 'xLatent'
savemat(output_path, {struct_name: latents.transpose()})
print(f'Latent features saved to {output_path}')

Latent features saved to /media/cyberspace007/T7/martin/dt5/results/day_9/rec2/DINO_SSL_small_weight_30.mat
