Script to test PDGrapher on chemical/genetic perturbation data

In [None]:
import torch
import pandas as pd
import sys
import numpy as np
import os
import os.path as osp
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
torch.manual_seed(0)
np.random.seed(0)
import os
import os.path as osp
torch.set_num_threads(20)
from datetime import datetime
from pdgrapher import PDGrapher
from pdgrapher import Trainer, Dataset
import sys
from glob import glob
from pdgrapher._utils import get_thresholds

Test on chemical or genetic datasets

In [None]:
data_type = "Chemical"

Define the corresponding cell lines

In [None]:
if data_type == "chemical":
    cell_lines = ['A549', 'MCF7', 'PC3', 'VCAP', 'MDAMB231', 'BT20', 'HT29', 'A375', 'HELA']
elif data_type == "genetic":
    cell_lines = ['PC3', 'YAPC', 'AGS', 'A375', 'HT29', 'A549', 'BICR6', 'U251MG', 'ES2', 'MCF7']


Setup parameters

In [None]:
use_backward_data = True
use_supervision = True #whether to use supervision loss
use_intervention_data = True #whether to use cycle loss
current_date = datetime.now() 
n_layers_nn = 1
global use_forward_data
use_forward_data = True

Load model and do test and evaluation

In [None]:
for cell_line in cell_lines:
    print(f"Processing cell line: {cell_line}")

    if data_type == 'chemical':
        if cell_line in ['HA1E', 'HT29', 'A375', 'HELA']:
            use_forward_data = False

        #Dataset
        dataset = Dataset(
            forward_path=f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/lincs_processed/{data_type}/pt_data/real_lognorm/data_forward_{cell_line}.pt",
            backward_path=f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/lincs_processed/{data_type}/pt_data/real_lognorm/data_backward_{cell_line}.pt",
            splits_path=f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/PDgrapher/data/split_guada/{data_type}/{data_type}/{cell_line}/random/5fold/splits.pt"
        )

        edge_index = torch.load(f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/lincs_processed/{data_type}/pt_data/real_lognorm/edge_index_{cell_line}.pt")

        #Modify based on folder name
        paths = glob('/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/PDgrapher/output/{}/{}_corrected_pos_emb/*'.format(data_type, cell_line))
    
    elif data_type == 'genetic':
        if cell_line in ['ES2', 'BICR6', 'YAPC', 'AGS', 'U251MG', 'HT29', 'A375']:
            use_forward_data = False

        #Dataset
        dataset = Dataset(
            forward_path=f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/lincs_processed/{data_type}/pt_data/real_lognorm/data_forward_{cell_line}.pt",
            backward_path=f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/lincs_processed/{data_type}/pt_data/real_lognorm/data_backward_{cell_line}.pt",
            splits_path=f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/PDgrapher/data/split_guada/{data_type}/{data_type}/{cell_line}/random/5fold/splits.pt"
        )

        edge_index = torch.load(f"/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/lincs_processed/{data_type}/pt_data/real_lognorm/edge_index_{cell_line}.pt")

        #Modify based on folder name
        paths = glob('/n/holystore01/LABS/mzitnik_lab/Lab/xianglin226/PDgrapher/output/{}/{}_corrected_pos_emb/*'.format(data_type, cell_line))

    for path in paths:

        outdir = './results_pdgrapher/{}/val/'.format(data_type)
        path_model = path

        n_layers_gnn = int(path.split('/')[-1].split('_')[2])

        all_recall_at_1 = {'test':[], 'val':[]}
        all_recall_at_10 = {'test':[], 'val':[]}
        all_recall_at_100 = {'test':[], 'val':[]}
        all_recall_at_1000 = {'test':[], 'val':[]}
        all_perc_partially_accurate_predictions = {'test':[], 'val':[]}
        all_rankings = {'test':[], 'val':[]}

        for fold in range(1,6):
            #Instantiates model
            model = PDGrapher(edge_index, model_kwargs={"n_layers_nn": 1, "n_layers_gnn": n_layers_gnn, "num_vars": dataset.get_num_vars(),
                                                        },
                                        response_kwargs={'train': True},
                                        perturbation_kwargs={'train':True})

            # restore response prediction
            save_path = osp.join(path, '_fold_{}_response_prediction.pt'.format(fold))
            checkpoint = torch.load(save_path)
            model.response_prediction.load_state_dict(checkpoint["model_state_dict"])
            # restore Perturbation discovery
            save_path = osp.join(path, '_fold_{}_perturbation_discovery.pt'.format(fold))
            checkpoint = torch.load(save_path)
            model.perturbation_discovery.load_state_dict(checkpoint["model_state_dict"])

            #loads fold-specific dataset
            device = torch.device('cuda')

            dataset.prepare_fold(fold)

            thresholds = get_thresholds(dataset)
            thresholds = {k: v.to(device) if v is not None else v for k, v in thresholds.items()} 

            model.response_prediction.edge_index = model.response_prediction.edge_index.to(device)
            model.perturbation_discovery.edge_index = model.perturbation_discovery.edge_index.to(device)
            model.perturbation_discovery = model.perturbation_discovery.to(device)

            model.perturbation_discovery.eval()
            model.response_prediction.eval()

            (
                        train_loader_forward, train_loader_backward,
                        val_loader_forward, val_loader_backward,
                        test_loader_forward, test_loader_backward
                    ) = dataset.get_dataloaders(num_workers = 20, batch_size = 1)

            recall_at_1 = []
            recall_at_10 = []
            recall_at_100 = []
            recall_at_1000 = []
            perc_partially_accurate_predictions = []
            rankings = []
            n_non_zeros = 0

            ####ON TEST SET
            for data in test_loader_backward:
                pred_backward_m2 = model.perturbation_discovery(torch.concat([data.diseased.view(-1, 1).to(device), data.treated.view(-1, 1).to(device)], 1), data.batch.to(device), mutilate_mutations=data.mutations.to(device), threshold_input=thresholds)
                out = pred_backward_m2
                                
                num_nodes = int(data.num_nodes / len(torch.unique(data.batch)))            
                
                correct_interventions = set(torch.where(data.intervention.detach().cpu().view(-1, num_nodes))[1].tolist())
                predicted_interventions = torch.argsort(out.detach().cpu().view(-1, num_nodes), descending=True)[0, :].tolist()

                for ci in list(correct_interventions):
                    rankings.append(1 - (predicted_interventions.index(ci) / num_nodes))
                
                recall_at_1.append(len(set(predicted_interventions[:1]).intersection(correct_interventions)) / len(correct_interventions))
                recall_at_10.append(len(set(predicted_interventions[:10]).intersection(correct_interventions)) / len(correct_interventions))
                recall_at_100.append(len(set(predicted_interventions[:100]).intersection(correct_interventions)) / len(correct_interventions))
                recall_at_1000.append(len(set(predicted_interventions[:1000]).intersection(correct_interventions)) / len(correct_interventions))

                jaccards = len(correct_interventions.intersection(predicted_interventions[:len(correct_interventions)])) / len(correct_interventions.union(predicted_interventions))

                if jaccards != 0:
                    n_non_zeros += 1

            all_recall_at_1['test'].append(np.mean(recall_at_1))
            all_recall_at_10['test'].append(np.mean(recall_at_10))
            all_recall_at_100['test'].append(np.mean(recall_at_100))
            all_recall_at_1000['test'].append(np.mean(recall_at_1000))
            all_rankings['test'].append(np.mean(rankings))
            all_perc_partially_accurate_predictions['test'].append(100 * n_non_zeros/len(test_loader_backward))
            print('fold {}/5'.format(fold))

            ####ON VALIDATION SET
            recall_at_1 = []
            recall_at_10 = []
            recall_at_100 = []
            recall_at_1000 = []
            perc_partially_accurate_predictions = []
            rankings = []
            n_non_zeros = 0
            
            
            for data in val_loader_backward:
                pred_backward_m2 = model.perturbation_discovery(torch.concat([data.diseased.view(-1, 1).to(device), data.treated.view(-1, 1).to(device)], 1), data.batch.to(device), mutilate_mutations=data.mutations.to(device), threshold_input=thresholds)
                out = pred_backward_m2
                                
                num_nodes = int(data.num_nodes / len(torch.unique(data.batch)))
                
                
                correct_interventions = set(torch.where(data.intervention.detach().cpu().view(-1, num_nodes))[1].tolist())
                predicted_interventions = torch.argsort(out.detach().cpu().view(-1, num_nodes), descending=True)[0, :].tolist()

                for ci in list(correct_interventions):
                    rankings.append(1 - (predicted_interventions.index(ci) / num_nodes))
                
                recall_at_1.append(len(set(predicted_interventions[:1]).intersection(correct_interventions)) / len(correct_interventions))
                recall_at_10.append(len(set(predicted_interventions[:10]).intersection(correct_interventions)) / len(correct_interventions))
                recall_at_100.append(len(set(predicted_interventions[:100]).intersection(correct_interventions)) / len(correct_interventions))
                recall_at_1000.append(len(set(predicted_interventions[:1000]).intersection(correct_interventions)) / len(correct_interventions))


                jaccards = len(correct_interventions.intersection(predicted_interventions[:len(correct_interventions)])) / len(correct_interventions.union(predicted_interventions))

                if jaccards != 0:
                    n_non_zeros += 1


            all_recall_at_1['val'].append(np.mean(recall_at_1))
            all_recall_at_10['val'].append(np.mean(recall_at_10))
            all_recall_at_100['val'].append(np.mean(recall_at_100))
            all_recall_at_1000['val'].append(np.mean(recall_at_1000))
            all_rankings['val'].append(np.mean(rankings))
            all_perc_partially_accurate_predictions['val'].append(100 * n_non_zeros/len(test_loader_backward))
            print('fold {}/5'.format(fold))


        log = open(osp.join(outdir, f'{cell_line}_{n_layers_gnn}_final_performance_metrics_within.txt'), 'w')
        log.write('\n\nVALIDATION SET\n')
        log.write('recall@1: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_1['val']), np.std(all_recall_at_1['val'])))
        log.write('recall@10: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_10['val']), np.std(all_recall_at_10['val'])))
        log.write('recall@100: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_100['val']), np.std(all_recall_at_100['val'])))
        log.write('recall@1000: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_1000['val']), np.std(all_recall_at_1000['val'])))
        log.write('percentage of samples with partially accurate predictions: {:.2f}±{:.2f}\n'.format(np.mean(all_perc_partially_accurate_predictions['val']), np.std(all_perc_partially_accurate_predictions['val'])))
        log.write('ranking score: {:.2f}±{:.2f}\n'.format(np.mean(all_rankings['val']), np.std(all_rankings['val'])))

        log.write('--------------------------\n')
        log.write('All metric datapoints:\n')
        log.write('recall@1: {}\n'.format(all_recall_at_1['val']))
        log.write('recall@10: {}\n'.format(all_recall_at_10['val']))
        log.write('recall@100: {}\n'.format(all_recall_at_100['val']))
        log.write('recall@1000: {}\n'.format(all_recall_at_1000['val']))
        log.write('percentage of samples with partially accurate predictions: {}\n'.format(all_perc_partially_accurate_predictions['val']))
        log.write('ranking score: {}\n'.format(all_rankings['val']))

        log.write('\n\n----------------------\n')
        log.write('\n\nTEST SET\n')
        log.write('recall@1: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_1['test']), np.std(all_recall_at_1['test'])))
        log.write('recall@10: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_10['test']), np.std(all_recall_at_10['test'])))
        log.write('recall@100: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_100['test']), np.std(all_recall_at_100['test'])))
        log.write('recall@1000: {:.4f}±{:.4f}\n'.format(np.mean(all_recall_at_1000['test']), np.std(all_recall_at_1000['test'])))
        log.write('percentage of samples with partially accurate predictions: {:.2f}±{:.2f}\n'.format(np.mean(all_perc_partially_accurate_predictions['test']), np.std(all_perc_partially_accurate_predictions['test'])))
        log.write('ranking score: {:.2f}±{:.2f}\n'.format(np.mean(all_rankings['test']), np.std(all_rankings['test'])))

        log.write('--------------------------\n')
        log.write('All metric datapoints:\n')
        log.write('recall@1: {}\n'.format(all_recall_at_1['test']))
        log.write('recall@10: {}\n'.format(all_recall_at_10['test']))
        log.write('recall@100: {}\n'.format(all_recall_at_100['test']))
        log.write('recall@1000: {}\n'.format(all_recall_at_1000['test']))
        log.write('percentage of samples with partially accurate predictions: {}\n'.format(all_perc_partially_accurate_predictions['test']))
        log.write('ranking score: {}\n'.format(all_rankings['test']))

        log.close()