In [20]:
import train as proteo_train
import os
import torch
import torch.nn.functional as F
from proteo.datasets.ftd import FTDDataset, reverse_log_transform
import torch.nn.functional as F


def load_checkpoint(relative_checkpoint_path, levels_up=5):
    '''Load the checkpoint as a module. Note levels_up depends on the directory structure of the ray_results folder'''
    # Get the current script directory
    current_directory = os.getcwd()
    
    # Navigate up the specified number of levels
    for _ in range(levels_up):
        current_directory = os.path.dirname(current_directory)
    
    # Construct the full path to the checkpoint
    checkpoint_path = os.path.join(current_directory,'scratch/lcornelis/outputs/ray_results/', relative_checkpoint_path)
    print(f"Loading checkpoint from: {checkpoint_path}")

    # Check if the file exists to avoid errors
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")
    
    module = proteo_train.Proteo.load_from_checkpoint(checkpoint_path)
    return module

# Load in the datasets from the config
def load_config(module):
    '''Load the config from the module  and return it'''
    config = module.config
    print("Config being used:", config)
    return config


def load_model_and_predict(module, config, device = 'cuda'):
    '''Run the module with the correct train and test datasets and return the predictions and targets'''
    module.to(device)
    module.eval()
    train_dataset, test_dataset = proteo_train.construct_datasets(config)
    train_loader, test_loader = proteo_train.construct_loaders(config, train_dataset, test_dataset)
    # Get predictions and targets for the training set
    train_preds, train_targets = [], []
    for batch in train_loader:
        batch.to(device)
        # Forward pass
        pred = module(batch)
        target = batch.y.view(pred.shape)
        
        # Store predictions and targets
        train_preds.append(pred.cpu())
        train_targets.append(target.cpu())
    train_preds = torch.cat(train_preds)
    train_targets = torch.cat(train_targets)
    
    # Calculate MSE for training set
    train_mse = F.mse_loss(train_preds, train_targets).item()
    
    # Get predictions and targets for the validation set
    val_preds, val_targets = [], []
    for batch in test_loader:
        batch.to(device)
        # Forward pass
        pred = module(batch)
        target = batch.y.view(pred.shape)
        
        # Store predictions and targets
        val_preds.append(pred.cpu())
        val_targets.append(target.cpu())
    val_preds = torch.cat(val_preds)
    val_targets = torch.cat(val_targets)
    
    # Calculate MSE for validation set
    val_mse = F.mse_loss(val_preds, val_targets).item()
    return train_preds, train_targets, train_mse, val_preds, val_targets, val_mse

def full_load_and_run_and_convert(relative_checkpoint_path, device, mean, std, levels_up=5):
    '''Call all the functions to load the checkpoint, run the model and convert the predictions back to the original units'''
    module = load_checkpoint(relative_checkpoint_path, levels_up)
    config = load_config(module)
    train_preds, train_targets, train_mse, val_preds, val_targets, val_mse = load_model_and_predict(module, config, device)
    train_preds = reverse_log_transform(train_preds, mean, std)
    train_targets = reverse_log_transform(train_targets, mean, std)
    train_mse = F.mse_loss(train_preds, train_targets)
    train_rmse = torch.sqrt(train_mse)
    val_preds = reverse_log_transform(val_preds, mean, std)
    val_targets = reverse_log_transform(val_targets, mean, std)
    val_mse = F.mse_loss(val_preds, val_targets)
    val_rmse = torch.sqrt(val_mse)
    print("Checkpoint path", relative_checkpoint_path)
    #print("Original Units Train preds:", train_preds)
    #print("Original Units Train targets:", train_targets)
    print("Original Units Train MSE:", train_mse)
    print("Original Units Train RMSE:", train_rmse)
    #print("Original Units Val preds:", val_preds)
    #print("Original Units Val targets:", val_targets)
    print("Original Units Val MSE:", val_mse)
    print("Original Units Val RMSE:", val_rmse)
    return train_preds, train_targets, train_mse, train_rmse, val_preds, val_targets, val_mse, val_rmse

In [22]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
relative_checkpoint_path = 'TorchTrainer_2024-07-31_10-47-21/model=gat-v4,seed=19543_0_act=relu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0000,lr=0.1000,lr_scheduler=LambdaLR,modal_2024-07-31_10-47-21/checkpoint_000002/checkpoint.ckpt'
train_preds, train_targets, train_mse, train_rmse, val_preds, val_targets, val_mse, val_rmse = full_load_and_run_and_convert(relative_checkpoint_path, device, 2.3840692826511245, 0.9650973053482799)

Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-07-31_10-47-21/model=gat-v4,seed=19543_0_act=relu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0000,lr=0.1000,lr_scheduler=LambdaLR,modal_2024-07-31_10-47-21/checkpoint_000002/checkpoint.ckpt
Config being used: gat_v4_weight_initializer=['uniform'] gat_hidden_channels=[8, 32, 128, 256] device=[0] root_dir='/home/lcornelis/code/proteo' checkpoint_every_n_epochs_train=1 l1_lambda=1e-05 pin_memory=True sex=['M'] wgcna_mergeCutHeight=0.25 y_val='nfl' log_every_n_steps=10 num_samples=1 gat_v4_heads=[[2, 3]] cpu_per_worker=16 nodes_count=1 adj_thresh=0.1 gat_heads=[1, 2, 4, 8] optimizer='Adam' checkpoint_dir='/scratch/lcornelis/outputs/checkpoints' num_to_keep=3 seed=19543 modality_choices=['plasma'] l1_lambda_min=1e-05 gat-v4={'hidden_channels': [8, 16], 'heads': [2, 3], 'use_layer_norm': True, 'which_layer': ['layer1', 'layer2', 'layer3'], 'fc_dim': [64, 128, 128, 32], 'fc_dropout': 0.1, 'fc_act': 're

In [14]:
#Sanity check
def compute_manual_mse(val_preds, val_targets):
    """
    Manually computes the Mean Squared Error (MSE) for the given predictions and targets.

    Parameters:
    val_preds (list of list of torch.Tensor): The predicted values.
    val_targets (list of list of torch.Tensor): The true target values.

    Returns:
    float: The computed Mean Squared Error.
    """
   # Compute the squared differences
    squared_diffs = (val_preds - val_targets) ** 2

    # Compute the mean of the squared differences
    mse = squared_diffs.mean().item()

    return mse

print(compute_manual_mse(val_preds, val_targets))

716.7650146484375


In [20]:
import os
import torch
import torch.nn.functional as F
import train as proteo_train

# Define a function to load the checkpoint and calculate MSE
def load_checkpoint_and_calculate_mse(relative_checkpoint_path, levels_up=5):
    # Get the current script directory
    current_directory = os.getcwd()
    
    # Navigate up the specified number of levels
    for _ in range(levels_up):
        current_directory = os.path.dirname(current_directory)
    
    # Construct the full path to the checkpoint
    checkpoint_path = os.path.join(current_directory, relative_checkpoint_path)
    print(f"Loading checkpoint from: {checkpoint_path}")

    # Check if the file exists to avoid errors
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    print("Checkpoint keys:", checkpoint.keys())
    print("checkpoint state_dict keys:", checkpoint['state_dict'].keys())

    module = proteo_train.Proteo.load_from_checkpoint(checkpoint_path)

    # Access the attributes
    # best_val_pred = module.best_val_pred
    # print("best_val_pred:", best_val_pred)
    # # print("min_val loss:", module.val_loss)
    # best_val_target = module.best_val_target
    # best_train_pred = module.best_train_pred
    # best_train_target = module.best_train_target

    # # Calculate MSE for validation and training
    # mse_val = F.mse_loss(best_val_pred, best_val_target).item()
    # mse_train = F.mse_loss(best_train_pred, best_train_target).item()

    return module, checkpoint

# Example usage
relative_checkpoint_path = 'scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-07-31_10-47-21/model=gat-v4,seed=19543_0_act=relu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0000,lr=0.1000,lr_scheduler=LambdaLR,modal_2024-07-31_10-47-21/checkpoint_000003/checkpoint.ckpt'
module, checkpoint = load_checkpoint_and_calculate_mse(relative_checkpoint_path)
# print(f"MSE Loss for validation set: {mse_val}")
# print(f"MSE Loss for training set: {mse_train}")


Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-07-31_10-47-21/model=gat-v4,seed=19543_0_act=relu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0000,lr=0.1000,lr_scheduler=LambdaLR,modal_2024-07-31_10-47-21/checkpoint_000003/checkpoint.ckpt
Checkpoint keys: dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])
checkpoint state_dict keys: odict_keys(['model.convs.0.att_src', 'model.convs.0.att_dst', 'model.convs.0.bias', 'model.convs.0.lin.weight', 'model.convs.1.att_src', 'model.convs.1.att_dst', 'model.convs.1.bias', 'model.convs.1.lin.weight', 'model.pools.0.weight', 'model.pools.0.bias', 'model.pools.1.weight', 'model.pools.1.bias', 'model.layer_norm.weight', 'model.layer_norm.bias', 'model.encoder.0.0.weight', 'model.encoder.0.0.bias', 'model.encoder.1.0.weight', 'model.encoder.1.0.bias', 'model.encoder.2.0.weight', 'model.

In [19]:
module.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('model',
               GATv4(
                 (convs): ModuleList(
                   (0): CustomGATConv(1, 8, heads=2)
                   (1): CustomGATConv(16, 16, heads=3)
                 )
                 (pools): ModuleList(
                   (0): Linear(in_features=16, out_features=1, bias=True)
                   (1): Linear(in_features=48, out_fea

In [22]:
checkpoint

{'epoch': 3,
 'global_step': 16,
 'pytorch-lightning_version': '2.3.3',
 'state_dict': OrderedDict([('model.convs.0.att_src',
               tensor([[[ 0.8699, -0.1098,  0.5025,  0.1256,  0.5469,  0.2400,  0.1043,
                         -0.4770],
                        [-0.4824,  1.1463, -0.2332, -0.3317,  1.0416, -0.4451,  0.3276,
                         -0.4540]]], device='cuda:0')),
              ('model.convs.0.att_dst',
               tensor([[[ 0.8882, -0.2228,  0.5346, -0.4571,  0.3048, -0.5855,  0.1434,
                          0.3646],
                        [ 1.1702,  0.3941,  0.3302,  1.2372, -0.0151,  1.1996,  0.4364,
                          1.2554]]], device='cuda:0')),
              ('model.convs.0.bias',
               tensor([ 0.0028,  0.4199,  0.3917,  0.4135,  0.3705,  0.4011,  0.3794,  0.3819,
                        0.3859, -0.0071,  0.3830,  0.3899, -0.0170,  0.4089,  0.3918,  0.3970],
                      device='cuda:0')),
              ('model.convs.0.l

In [None]:
# load in train and test datasets using config
# run model and get val_targets val_preds train_targets train_preds
# find loss for each