In [1]:
import numpy as np 
import pandas as pd
import SAGEnet.data 
import SAGEnet.tools
from SAGEnet.models import pSAGEnet
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import glob
import os
import pysam

#### In this notebook, we go through the process of initializing a PersonalGenomeDataset and using this to train a p-SAGE-net model with the GEUVADIS dataset (publicly available).

#### Before you run this notebook: 
-- follow the steps under "installation" in the main READme to install dependencies and the SAGEnet package.

-- download the pre-processed GEUVADIS expression data 'tpm_pca_annot.csv.gz' from Rastogi et al.: https://github.com/ni-lab/finetuning-enformer/tree/main/process_geuvadis_data/tpm  

-- download the GEUVADIS VCF for chromosome 21 (to use as an exapmle) from here: https://www.ebi.ac.uk/biostudies/arrayexpress/studies/E-GEUV-1. The file is called 'GEUVADIS.chr21.PH1PH2_465.IMPFRQFILT_BIALLELIC_PH.annotv2.genotypes.vcf.gz'  

-- run the following lines of code to be able to use pysam with this VCF file:   
```bash
gunzip GEUVADIS.chr21.PH1PH2_465.IMPFRQFILT_BIALLELIC_PH.annotv2.genotypes.vcf.gz # decompress 
bgzip GEUVADIS.chr21.PH1PH2_465.IMPFRQFILT_BIALLELIC_PH.annotv2.genotypes.vcf # recompress using bgzip
tabix -p vcf GEUVADIS.chr21.PH1PH2_465.IMPFRQFILT_BIALLELIC_PH.annotv2.genotypes.vcf.gz # index   
```

-- download the hg19 genome (to be consisent with GEUVADIS variant calls):    
```bash
curl -O https://hgdownload.soe.ucsc.edu/goldenPath/hg19/bigZips/hg19.fa.gz  
gunzip hg19.fa.gz
```

#### Change these paths based on your SAGEnet repo locaiton, where you have saved the files downloaded in the previous step, and where you want to save your model results. 

In [4]:
base_repo_path='/homes/gws/aspiro17/SAGEnet/'

hg19_path='/data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/hg19.fa'
expr_path='/data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/tpm_pca_annot.csv.gz'
vcf_file_path='/data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/GEUVADIS.vcf.gz'

model_save_dir='data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/res/'
os.makedirs(model_save_dir, exist_ok=True)

#### Use the tss_data_path provided in this github. Since the GEUVADIS variant calls are with respect to hg19 (not hg38, as was the case with our ROSMAP and GTEx data), update the tss and chr columns in gene_meta_info to reflect this. 

In [5]:
tss_data_path=f'{base_repo_path}input_data/gene-ids-and-positions.tsv'
gene_meta_info = pd.read_csv(tss_data_path, sep='\t')
gene_meta_info['chr'] = gene_meta_info['chr_hg19'].str.replace('chr', '', regex=False)
gene_meta_info['tss'] = pd.to_numeric(gene_meta_info['tss_hg19'], errors='coerce').astype('Int64')
gene_meta_info

Unnamed: 0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand
0,DDX11L1,ENSG00000223972,1,11869,14409,+,11869,chr1,11869.0,11869,1,ENSG00000223972,+
1,WASH7P,ENSG00000227232,1,14404,29570,-,29570,chr1,29570.0,29570,1,ENSG00000227232,-
2,MIR6859-1,ENSG00000278267,1,17369,17436,-,17436,chr1,17436.0,17436,1,ENSG00000278267,-
3,MIR1302-2HG,ENSG00000243485,1,29554,31109,+,29554,chr1,29554.0,29554,1,ENSG00000243485,+
4,MIR1302-2,ENSG00000284332,1,30366,30503,+,30366,chr1,30366.0,30366,1,ENSG00000284332,+
...,...,...,...,...,...,...,...,...,...,...,...,...,...
58297,AC240274.1,ENSG00000271254,KI270711.1,4612,29626,-,29626,,,,,ENSG00000271254,-
58298,U1,ENSG00000275405,KI270713.1,21861,22024,-,22024,,,,,ENSG00000275405,-
58299,U1,ENSG00000275987,KI270713.1,30437,30580,-,30580,,,,,ENSG00000275987,-
58300,AC213203.1,ENSG00000277475,KI270713.1,31698,32528,-,32528,,,,,ENSG00000277475,-


#### Load the preprocessed expression data:

In [6]:
orig_expr_df = pd.read_csv(expr_path)
orig_expr_df

Unnamed: 0,TargetID,Gene_Symbol,Chr,Coord,HG00096,HG00097,HG00099,HG00100,HG00101,HG00102,...,NA20828,stable_id,gencode_v12_gene_name,our_gene_name,EUR_eGene,YRI_eGene,top_EUR_eqtl_rsid,top_YRI_eqtl_rsid,top_EUR_eqtl_distance,top_YRI_eqtl_distance
0,ENSG00000257527.1,ENSG00000257527.1,16,18505708,0.921724,0.848552,0.475993,0.318905,1.814351,0.845431,...,0.333369,ENSG00000257527,rp11-1212a22.6,,False,False,,,,
1,ENSG00000151503.7,ENSG00000151503.7,11,134095348,39.400926,35.462012,50.285482,45.931011,37.510652,38.106420,...,50.635764,ENSG00000151503,ncapd3,,False,False,,,,
2,ENSG00000254681.2,ENSG00000254681.2,16,18495797,8.417257,11.753510,10.750892,9.576572,11.132875,11.370822,...,7.895669,ENSG00000254681,rp11-1212a22.3,,False,False,,,,
3,ENSG00000228477.1,ENSG00000228477.1,1,40428352,267.057590,214.369516,287.589232,394.840340,279.594044,201.982560,...,177.177176,ENSG00000228477,rp3-342p20.2,,False,False,,,,
4,ENSG00000159733.9,ENSG00000159733.9,4,2420390,0.227618,0.203908,0.712240,0.433100,0.631554,0.564494,...,0.945074,ENSG00000159733,zfyve28,zfyve28,True,False,rs4974687,,9347.0,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23717,ENSG00000137709.4,ENSG00000137709.4,11,120107349,-0.438372,1.930072,0.681640,1.860306,1.169887,1.314488,...,0.944036,ENSG00000137709,pou2f3,,False,False,,,,
23718,ENSG00000006007.7,ENSG00000006007.7,16,19533467,19.452829,20.194346,16.056744,15.899420,16.054821,21.502305,...,20.063232,ENSG00000006007,gde1,,False,False,,,,
23719,ENSG00000172297.6,ENSG00000172297.6,Y,27600708,0.166371,0.139706,0.150852,0.218521,0.039107,0.129921,...,0.111516,ENSG00000172297,golga2p3y,,False,False,,,,
23720,ENSG00000125266.5,ENSG00000125266.5,13,107187462,0.139229,0.213003,0.287849,0.086688,0.251640,0.245102,...,0.079666,ENSG00000125266,efnb2,,False,False,,,,


#### Select some example individuals from the GEUVADIS dataset to use as a train, validaiton, and test set (make sure these sample names also exist in the expression data): 

In [9]:
with pysam.VariantFile(vcf_file_path) as vcf:
    all_samps = list(vcf.header.samples)

samps_in_expr_data = [samp for samp in all_samps if samp in orig_expr_df.columns]
print(f'total n samps: {len(samps_in_expr_data)}')

# shuffle and split into train, validation, and test
np.random.seed(12)
shuffled_indices = np.random.permutation(len(samps_in_expr_data))
shuffled_individs = np.array(samps_in_expr_data)[shuffled_indices]  # Convert to numpy array for indexing

# Split data
train_samps = shuffled_individs[:10]
val_samps = shuffled_individs[10:15]
test_samps = shuffled_individs[15:20]
all_samps = np.concatenate((train_samps,val_samps,test_samps))

print(f'train samps: {train_samps}')
print(f'val samps: {val_samps}')
print(f'test samps: {test_samps}')

total n samps: 462
train samps: ['NA20803' 'NA12830' 'HG00308' 'NA11892' 'NA12249' 'NA20804' 'NA12383'
 'NA12005' 'NA18498' 'NA19204']
val samps: ['NA18933' 'NA20536' 'HG00366' 'HG00277' 'NA20514']
test samps: ['HG00325' 'NA12272' 'HG00100' 'HG00332' 'NA20506']


#### Select an example gene set to use for model training. Make sure that these genes are on chromosome 21 (this is the example VCF we are using). Usually we would split our gene set into train, validation, and test by chromosome, but since all of these example genes are from chromosome 21, we split randomly. 

In [10]:
chr21_genes = gene_meta_info[gene_meta_info['chr']=='21']['gene_id'].values
chr21_genes_in_expr_data = np.intersect1d(orig_expr_df['stable_id'], chr21_genes)

train_genes=chr21_genes_in_expr_data[:10]
val_genes=chr21_genes_in_expr_data[10:15]
test_genes=chr21_genes_in_expr_data[15:20]

print(train_genes)

['ENSG00000141956' 'ENSG00000141959' 'ENSG00000142149' 'ENSG00000142156'
 'ENSG00000142166' 'ENSG00000142168' 'ENSG00000142173' 'ENSG00000142178'
 'ENSG00000142185' 'ENSG00000142188']


#### Put the expression data into the format required by PersonalGenomeDataset (indexed by gene IDs, column values are sampele names):

In [11]:
expr_df = orig_expr_df[all_samps]
expr_df.index=orig_expr_df['stable_id']
expr_df

Unnamed: 0_level_0,NA20803,NA12830,HG00308,NA11892,NA12249,NA20804,NA12383,NA12005,NA18498,NA19204,NA18933,NA20536,HG00366,HG00277,NA20514,HG00325,NA12272,HG00100,HG00332,NA20506
stable_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1
ENSG00000257527,0.435554,0.328876,0.542901,0.270288,0.906700,0.191970,0.380016,0.272329,0.597474,0.205384,0.619697,0.871186,0.538308,0.950538,0.301588,0.411008,0.557917,0.318905,0.678781,0.402872
ENSG00000151503,36.769189,32.191375,44.727513,44.521700,22.168308,26.766314,44.226758,29.253736,44.163776,31.429208,20.655892,47.044519,44.118198,37.670312,43.087610,48.363630,44.780661,45.931011,37.577450,38.544274
ENSG00000254681,7.911369,11.156102,6.884866,10.338751,3.927643,10.582101,9.203352,10.578331,8.509440,20.849868,17.335315,8.645044,7.983753,8.625198,6.845965,8.974603,11.913209,9.576572,16.101424,8.348114
ENSG00000228477,221.705080,251.016985,267.960374,303.305799,234.466966,221.156782,142.271447,328.554436,293.008224,222.329632,210.143495,301.558543,272.610373,233.220437,327.936792,279.922783,168.551674,394.840340,266.975709,263.083702
ENSG00000159733,0.445662,1.788335,0.662519,1.522028,1.977999,1.391980,1.540233,0.664821,1.555032,0.680732,0.167948,0.627373,0.862779,1.957021,1.285814,0.560031,0.981346,0.433100,0.355617,1.587291
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000137709,0.962992,2.417636,2.572387,1.517089,2.405431,1.521566,0.596592,2.520589,1.161636,1.646100,0.983438,0.377228,1.894994,-0.649883,2.475048,2.395376,1.771172,1.860306,0.496738,0.938476
ENSG00000006007,17.547194,19.545290,18.314184,20.772988,18.689712,20.295341,18.497651,16.902593,18.332126,17.768566,17.696008,21.477662,30.556087,16.989199,17.506841,16.702770,15.528529,15.899420,21.537412,19.416601
ENSG00000172297,0.124084,0.173294,0.203024,0.117988,0.101695,0.098348,0.200327,0.120292,0.074053,0.006652,0.112481,0.123586,0.271811,0.300667,0.105189,0.118355,0.135799,0.218521,0.104980,0.117371
ENSG00000125266,0.119904,0.015309,0.189959,0.295348,0.184744,0.136891,0.215636,0.123038,0.241018,0.058625,0.391730,0.127234,0.171226,0.073759,0.203049,0.366541,0.200787,0.086688,0.216478,0.113176


#### Select train and validation gene meta information:

In [12]:
train_gene_meta = gene_meta_info[gene_meta_info['gene_id'].isin(train_genes)]
val_gene_meta = gene_meta_info[gene_meta_info['gene_id'].isin(val_genes)]
train_gene_meta

Unnamed: 0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand
57800,SOD1,ENSG00000142168,21,31659622,31668931,+,31659622,chr21,33031935.0,33031935,21,ENSG00000142168,+
57807,HUNK,ENSG00000142149,21,31873315,32044633,+,31873315,chr21,33245628.0,33245628,21,ENSG00000142149,+
57849,IFNAR1,ENSG00000142166,21,33324477,33359862,+,33324477,chr21,34696782.0,34696782,21,ENSG00000142166,+
57852,TMEM50B,ENSG00000142188,21,33432485,33480011,-,33480011,chr21,34852318.0,34852318,21,ENSG00000142188,-
58022,PRDM15,ENSG00000141956,21,41798225,41879482,-,41879482,chr21,43299591.0,43299591,21,ENSG00000141956,-
58066,SIK1,ENSG00000142178,21,43414515,43427128,-,43427128,chr21,44847008.0,44847008,21,ENSG00000142178,-
58099,PFKL,ENSG00000141959,21,44300051,44327376,+,44300051,chr21,45719934.0,45719934,21,ENSG00000141959,+
58105,TRPM2,ENSG00000142185,21,44350163,44443081,+,44350163,chr21,45770046.0,45770046,21,ENSG00000142185,+
58178,COL6A1,ENSG00000142156,21,45981737,46005050,+,45981737,chr21,47401651.0,47401651,21,ENSG00000142156,+
58184,COL6A2,ENSG00000142173,21,46098097,46132849,+,46098097,chr21,47518011.0,47518011,21,ENSG00000142173,+


#### Initialize training and validation PersonalGenomeDatasets. For more information on how to adjust the parameters to PersonalGenomeDataset to suit your needs, see https://github.com/mostafavilabuw/SAGEnet/blob/main/example_usage.ipynb and the class documentation. 

In [13]:
train_dataset = SAGEnet.data.PersonalGenomeDataset(gene_metadata=train_gene_meta, vcf_file_path=vcf_file_path, hg38_file_path=hg19_path, sample_list=train_samps, expr_data=expr_df,contig_prefix='')
val_subs_dataset = SAGEnet.data.PersonalGenomeDataset(gene_metadata=train_gene_meta, vcf_file_path=vcf_file_path, hg38_file_path=hg19_path, sample_list=train_samps, expr_data=expr_df,contig_prefix='')
val_genes_dataset = SAGEnet.data.PersonalGenomeDataset(gene_metadata=val_gene_meta, vcf_file_path=vcf_file_path, hg38_file_path=hg19_path, sample_list=val_samps, expr_data=expr_df,contig_prefix='')

train_dataloader = DataLoader(train_dataset,shuffle=True,num_workers=1)
val_subs_dataloader = DataLoader(val_subs_dataset, shuffle=False,num_workers=1)
val_genes_dataloader = DataLoader(val_genes_dataset,shuffle=False,num_workers=1)

#### Initialize a p-SAGE-net model. For more information on how to adjust the parameters to PersonalGenomeDataset to suit your needs, see https://github.com/mostafavilabuw/SAGEnet/blob/main/example_usage.ipynb and the class documentation.

In [14]:
my_model = pSAGEnet(model_save_dir=model_save_dir)

#### Set up for model training:

In [15]:
val_dataloaders=[val_subs_dataloader,val_genes_dataloader]
es = EarlyStopping(monitor="train_gene_val_sub_diff_loss/dataloader_idx_0", patience=5,mode='min')
lr_monitor = LearningRateMonitor(logging_interval='epoch')

# used to save every model epoch 
all_epoch_checkpoint_callback = ModelCheckpoint(
    dirpath=model_save_dir,  
    filename="{epoch}",  
    save_top_k=-1,  
    every_n_epochs=1,
    save_last=False
)

# save last ckpt to be able to resume model training if job is killed 
last_checkpoint_callback = ModelCheckpoint(
    dirpath=model_save_dir,
    filename="last",     
    save_top_k=0,        
    every_n_train_steps=300,  
    save_last=True      
)

ckpt_list = [all_epoch_checkpoint_callback,last_checkpoint_callback]

#### Set up a trainer: 

In [16]:
wandb_logger = WandbLogger(project='test_project_name', name='test_job_name', id='test_job_name', resume="allow") # change these based on your logging preferences 
device=2 # indicates which GPU
num_nodes=1 # single node training 
max_epochs=1 

trainer = pl.Trainer(
    accelerator="gpu", 
    devices=[int(device)] if device else 1, 
    num_nodes=num_nodes, 
    strategy="ddp" if not device else 'auto', 
    callbacks=ckpt_list, 
    max_epochs=max_epochs, 
    benchmark=False, 
    profiler='simple', 
    gradient_clip_val=1, 
    logger=wandb_logger, 
    log_every_n_steps=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


#### Train! 

In [17]:
trainer.fit(my_model, train_dataloader, val_dataloaders=val_dataloaders)

You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /homes/gws/aspiro17/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mspiroannae[0m ([33mspiroannae-university-of-washington[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name               | Type       | Params | Mode 
----------------------------------------------------------
0 | conv0              | Sequential | 36.9 K | train
1 | convlayers         | ModuleList | 2.5 M  | train
2 | dilated_convlayers | ModuleList | 0      | train
3 | fc0                | Sequential | 65.8 K | train
4 | fclayers           | ModuleList | 65.8 K | train
5 | diff_fclayers      | ModuleList | 65.8 K | train
6 | diff_out           | Sequential | 257    | train
7 | ref_out            | Sequential | 257    | train
----------------------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.811    Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/homes/gws/aspiro17/miniconda3/envs/SAGEnet/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

  return F.conv1d(input, weight, bias, self.stride,


                                                                           

  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
/homes/gws/aspiro17/miniconda3/envs/SAGEnet/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 0: 100%|██████████| 100/100 [00:35<00:00,  2.83it/s, v_num=name]

  c /= stddev[:, None]
`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|██████████| 100/100 [00:36<00:00,  2.77it/s, v_num=name]


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                           	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                            	|  -              	|  6474         