In [1]:
import train as proteo_train
import os
import torch
import torch.nn.functional as F
from proteo.datasets.ftd import FTDDataset, reverse_log_transform
import torch.nn.functional as F
import pytorch_lightning as pl
from scipy.stats import zscore


def load_checkpoint(relative_checkpoint_path):
    '''Load the checkpoint as a module. Note levels_up depends on the directory structure of the ray_results folder'''
    relative_checkpoint_path = os.path.join(relative_checkpoint_path, 'checkpoint.ckpt')
    # Check if the file exists to avoid errors
    if not os.path.isfile(relative_checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {relative_checkpoint_path}")
    module = proteo_train.Proteo.load_from_checkpoint(relative_checkpoint_path)
    return module

# Load in the datasets from the config
def load_config(module):
    '''Load the config from the module  and return it'''
    config = module.config
    return config


def load_model_and_predict(module, config, device = 'cuda'):
    '''Run the module with the correct train and test datasets and return the predictions and targets'''
    module.to(device)
    module.eval()
    #pl.seed_everything(config.seed)
    train_dataset, test_dataset = proteo_train.construct_datasets(config)
    train_loader, test_loader = proteo_train.construct_loaders(config, train_dataset, test_dataset)
    # Get predictions and targets for the training set
    train_preds, train_targets = [], []
    for batch in train_loader:
        batch.to(device)
        with torch.no_grad():
        # Forward pass
            pred = module(batch)
            target = batch.y.view(pred.shape)
        
        # Store predictions and targets
        train_preds.append(pred.cpu())
        train_targets.append(target.cpu())
    train_preds = torch.cat(train_preds)
    train_targets = torch.cat(train_targets)
    
    # Calculate MSE for training set
    train_mse = F.mse_loss(train_preds, train_targets).item()
    
    # Get predictions and targets for the validation set
    val_preds, val_targets = [], []
    for batch in test_loader:
        batch.to(device)
        # Forward pass
        pred = module(batch)
        target = batch.y.view(pred.shape)
        
        # Store predictions and targets
        val_preds.append(pred.cpu())
        val_targets.append(target.cpu())
    val_preds = torch.cat(val_preds)
    val_targets = torch.cat(val_targets)
    
    # Calculate MSE for validation set
    val_mse = F.mse_loss(val_preds, val_targets).item()
    print("Normalized Val MSE:", val_mse)
    print("Normalized train MSE:", train_mse)
    return train_preds, train_targets, train_mse, val_preds, val_targets, val_mse

def full_load_and_run_and_convert(relative_checkpoint_path, device, mean, std):
    '''Call all the functions to load the checkpoint, run the model and convert the predictions back to the original units'''
    module = load_checkpoint(relative_checkpoint_path)
    config = load_config(module)
    train_preds, train_targets, train_mse, val_preds, val_targets, val_mse = load_model_and_predict(module, config, device)
    train_preds = reverse_log_transform(train_preds, mean, std)
    train_targets = reverse_log_transform(train_targets, mean, std)
    train_mse = F.mse_loss(train_preds, train_targets)
    train_rmse = torch.sqrt(train_mse)
    val_preds = reverse_log_transform(val_preds, mean, std)
    val_targets = reverse_log_transform(val_targets, mean, std)
    val_mse = F.mse_loss(val_preds, val_targets)
    val_rmse = torch.sqrt(val_mse)
    print(val_preds.view(-1).detach().cpu().numpy())
    val_z_scores = zscore(val_preds.view(-1).detach().cpu().numpy() - val_targets.view(-1).detach().cpu().numpy(), ddof=1)
    #print("Original Units Train preds:", train_preds)
    #print("Original Units Train targets:", train_targets)
    #print("Original Units Train MSE:", train_mse)
    #print("Original Units Train RMSE:", train_rmse)
    #print("Original Units Val preds:", val_preds)
    #print("Original Units Val targets:", val_targets)
    print("Original Units Val MSE:", val_mse)
    print("Original Units Val RMSE:", val_rmse)
    print("Val Z scores:", val_z_scores)
    return [train_preds, train_targets, train_mse, train_rmse, val_preds, val_targets, val_mse, val_rmse, val_z_scores]

def process_checkpoints(checkpoint_paths, mean_dict, std_dict, device):
    results = []
    i = 1
    for checkpoint_path in checkpoint_paths:
        print(f"Loading checkpoint from: {checkpoint_path}")
        module = load_checkpoint(checkpoint_path)
        config = load_config(module)
        #print("Config being used:", config)
        print(f"{i} best checkpoint for {config.sex} and {config.modality}")
        key = f"{config.sex}_{config.modality}"
        mean = mean_dict[key]
        std = std_dict[key]
        
        result = full_load_and_run_and_convert(checkpoint_path, device, mean, std)
        results.append(result)
        i += 1
    return results

In [18]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
c9_mean_dict = {"['M']_csf":2.20139473218633, 
                "['F']_plasma":2.5069125020915246,
                "['F']_csf":2.3905483112831987,
                "['M', 'F']_plasma":2.4382370774886417,
                "['M', 'F']_csf":2.323617044833538}
c9_std_dict = {"['M']_csf":0.9414006476156331,
                "['F']_plasma":0.9801098341235991,
                "['F']_csf":0.95108017948172,
                "['M', 'F']_plasma":0.9639665529956777,
                "['M', 'F']_csf":0.951972757962228}

c9_BEST_RUNS_M=[
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=3784_347_act=elu,adj_thresh=0.0500,batch_size=8,dropout=0.1000,l1_lambda=0.0010,lr=0.0007,lr_scheduler=CosineAnn_2024-08-01_11-17-45/checkpoint_000101',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=36658_797_act=leaky_relu,adj_thresh=0.7000,batch_size=16,dropout=0.1000,l1_lambda=0.0002,lr=0.0010,lr_scheduler=_2024-08-01_13-15-10/checkpoint_000132',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=43041_59_act=leaky_relu,adj_thresh=0.1000,batch_size=32,dropout=0.2000,l1_lambda=0.0033,lr=0.0122,lr_scheduler=C_2024-08-01_10-30-18/checkpoint_000006'
]
c9_BEST_RUNS_F= [
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=15140_461_act=tanh,adj_thresh=0.5000,batch_size=8,dropout=0.1000,l1_lambda=0.0001,lr=0.0002,lr_scheduler=CosineA_2024-08-01_11-48-59/checkpoint_000016',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=29159_100_act=sigmoid,adj_thresh=0.1000,batch_size=50,dropout=0.1000,l1_lambda=0.0004,lr=0.0005,lr_scheduler=Lam_2024-08-01_10-30-18/checkpoint_000249',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=44912_137_act=leaky_relu,adj_thresh=0.9000,batch_size=8,dropout=0,l1_lambda=0.0037,lr=0.0229,lr_scheduler=StepLR_2024-08-01_10-30-18/checkpoint_000009'
]
c9_BEST_RUNS_M_F= [
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=40910_881_act=tanh,adj_thresh=0.7000,batch_size=16,dropout=0.1000,l1_lambda=0.0000,lr=0.0037,lr_scheduler=Cosine_2024-08-01_13-32-59/checkpoint_000039',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=42404_611_act=leaky_relu,adj_thresh=0.9000,batch_size=16,dropout=0.2000,l1_lambda=0.0000,lr=0.0061,lr_scheduler=_2024-08-01_12-24-45/checkpoint_000146',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=40995_232_act=relu,adj_thresh=0.9000,batch_size=8,dropout=0,l1_lambda=0.0001,lr=0.0004,lr_scheduler=ReduceLROnPl_2024-08-01_10-40-25/checkpoint_000027'
]
process_checkpoints(c9_BEST_RUNS_M, c9_mean_dict, c9_std_dict, device)
process_checkpoints(c9_BEST_RUNS_F, c9_mean_dict, c9_std_dict, device)
process_checkpoints(c9_BEST_RUNS_M_F, c9_mean_dict, c9_std_dict, device)

Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=3784_347_act=elu,adj_thresh=0.0500,batch_size=8,dropout=0.1000,l1_lambda=0.0010,lr=0.0007,lr_scheduler=CosineAnn_2024-08-01_11-17-45/checkpoint_000101
1 best checkpoint for ['M'] and csf
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.05_num_nodes_150_mutation_C9orf72_csf_sex_M_train.pt
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.05_num_nodes_150_mutation_C9orf72_csf_sex_M_test.pt
Normalized Val MSE: 0.06449513137340546
Normalized train MSE: 0.15109355747699738
[ 3.5153    7.955807  5.813492 11.785532 17.050365]
Original Units Val MSE: tensor(4.6682, grad_fn=<MseLossBackward0>)
Original Units Val RMSE: tensor(2.1606, grad_fn=<SqrtBackward0>)
Val Z scores: [ 0.65391517 -1.2735126   0.33508742 -0.80584943  1.0903596 ]
Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30

[[tensor([[  5.8184],
          [ 70.6321],
          [  7.2476],
          [  6.9193],
          [  5.3588],
          [  6.3820],
          [ 10.0662],
          [  5.8440],
          [  5.7913],
          [  4.2404],
          [  6.5051],
          [  4.4764],
          [ 15.7411],
          [  6.2532],
          [  8.6289],
          [ 69.8403],
          [  4.7678],
          [192.6585],
          [  6.4075],
          [113.3991],
          [  8.4937],
          [  5.6313],
          [ 31.8740],
          [  6.1263],
          [  4.2468],
          [  4.3882],
          [  5.5273],
          [  4.8352],
          [  4.3488],
          [  5.4444],
          [ 89.9066],
          [  9.1867],
          [  4.6255],
          [  4.2695],
          [  7.8667],
          [ 65.0458],
          [  4.1354],
          [  5.5514],
          [  5.8448],
          [133.3592],
          [  7.5410],
          [  9.0281],
          [  7.0236],
          [  7.2100],
          [ 97.0730],
          

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
MAPT_mean_dict = {"['M']_csf":2.080694213697065, 
                "['M']_plasma":2.1657279439973016,
                "['F']_csf":2.0152637385189265,
                "['M', 'F']_csf": 2.0454624193703754}
MAPT_std_dict = {"['M']_csf":0.6213240141321779,
                "['M']_plasma":0.6840496344783593,
                "['F']_csf":0.7999340389937927,
                "['M', 'F']_csf":0.7237378322971036}

MAPT_BEST_RUNS_M=[
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=62542_686_act=leaky_relu,adj_thresh=0.0500,batch_size=8,dropout=0.2000,l1_lambda=0.0000,lr=0.0053,lr_scheduler=R_2024-08-01_12-44-53/checkpoint_000024',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=8990_704_act=leaky_relu,adj_thresh=0.7000,batch_size=32,dropout=0.2000,l1_lambda=0.0002,lr=0.0011,lr_scheduler=L_2024-08-01_12-49-21/checkpoint_000027',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=16742_1047_act=sigmoid,adj_thresh=0.5000,batch_size=16,dropout=0.0500,l1_lambda=0.0000,lr=0.0000,lr_scheduler=La_2024-08-01_14-01-42/checkpoint_000000'
]
    
MAPT_BEST_RUNS_F= [
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=37235_729_act=leaky_relu,adj_thresh=0.9000,batch_size=32,dropout=0,l1_lambda=0.0001,lr=0.0007,lr_scheduler=Reduc_2024-08-01_12-56-46/checkpoint_000015',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=46905_207_act=sigmoid,adj_thresh=0.1000,batch_size=32,dropout=0.3000,l1_lambda=0.0003,lr=0.0006,lr_scheduler=Ste_2024-08-01_10-31-43/checkpoint_000028',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=14399_99_act=tanh,adj_thresh=0.7000,batch_size=50,dropout=0.1000,l1_lambda=0.0000,lr=0.0019,lr_scheduler=CosineA_2024-08-01_10-30-18/checkpoint_000027'
]

MAPT_BEST_RUNS_M_F= [
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=23432_69_act=sigmoid,adj_thresh=0.9000,batch_size=50,dropout=0,l1_lambda=0.0003,lr=0.0057,lr_scheduler=ReduceLRO_2024-08-01_10-30-18/checkpoint_000009',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=35171_951_act=elu,adj_thresh=0.9000,batch_size=16,dropout=0.0500,l1_lambda=0.0001,lr=0.0019,lr_scheduler=StepLR,_2024-08-01_13-46-49/checkpoint_000008',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=25447_249_act=sigmoid,adj_thresh=0.7000,batch_size=32,dropout=0.2000,l1_lambda=0.0001,lr=0.0101,lr_scheduler=Ste_2024-08-01_10-47-31/checkpoint_000002'
]

process_checkpoints(MAPT_BEST_RUNS_M, MAPT_mean_dict, MAPT_std_dict, device)
process_checkpoints(MAPT_BEST_RUNS_F, MAPT_mean_dict, MAPT_std_dict, device)
process_checkpoints(MAPT_BEST_RUNS_M_F, MAPT_mean_dict, MAPT_std_dict, device)

Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=62542_686_act=leaky_relu,adj_thresh=0.0500,batch_size=8,dropout=0.2000,l1_lambda=0.0000,lr=0.0053,lr_scheduler=R_2024-08-01_12-44-53/checkpoint_000024
1 best checkpoint for ['M'] and plasma
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.05_num_nodes_150_mutation_MAPT_plasma_sex_M_train.pt
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.05_num_nodes_150_mutation_MAPT_plasma_sex_M_test.pt
Normalized Val MSE: 0.3830680847167969
Normalized train MSE: 0.18608273565769196
Original Units Val MSE: tensor(60.0437, grad_fn=<MseLossBackward0>)
Original Units Val RMSE: tensor(7.7488, grad_fn=<SqrtBackward0>)
Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=8990_704_act=leaky_relu,adj_thresh=0.7000,batch_size=32,dropout=0.2000,l1_lambda=0.0002,lr=0.0011,lr_sc

[(tensor([[ 5.3603],
          [ 5.6851],
          [ 6.6921],
          [ 5.3087],
          [ 6.5727],
          [ 9.3294],
          [ 5.5381],
          [ 5.8862],
          [ 6.4908],
          [ 5.1970],
          [ 7.6585],
          [ 8.3916],
          [ 5.0310],
          [ 5.0351],
          [ 7.3955],
          [10.9128],
          [ 6.7094],
          [ 6.4397],
          [ 9.3152],
          [ 8.3094],
          [ 5.8805],
          [ 5.8522],
          [ 8.1659],
          [ 5.6992],
          [ 5.1975],
          [ 8.6050],
          [ 5.4215],
          [ 6.4946],
          [ 6.6584],
          [ 9.4137],
          [ 5.3420],
          [ 5.0834],
          [ 9.2575],
          [ 5.9569],
          [ 6.1477],
          [ 9.2359],
          [ 6.7914],
          [ 7.8718],
          [ 5.4672],
          [ 5.6368],
          [ 6.1177]]),
  tensor([[ 2.7373],
          [ 1.4759],
          [ 5.9444],
          [ 1.8875],
          [13.3588],
          [22.1398],
          [

In [16]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
59173
GRN_mean_dict = {"['M']_csf": 2.178815827183045,
                "['F']_plasma":3.120974866855634,
                "['F']_csf": 3.2586357196385385,
                "['M', 'F']_csf":2.752470145050026}
GRN_std_dict = {"['M']_csf":0.7776541040264751,
                "['F']_plasma":1.2401561087499366,
                "['F']_csf":1.1764975422138229,
                "['M', 'F']_csf":1.1441881493582908}

GRN_BEST_RUNS_M=[
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=52208_883_act=relu,adj_thresh=0.7000,batch_size=8,dropout=0,l1_lambda=0.0002,lr=0.0010,lr_scheduler=ReduceLROnPl_2024-08-01_13-33-15/checkpoint_000020',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=10511_829_act=sigmoid,adj_thresh=0.9000,batch_size=8,dropout=0.3000,l1_lambda=0.0052,lr=0.0010,lr_scheduler=Redu_2024-08-01_13-24-19/checkpoint_000003',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=7148_163_act=sigmoid,adj_thresh=0.1000,batch_size=32,dropout=0.0500,l1_lambda=0.0183,lr=0.0192,lr_scheduler=Step_2024-08-01_10-30-18/checkpoint_000005',

]
GRN_BEST_RUNS_F= [
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_17-06-03/model=gat-v4,seed=6059_182_act=relu,adj_thresh=0.1000,batch_size=16,dropout=0.3000,l1_lambda=0.0001,lr=0.0011,lr_scheduler=CosineA_2024-08-01_17-06-04/checkpoint_000065',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_17-06-03/model=gat-v4,seed=42177_32_act=leaky_relu,adj_thresh=0.0500,batch_size=8,dropout=0.1000,l1_lambda=0.0000,lr=0.0072,lr_scheduler=La_2024-08-01_17-06-04/checkpoint_000104',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_17-06-03/model=gat-v4,seed=41068_56_act=tanh,adj_thresh=0.0500,batch_size=50,dropout=0,l1_lambda=0.0011,lr=0.0096,lr_scheduler=CosineAnneal_2024-08-01_17-06-04/checkpoint_000066',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=59173_241_act=tanh,adj_thresh=0.1000,batch_size=32,dropout=0.0500,l1_lambda=0.0000,lr=0.0064,lr_scheduler=Lambda_2024-08-01_10-44-20/checkpoint_000017'
]

GRN_BEST_RUNS_M_F= [
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=39163_985_act=leaky_relu,adj_thresh=0.9000,batch_size=8,dropout=0.0500,l1_lambda=0.0000,lr=0.0041,lr_scheduler=R_2024-08-01_13-53-51/checkpoint_000019',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=63998_355_act=elu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0001,lr=0.0108,lr_scheduler=ReduceLROnPla_2024-08-01_11-20-27/checkpoint_000003',
    '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=14600_751_act=tanh,adj_thresh=0.7000,batch_size=32,dropout=0.0500,l1_lambda=0.0011,lr=0.0236,lr_scheduler=Cosine_2024-08-01_12-59-43/checkpoint_000042'
]
process_checkpoints(GRN_BEST_RUNS_M, GRN_mean_dict, GRN_std_dict, device)
process_checkpoints(GRN_BEST_RUNS_F, GRN_mean_dict, GRN_std_dict, device)
process_checkpoints(GRN_BEST_RUNS_M_F, GRN_mean_dict, GRN_std_dict, device)

Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=52208_883_act=relu,adj_thresh=0.7000,batch_size=8,dropout=0,l1_lambda=0.0002,lr=0.0010,lr_scheduler=ReduceLROnPl_2024-08-01_13-33-15/checkpoint_000020
1 best checkpoint for ['M'] and csf
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.7_num_nodes_10_mutation_GRN_csf_sex_M_train.pt
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.7_num_nodes_10_mutation_GRN_csf_sex_M_test.pt
Normalized Val MSE: 0.0497010201215744
Normalized train MSE: 0.7081002593040466
Original Units Val MSE: tensor(0.2582, grad_fn=<MseLossBackward0>)
Original Units Val RMSE: tensor(0.5081, grad_fn=<SqrtBackward0>)
Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-01_10-30-17/model=gat-v4,seed=10511_829_act=sigmoid,adj_thresh=0.9000,batch_size=8,dropout=0.3000,l1_lambda=0.0052,lr=0.0010,lr_scheduler=Redu_2024-08

[(tensor([[38.1163],
          [ 4.3769],
          [ 9.2230],
          [ 7.9793],
          [33.8475],
          [13.0684],
          [43.1009],
          [32.2196],
          [ 0.3669],
          [15.1122],
          [20.5845],
          [14.0804],
          [ 4.1669],
          [48.3165],
          [ 0.6697],
          [ 2.2397],
          [ 0.7753],
          [36.7248],
          [ 2.0329],
          [ 5.7276],
          [ 0.9056],
          [40.8088],
          [ 3.1494],
          [ 5.8064],
          [39.7984]]),
  tensor([[80.2171],
          [ 7.9675],
          [11.2010],
          [10.8412],
          [92.7933],
          [ 6.9760],
          [44.1702],
          [35.0793],
          [ 2.8600],
          [ 6.0449],
          [21.8106],
          [11.1943],
          [ 3.2937],
          [84.4435],
          [ 2.2781],
          [ 9.1830],
          [ 4.5831],
          [73.4552],
          [ 6.6160],
          [13.1766],
          [ 5.1185],
          [95.4446],
          [

In [2]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
#Best run passing in sex, mutation and age before encoder
full_load_and_run_and_convert('/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-13_15-49-20/model=gat-v4,seed=31061_269_act=sigmoid,adj_thresh=0.1000,batch_size=8,dropout=0.1000,l1_lambda=0.0008,lr=0.0000,lr_scheduler=Lamb_2024-08-13_16-58-56/checkpoint_000005', device, 2.124088581365514, 0.8733420033790319)
#full_load_and_run_and_convert('/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-13_15-49-20/model=gat-v4,seed=55118_1133_act=relu,adj_thresh=0.9000,batch_size=32,dropout=0.2000,l1_lambda=0.0001,lr=0.0000,lr_scheduler=Lambd_2024-08-14_01-14-01/checkpoint_000001',device, 2.124088581365514, 0.8733420033790319)


full_load_and_run_and_convert('/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-15_10-15-54/model=gat-v4,seed=35068_564_act=sigmoid,adj_thresh=0.7000,batch_size=8,dropout=0.2000,l1_lambda=0.0006,lr=0.0000,lr_scheduler=Lamb_2024-08-15_13-19-27/checkpoint_000065', device, 2.124088581365514, 0.8733420033790319)



AttributeError: 'Config' object has no attribute 'use_master_nodes'

In [5]:
# Best run using sex, mutation, age, masternodes
full_load_and_run_and_convert('/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-08-23_13-20-08/model=gat-v4,seed=44609_31_act=relu,adj_thresh=0.9000,batch_size=50,dropout=0,l1_lambda=0.0104,lr=0.0034,lr_scheduler=ReduceLROnPl_2024-08-23_13-20-08/checkpoint_000006', device, 2.124088581365514, 0.8733420033790319)

Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.9_num_nodes_7247_mutation_GRN,MAPT,C9orf72,CTL_csf_sex_M,F_masternodes_True_train.pt
Loading data from: /home/data/data_louisa/processed/ftd_y_val_nfl_adj_thresh_0.9_num_nodes_7247_mutation_GRN,MAPT,C9orf72,CTL_csf_sex_M,F_masternodes_True_test.pt
Normalized Val MSE: 0.3589880168437958
Normalized train MSE: 0.6138961911201477
[15.2708235 10.195642  18.776457   9.160423   3.0481133  9.250263
 11.133699  11.667741   4.350158   5.136927   8.905613  18.445415
  3.127379   4.4976735 13.419723  13.967364  14.09947    7.299478
 17.701893  11.746225  11.734119   4.4216065  3.7322218  8.885336
  6.459698   6.8190637 17.22844   18.026674   6.6333866  6.278182
 15.459351   8.390783   3.2082236 11.944394   4.344727  12.127326
  3.2634714 18.682575   6.1436768 17.546104  12.951021   7.9535956
  5.470645   4.4759927 13.552354 ]
Original Units Val MSE: tensor(164.0844, grad_fn=<MseLossBackward0>)
Original Units Val RMSE: t

[tensor([[ 5.6601],
         [ 7.3133],
         [19.0930],
         [ 3.6284],
         [ 9.7957],
         [12.2536],
         [ 3.2121],
         [16.1121],
         [12.6195],
         [ 3.2712],
         [ 4.2896],
         [ 6.0337],
         [ 3.5932],
         [18.2804],
         [16.0355],
         [12.7171],
         [18.5303],
         [ 5.6577],
         [ 8.6207],
         [ 3.6887],
         [ 3.7167],
         [ 8.2238],
         [12.4337],
         [ 3.0844],
         [ 3.8381],
         [ 3.2121],
         [10.0568],
         [15.2929],
         [ 3.2598],
         [16.5107],
         [ 4.2044],
         [ 7.1950],
         [13.6706],
         [ 6.8542],
         [ 4.7521],
         [15.2225],
         [ 7.7305],
         [18.1262],
         [18.8649],
         [11.3500],
         [11.7557],
         [16.3371],
         [ 3.1791],
         [ 7.7107],
         [13.3683],
         [ 8.1163],
         [ 3.6262],
         [ 6.0099],
         [10.8881],
         [16.0550],


In [14]:
#Sanity check
def compute_manual_mse(val_preds, val_targets):
    """
    Manually computes the Mean Squared Error (MSE) for the given predictions and targets.

    Parameters:
    val_preds (list of list of torch.Tensor): The predicted values.
    val_targets (list of list of torch.Tensor): The true target values.

    Returns:
    float: The computed Mean Squared Error.
    """
   # Compute the squared differences
    squared_diffs = (val_preds - val_targets) ** 2

    # Compute the mean of the squared differences
    mse = squared_diffs.mean().item()

    return mse

print(compute_manual_mse(val_preds, val_targets))

716.7650146484375


In [13]:
import os
import torch
import torch.nn.functional as F
import train as proteo_train

# Define a function to load the checkpoint and calculate MSE
def load_checkpoint_and_calculate_mse(relative_checkpoint_path, levels_up=5):
    # Get the current script directory
    current_directory = os.getcwd()
    
    # Navigate up the specified number of levels
    for _ in range(levels_up):
        current_directory = os.path.dirname(current_directory)
    
    # Construct the full path to the checkpoint
    checkpoint_path = os.path.join(current_directory, relative_checkpoint_path)
    print(f"Loading checkpoint from: {checkpoint_path}")

    # Check if the file exists to avoid errors
    if not os.path.isfile(checkpoint_path):
        raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}")

    # Load the checkpoint
    checkpoint = torch.load(checkpoint_path)
    print("Checkpoint keys:", checkpoint.keys())
    print("checkpoint state_dict keys:", checkpoint['state_dict'].keys())

    module = proteo_train.Proteo.load_from_checkpoint(checkpoint_path)

    # Access the attributes
    # best_val_pred = module.best_val_pred
    # print("best_val_pred:", best_val_pred)
    # # print("min_val loss:", module.val_loss)
    # best_val_target = module.best_val_target
    # best_train_pred = module.best_train_pred
    # best_train_target = module.best_train_target

    # # Calculate MSE for validation and training
    # mse_val = F.mse_loss(best_val_pred, best_val_target).item()
    # mse_train = F.mse_loss(best_train_pred, best_train_target).item()

    return module, checkpoint

# Example usage
relative_checkpoint_path = '/scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-07-31_16-47-02/model=gat-v4,seed=19543_0_act=relu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0000,lr=0.1000,lr_scheduler=LambdaLR,modal_2024-07-31_16-47-02/checkpoint_000001/checkpoint.cpkt'
module, checkpoint = load_checkpoint_and_calculate_mse(relative_checkpoint_path)
# print(f"MSE Loss for validation set: {mse_val}")
# print(f"MSE Loss for training set: {mse_train}")


Loading checkpoint from: /scratch/lcornelis/outputs/ray_results/TorchTrainer_2024-07-31_15-31-35/model=gat-v4,seed=19543_0_act=relu,adj_thresh=0.1000,batch_size=8,dropout=0,l1_lambda=0.0000,lr=0.1000,lr_scheduler=LambdaLR,modal_2024-07-31_15-31-36/checkpoint_000001/checkpoint.ckpt
Checkpoint keys: dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])
checkpoint state_dict keys: odict_keys(['model.convs.0.att_src', 'model.convs.0.att_dst', 'model.convs.0.bias', 'model.convs.0.lin.weight', 'model.convs.1.att_src', 'model.convs.1.att_dst', 'model.convs.1.bias', 'model.convs.1.lin.weight', 'model.pools.0.weight', 'model.pools.0.bias', 'model.pools.1.weight', 'model.pools.1.bias', 'model.layer_norm.weight', 'model.layer_norm.bias', 'model.encoder.0.0.weight', 'model.encoder.0.0.bias', 'model.encoder.1.0.weight', 'model.encoder.1.0.bias', 'model.encoder.2.0.weight', 'model.

In [19]:
module.__dict__

{'training': True,
 '_parameters': OrderedDict(),
 '_buffers': OrderedDict(),
 '_non_persistent_buffers_set': set(),
 '_backward_pre_hooks': OrderedDict(),
 '_backward_hooks': OrderedDict(),
 '_is_full_backward_hook': None,
 '_forward_hooks': OrderedDict(),
 '_forward_hooks_with_kwargs': OrderedDict(),
 '_forward_hooks_always_called': OrderedDict(),
 '_forward_pre_hooks': OrderedDict(),
 '_forward_pre_hooks_with_kwargs': OrderedDict(),
 '_state_dict_hooks': OrderedDict(),
 '_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_pre_hooks': OrderedDict(),
 '_load_state_dict_post_hooks': OrderedDict(),
 '_modules': OrderedDict([('model',
               GATv4(
                 (convs): ModuleList(
                   (0): CustomGATConv(1, 8, heads=2)
                   (1): CustomGATConv(16, 16, heads=3)
                 )
                 (pools): ModuleList(
                   (0): Linear(in_features=16, out_features=1, bias=True)
                   (1): Linear(in_features=48, out_fea

In [14]:
checkpoint

{'epoch': 1,
 'global_step': 8,
 'pytorch-lightning_version': '2.3.3',
 'state_dict': OrderedDict([('model.convs.0.att_src',
               tensor([[[ 0.6818,  0.0561,  0.6685,  0.3070,  0.7125,  0.4150,  0.2690,
                         -0.2861],
                        [-0.2914,  0.9814, -0.0518, -0.1408,  0.3510, -0.2540,  0.5184,
                         -0.2630]]], device='cuda:0')),
              ('model.convs.0.att_dst',
               tensor([[[ 0.8912, -0.2167,  0.5399, -0.1660,  0.3122, -0.2951,  0.1650,
                          0.5460],
                        [ 1.0074,  0.5066,  0.3665,  1.0729,  0.2071,  1.0349,  0.2690,
                          1.0922]]], device='cuda:0')),
              ('model.convs.0.bias',
               tensor([-0.0016,  0.2387,  0.2148,  0.2318,  0.1968,  0.2223,  0.2042,  0.2067,
                        0.2075, -0.0101,  0.2075,  0.2105,  0.2058,  0.2272,  0.2135,  0.2171],
                      device='cuda:0')),
              ('model.convs.0.li

In [None]:
# load in train and test datasets using config
# run model and get val_targets val_preds train_targets train_preds
# find loss for each