In [1]:
import numpy as np 
import pandas as pd
import SAGEnet.data 
import SAGEnet.tools
from SAGEnet.models import pSAGEnet
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import glob
import os

  warn(


#### Notes:   
-- Because of data protections, the data we used in our analyses (ROSMAP and GTEx WGS and expression data) is not provided for use in this github. Please see the paper for specifics on how to request that data.   
-- Because of this, we instead provide toy data in input_data/example data to demonstrate required data formats for this tutorial.  
-- In addition, we do not include the hg38 genome in this github due to its size -- but you can go to input_data/ and run download_genome.sh to save a copy of the hg38 genome to use.  
-- Finally, before running this notebook, please follow the steps under "installation" in the main READme to install dependencies and the SAGEnet package.

## Initializing a PersonalGenomeDataset

#### Change this path to where you saved the hg38 genome: 

In [2]:
hg38_file_path='/data/tuxm/project/Decipher-multi-modality/data/genome/hg38.fa'

#### Change this path to the location of your SAGEnet repo: 

In [3]:
base_repo_path='/homes/gws/aspiro17/SAGEnet/'

#### Define these paths to the necessary input data (all provided in the github, you do not need to change these): 

In [4]:
example_vcf_file_path=f'{base_repo_path}input_data/example_data/example_vcf.vcf.gz'
tss_data_path=f'{base_repo_path}input_data/gene-ids-and-positions.tsv'
sample_names_path=f'{base_repo_path}input_data/example_data/example_individuals.csv'
expr_data_path=f'{base_repo_path}input_data/example_data/example_expression.csv'

#### Then, load your data: 
-- example_individuals is a list of sample names, as they appear in the VCF. For this toy data, we include three sample individuals.   
-- example_expression_data is a DataFrame of expression data indexed by gene names, with sample names as columns.  For this toy data, we use a single gene, "ENSG00000013573".  
-- gene_meta_info is a DataFrame of gene metadata containing the columns "chr", "tss", and "strand". 

In [5]:
example_individuals = np.loadtxt(sample_names_path,delimiter=',',dtype=str)
example_individuals

array(['SAMPLE1', 'SAMPLE2', 'SAMPLE3'], dtype='<U7')

In [6]:
example_expression_data = pd.read_csv(expr_data_path,index_col=0) 
example_expression_data

Unnamed: 0,SAMPLE1,SAMPLE2,SAMPLE3
ENSG00000013573,2.1,0.3,0.7


In [7]:
gene_meta_info = pd.read_csv(tss_data_path, sep='\t', index_col='region_id')
gene_meta_info

Unnamed: 0_level_0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand,pos
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
ENSG00000223972,DDX11L1,ENSG00000223972,1,11869,14409,+,11869,chr1,11869.0,11869,1,ENSG00000223972,+,11869
ENSG00000227232,WASH7P,ENSG00000227232,1,14404,29570,-,29570,chr1,29570.0,29570,1,ENSG00000227232,-,29570
ENSG00000278267,MIR6859-1,ENSG00000278267,1,17369,17436,-,17436,chr1,17436.0,17436,1,ENSG00000278267,-,17436
ENSG00000243485,MIR1302-2HG,ENSG00000243485,1,29554,31109,+,29554,chr1,29554.0,29554,1,ENSG00000243485,+,29554
ENSG00000284332,MIR1302-2,ENSG00000284332,1,30366,30503,+,30366,chr1,30366.0,30366,1,ENSG00000284332,+,30366
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000271254,AC240274.1,ENSG00000271254,KI270711.1,4612,29626,-,29626,,,29626,KI270711.1,ENSG00000271254,-,29626
ENSG00000275405,U1,ENSG00000275405,KI270713.1,21861,22024,-,22024,,,22024,KI270713.1,ENSG00000275405,-,22024
ENSG00000275987,U1,ENSG00000275987,KI270713.1,30437,30580,-,30580,,,30580,KI270713.1,ENSG00000275987,-,30580
ENSG00000277475,AC213203.1,ENSG00000277475,KI270713.1,31698,32528,-,32528,,,32528,KI270713.1,ENSG00000277475,-,32528


#### Select gene meta information for our example gene, "ENSG00000013573": 

In [10]:
example_gene_meta_info=gene_meta_info.loc[['ENSG00000013573']]
example_gene_meta_info

Unnamed: 0_level_0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand,pos
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
ENSG00000013573,DDX11,ENSG00000013573,12,31073845,31104791,+,31073845,chr12,31226779.0,31073845,12,ENSG00000013573,+,31073845


#### Now that we have our example_expression_data, example_gene_meta_info, vcf_file_path (which contains variant info for our example gene), and hg38_file_path, we can initialize a basic PersonalGenomeDataset:  

Important paramters to change to suit your needs:   
-- "contig_prefix" gives the string before chromosome number in VCF file. Since our example VCF presents the chromosome as "chr12", we used the input contig_prefix='chr' to PersonalGenomeDataset. If the format were instead "12", we would use the input contig_prefix=''. Default: ''.   
-- "split_expr" controls whether expression is to be decomposed into mean and difference from mean. If True, idx 1 of the expression output represents difference from mean expression. If False, idx 1 of the expression output represents personal gene expression (used to train non-contrastive model). Default: True.   
-- "input_len" controls sequence length.   
-- "expr_data_zscore" determines whether (if "split_expr" is True), z-scores will be used as difference from mean. If not, the difference used will be simply personal expression - mean expression. Default: True.   
-- "only_snps" determine if only SNPs (or if variants of any length) are inserted into the personal genome (by default, all variants are inserted). Default: False.   
-- "expr_data" must be provided if you want to output personal expression values. If it is not provided, zeros will be used in place of expression values. This can be useful if you are just evaluating a model, not training. Default: None.   
-- "maf_threshold" determines the threshold MAF for a variant to be inserted. If less than zero, all variants will be inserted. Default: -1.   

To initialize a PersonalGenomeDataset to use to train a non-contrastive model, set split_expr=False. 

For more insight into parameters and how to adapt PersonalGenomeDataset, please see the class documentation.   

In [11]:
personal_dataset = SAGEnet.data.PersonalGenomeDataset(metadata=example_gene_meta_info, vcf_file_path=example_vcf_file_path, hg38_file_path=hg38_file_path, sample_list=example_individuals, y_data=example_expression_data,contig_prefix='chr')

acceptable maf range: -1<maf<2
avg is mean


#### As we can see, the number of datapoints in the PersonalGenomeDataset is equal to the number of genes x the number of individuals (here, 1x3). Each item in the PersonalGenomeDataset is a tuple containing: 
-- One-hot-encoded tensor of genomic sequence of shape [2,8,40000]. This contains personal genomic sequence from each haplotype (maternal: [0,:4,:], paternal: [0,4:,:]) and reference sequence ([1,:,:]).  
-- Expression tensor containing [mean expression, personal difference from mean expression]  
-- Gene index  
-- Sample index  

Gene index and sample index are used to get correlation (across gene and across individual) metrics at epoch ends. 

In [12]:
for item in personal_dataset:
    print('new datapoint:')
    print(item[0].shape)
    print(item[1])
    print(item[2])
    print(item[3])
    print('---')

new datapoint:
torch.Size([2, 8, 40000])
tensor([1.0333, 1.0667])
0
0
---
new datapoint:
torch.Size([2, 8, 40000])
tensor([ 1.0333, -0.7333])
0
1
---
new datapoint:
torch.Size([2, 8, 40000])
tensor([ 1.0333, -0.3333])
0
2
---


#### There should be no lag time when iterating through the PersonalGenomeDataset. 
#### The process for creating a ReferenceGenomeDataset or VariantDataset is very similar, please see the READme in SAGEnet/ and the class documentations for details. 

## Initializing a pSAGEnet model 

#### A pSAGEnet model with default parameters can be initialized using: 

In [13]:
my_model = pSAGEnet()

Important paramters to change to suit your needs:   
-- many paramters ("input_length", "first_layer_kernel_number", "int_layers_kernel_number", etc.) control aspects of model architecture. Please see class documentation for details.   
-- "lam_ref" and "lam_diff" control the weights on each componenent of the loss function. "lam_ref" controls the weight on the 0 idx output (mean) and "lam_diff" controls the weight on the 1 idx output (difference).   
-- "split_expr": if False, model "difference" output (idx 1) is predicted straight from personal sequence (no intermediate subtraction with reference).  

To train a "non-contrastive" model, set lam_ref=0, split_expr=False (along with setting split_expr=False in the PersonalGenomeDataset).

For more insight into parameters and how to adapt pSAGEnet, please see the class documentation.   

#### The process for initializing a rSAGEnet model is very similar, please see class documentation for details. 

## Training a pSAGEnet model

#### The full code required to train a pSAGEnet model can be found in script/train_mdoel/train_model.py. For a simplified walk-through of keys sections, see below. 

#### Load in your training and validation individuals. The train/validation splits we use for ROSMAP are found in input_data/individual_sets/. 

In [14]:
sub_data_dir=f'{base_repo_path}input_data/individual_sets/'
train_subs = np.loadtxt(f'{sub_data_dir}ROSMAP/train_subs.csv',delimiter=',',dtype=str)
val_subs = np.loadtxt(f'{sub_data_dir}ROSMAP/val_subs.csv',delimiter=',',dtype=str)

#### Each of these is just a list of individuals (with names that match what's found in the VCF & the expression data): 

In [15]:
print(train_subs.shape)
print(val_subs.shape)
print(val_subs[:5])

(689,)
(85,)
['P10315029_P10315029' 'P35286551_P35286551' 'P84642424_P84642424'
 'P50401390_P50401390' 'P20249897_P20249897']


#### Get list of genes to use in training and validation. In this example, let's use the top 1000 genes based on the predixcan ranking. 

In [16]:
predixcan_res_path=f'{base_repo_path}results_data/predixcan/rosmap_pearson_corr.csv'
rand_genes=0 # set to 0 because we are using "top genes" not random genes
num_top_genes=1000 # based on predixcan ranking 
gene_idx_start=0 # use genes rank 0 -- rank 999 
gene_list = SAGEnet.tools.select_region_set(enet_path=predixcan_res_path, rand_regions=rand_genes,num_regions=num_top_genes,region_idx_start=gene_idx_start) 
print(gene_list.shape)

selecting regions from 0 to 1000
(1000,)


#### Split this list of genes into train and validation based on chromsome split (defined in SAGEnet.tools)

In [17]:
train_gene_list, val_gene_list , test_gene_list = SAGEnet.tools.get_train_val_test_genes(gene_list, tss_data_path=tss_data_path)
print(train_gene_list.shape)
print(val_gene_list.shape)
print(test_gene_list.shape)

selecting train/val/test gene sets based on chromosome split
(741,)
(118,)
(117,)


#### Get the TSS information associated with each gene set: 

In [18]:
train_gene_meta=gene_meta_info.loc[train_gene_list]
val_gene_meta=gene_meta_info.loc[val_gene_list]
train_gene_meta

Unnamed: 0_level_0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand,pos
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
ENSG00000144820,ADGRG7,ENSG00000144820,3,100609589,100695479,+,100609589,chr3,100328433.0,100609589,3,ENSG00000144820,+,100609589
ENSG00000128617,OPN1SW,ENSG00000128617,7,128772491,128775790,-,128775790,chr7,128415844.0,128775790,7,ENSG00000128617,-,128775790
ENSG00000198502,HLA-DRB5,ENSG00000198502,6,32517343,32530287,-,32530287,chr6,32498064.0,32530287,6,ENSG00000198502,-,32530287
ENSG00000147813,NAPRT,ENSG00000147813,8,143574785,143578649,-,143578649,chr8,144660819.0,143578649,8,ENSG00000147813,-,143578649
ENSG00000120675,DNAJC15,ENSG00000120675,13,43023203,43114224,+,43023203,chr13,43597339.0,43023203,13,ENSG00000120675,+,43023203
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000176393,RNPEP,ENSG00000176393,1,201982372,202006147,+,201982372,chr1,201951500.0,201982372,1,ENSG00000176393,+,201982372
ENSG00000219435,CATSPERZ,ENSG00000219435,11,64300391,64304770,+,64300391,chr11,64067863.0,64300391,11,ENSG00000219435,+,64300391
ENSG00000028203,VEZT,ENSG00000028203,12,95217746,95302790,+,95217746,chr12,95611522.0,95217746,12,ENSG00000028203,+,95217746
ENSG00000102858,MGRN1,ENSG00000102858,16,4616493,4690974,+,4616493,chr16,4666494.0,4616493,16,ENSG00000102858,+,4616493


#### Following the guidelines from earlier in the notebook, create a PersonalGenomeDataset and Dataloaders for train and validation: 
-- For data protection reasons described at the top of the notebook, the following cells cannot be run because they need the actual ROSMAP vcf_file_path and expr_data, which we cannot add to the github.   
-- train_dataset is train individuals, train genes. val_subs_dataset is val individuals, train genes. val_genes_dataset is val individuals, val genes. 

In [None]:
train_dataset = SAGEnet.data.PersonalGenomeDataset(metadata=train_gene_meta, vcf_file_path=vcf_file_path, hg38_file_path=hg38_file_path, sample_list=train_subs, y_data=expr_data)
val_subs_dataset = SAGEnet.data.PersonalGenomeDataset(metadata=train_gene_meta, vcf_file_path=vcf_file_path, hg38_file_path=hg38_file_path, sample_list=val_subs, y_data=expr_data)
val_genes_dataset = SAGEnet.data.PersonalGenomeDataset(metadata=val_gene_meta, vcf_file_path=vcf_file_path, hg38_file_path=hg38_file_path, sample_list=val_subs, y_data=expr_data)

train_dataloader = DataLoader(train_dataset,shuffle=True)
val_subs_dataloader = DataLoader(val_subs_dataset, shuffle=False)
val_genes_dataloader = DataLoader(val_genes_dataset,shuffle=False)

#### Set up for model training:

In [None]:
model_save_dir='test_model_save_dir/'

val_dataloaders=[val_subs_dataloader,val_genes_dataloader]
es = EarlyStopping(monitor="train_gene_val_sub_diff_loss/dataloader_idx_0", patience=5,mode='min')
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# used to save every model epoch 
all_epoch_checkpoint_callback = ModelCheckpoint(
    dirpath=model_save_dir,  
    filename="{epoch}",  
    save_top_k=-1,  
    every_n_epochs=1,
    save_last=False
)

# save last ckpt to be able to resume model training if job is killed 
last_checkpoint_callback = ModelCheckpoint(
    dirpath=model_save_dir,
    filename="last",     
    save_top_k=0,        
    every_n_train_steps=300,  
    save_last=True      
)

ckpt_list = [all_epoch_checkpoint_callback,last_checkpoint_callback]

#### Resume if an earlier training run was interrupted: 

In [None]:
if glob.glob(os.path.join(model_save_dir, "*.ckpt"))!=[]:
    last_checkpoint = model_save_dir + "/last.ckpt"
else:
    last_checkpoint = None

#### Set up a trainer (in our case, using distributed training across multiple nodes):
-- To train on a single GPU, set device equal to the integer label of the GPU you want to use. You can also edit the "accelerator" argument to use with only CPU. 

In [None]:
wandb_logger = WandbLogger(project='test_project_name', name='test_job_name', id='test_job_name', resume="allow")
device=None # to indicate mulit-node training 
num_nodes=8
max_epochs=10

trainer = pl.Trainer(
    accelerator="gpu", 
    devices=[int(device)] if device else 1, 
    num_nodes=num_nodes, 
    strategy="ddp" if not device else 'auto', 
    callbacks=ckpt_list, 
    max_epochs=max_epochs, 
    benchmark=False, 
    profiler='simple', 
    gradient_clip_val=1, 
    logger=wandb_logger, 
    log_every_n_steps=10)

#### Optionally initialize p-SAGE-net from a pre-trained model (in our case, r-SAGE-net). 

In [None]:
ref_model_ckpt_path='reference_model.ckpt'
print(f'loading rSAGEnet weights from {ref_model_ckpt_path} into pSAGEnet model')
my_model = SAGEnet.tools.init_model_from_ref(my_model, ref_model_ckpt_path)

#### Train! 

In [None]:
if last_checkpoint is None:
    print('fitting model')
    trainer.fit(my_model, train_dataloader, val_dataloaders=val_dataloaders)
else:
    print(f'fitting model from ckpt={last_checkpoint}')
    trainer.fit(my_model, train_dataloader, val_dataloaders=val_dataloaders, ckpt_path=last_checkpoint)

#### For a completely runnable example, see GEUVADIS_example.ipynb. Also see the READmes in each directory for more examples of how to use the repo, and see the code documentation for an explanation of each parameter.