In [24]:
import sys
# sys.path.append('/home/gayal/ssl-project/gpatchTST')
import torch
from torch import nn
from tqdm import tqdm
import logging
import os
from configs import Config

In [17]:
logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s', 
                    handlers=[
                            logging.StreamHandler(sys.stdout),
                            logging.FileHandler('NW_test.log')
                    ]
                )

BASE_PATH = '/mnt/Helium/neeraj/ssl_feature_probes'

### setup 

In [19]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda:0


In [20]:
def get_checkpoint_dirs(base_path):
    logging.info(f"Scanning for checkpoints in: {base_path}")
    checkpoint_dirs = {}
    for root, dirs, files in os.walk(base_path):
        for file in files:
            if file.startswith('checkpoint') and file.endswith('.pth'):
                checkpoint_num = int(file.split('.pth')[0].split('_')[-1])
                checkpoint_dirs[checkpoint_num] = os.path.join(root, file)
    logging.info(f"Found {len(checkpoint_dirs)} checkpoints.")
    return checkpoint_dirs

In [21]:
from utils.pyt_dataset_utils import get_dataloader

def save_embeddings_for_data(checkpoint, new_dataset_path, save_dir, device):
    logging.info(f"Processing checkpoint: {checkpoint}")
    config_file_path = [i for i in os.listdir(os.path.dirname(checkpoint)) if i.endswith('.yaml')][0]
    config_file_path = os.path.join(os.path.dirname(checkpoint), config_file_path)
    config = Config(config_file=config_file_path).get()

    data_config = config['data']
    data_config['root_path'] = new_dataset_path
    data_config['csv_path'] = os.path.join(new_dataset_path, 'file_lengths_map.csv')
    model_config = config['model']
    revin = model_config['revin']
    patch_len = model_config['patch_length']
    stride = model_config['stride']

    '''
    SETUP DATALOADERS 
    '''
    val_dataset = 
    val_loader = get_dataloader(type='pyt', dataset=dataset, args=args, is_val=True)
    
    # train_loader, val_loader, test_loader = get_tuh_dataloaders_old_splits(
    #         data_config['root_path'],
    #         data_config['data_path'],
    #         data_config['csv_path'],
    #         batch_size=data_config['batch_size'],
    #         num_workers=data_config['num_workers'],
    #         prefetch_factor=data_config['prefetch_factor'],
    #         pin_memory=data_config['pin_memory'],
    #         drop_last=False,
    #         size=[model_config['seq_len'], 
    #               model_config['target_dim'],
    #               model_config['patch_length']],
    #     )

    # Load pretrained model
    model = get_patchTST_model(num_variates=data_config['n_vars'],
                                forecast_length=model_config['target_dim'],
                                patch_len=model_config['patch_length'],
                                stride=model_config['stride'],
                                num_patch=(model_config['seq_len'] - model_config['patch_length']) // model_config['stride'] + 1,
                                n_layers=model_config['num_layers'],
                                d_model=model_config['d_model'],
                                n_heads=model_config['num_heads'],
                                shared_embedding=model_config['shared_embedding'],
                                d_ff=model_config['d_ff'],
                                norm=model_config['norm'],
                                attn_dropout=model_config['attn_dropout'],
                                dropout=model_config['dropout'],
                                activation=model_config['activation'],
                                res_attention=model_config['res_attention'],
                                pe=model_config['pe'],
                                learn_pe=model_config['learn_pe'],
                                head_dropout=model_config['head_dropout'],
                                head_type=model_config['head_type'],
                                use_cls_token=model_config['use_cls_token'],
                            ).to(device)
    
    model.eval()
    with torch.no_grad():
        for loader in [train_loader, val_loader, test_loader]:
            logging.info(f"Processing loader: {loader}")
            for batch in tqdm(loader, desc=f"Extracting embeddings from {loader}"):
                data = batch['past_values'].to(device)
                filename = batch['filename']
                
                if revin:
                    data = revin(data, mode='norm')

                input_patches, _ = create_patches(data, patch_len, stride)

                output = model.backbone(input_patches) # [bs x nvars x d_model x (num_patch+1 or num_patch)]
                output = output[:, :, :, 0] # [bs x nvars x d_model]

                # save embeddings
                for sample in range(output.shape[0]): # samples in batch
                    sample_embeddings = output[sample].cpu().numpy() # [nvars x d_model]
                    sample_filename = filename[sample]
                    save_path = os.path.join(save_dir, f"{sample_filename}.npy")
                    np.save(save_path, sample_embeddings)

### run inference

In [25]:
'''
PREPROCESSED DATA USED FOR MODEL INFERENCE
'''
new_data_path = '/mnt/Helium/neeraj/ssl_feature_probes/preprocessed_eeg_data'

'''
PRETRAINED MODELS 
'''
pretrained_paths = {
    # 1000ms = 1s
    1: './pretrained_models/patchtst_pretrained_1s_patchlen_seq_10_sec/2025-04-17_21-01-03',
}


'''
INFERENCE LOOP OVER ALL CHECKPOINTS
'''
for patch_len, pretrained_path in pretrained_paths.items():
    logging.info(f"Processing pretrained path for patch length {patch_len}: {pretrained_path}")
    pretrained_base_path = os.path.join(BASE_PATH, pretrained_path)
    checkpoints = get_checkpoint_dirs(pretrained_base_path)
    print(f"Found {len(checkpoints)} checkpoints")

    embeddings_save_path = '/mnt/Helium/neeraj/ssl_feature_probes/saved_embeddings'
    os.makedirs(embeddings_save_path, exist_ok=True)

    patch_len_save_path = os.path.join(embeddings_save_path, f'patch_len_{patch_len}')
    os.makedirs(patch_len_save_path, exist_ok=True)

    for checkpoint_num, checkpoint_path in checkpoints.items():
        logging.info(f"Processing checkpoint {checkpoint_num}: {checkpoint_path}")
        checkpoint_save_path = os.path.join(patch_len_save_path, f'checkpoint_{checkpoint_num:03d}')
        os.makedirs(checkpoint_save_path, exist_ok=True)
        save_embeddings_for_data(checkpoint_path, new_data_path, checkpoint_save_path, device)

Found 10 checkpoints


NameError: name 'get_tuh_dataloaders_old_splits' is not defined