In [1]:
import numpy as np 
import pandas as pd
import SAGEnet.data 
import SAGEnet.tools
from SAGEnet.models import pSAGEnet
from torch.utils.data import DataLoader
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import glob
import os
import pysam

  warn(


#### In this notebook, we go through the process of initializing a ReferenceGenomeDataset and using this to train a r-SAGE-net model with the GEUVADIS dataset (publicly available).

#### Before you run this notebook: 
-- follow the steps under "installation" in the main READme to install dependencies and the SAGEnet package.

-- download the pre-processed GEUVADIS expression data 'tpm_pca_annot.csv.gz' from Rastogi et al.: https://github.com/ni-lab/finetuning-enformer/tree/main/process_geuvadis_data/tpm  

-- download the hg19 genome (to be consisent with GEUVADIS variant calls):    
```bash
curl -O https://hgdownload.soe.ucsc.edu/goldenPath/hg19/bigZips/hg19.fa.gz  
gunzip hg19.fa.gz
```

#### Change these paths based on your SAGEnet repo locaiton, where you have saved the files downloaded in the previous step, and where you want to save your model results. 

In [3]:
base_repo_path='/homes/gws/aspiro17/SAGEnet/'

hg19_path='/data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/hg19.fa'
expr_path='/data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/corrected_log_tpm.annot.csv.gz'

model_save_dir='data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/res/'
os.makedirs(model_save_dir, exist_ok=True)

#### Use the tss_data_path provided in this github. Since the GEUVADIS variant calls are with respect to hg19 (not hg38, as was the case with our ROSMAP and GTEx data), update the tss and chr columns in gene_meta_info to reflect this. 

In [4]:
tss_data_path=f'{base_repo_path}input_data/gene-ids-and-positions.tsv'
gene_meta_info = pd.read_csv(tss_data_path, sep="\t",index_col='region_id')
gene_meta_info['chr'] = gene_meta_info['chr_hg19'].str.replace('chr', '', regex=False)
gene_meta_info['tss'] = pd.to_numeric(gene_meta_info['tss_hg19'], errors='coerce').astype('Int64')
gene_meta_info

Unnamed: 0_level_0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand,pos
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
ENSG00000223972,DDX11L1,ENSG00000223972,1,11869,14409,+,11869,chr1,11869.0,11869,1,ENSG00000223972,+,11869
ENSG00000227232,WASH7P,ENSG00000227232,1,14404,29570,-,29570,chr1,29570.0,29570,1,ENSG00000227232,-,29570
ENSG00000278267,MIR6859-1,ENSG00000278267,1,17369,17436,-,17436,chr1,17436.0,17436,1,ENSG00000278267,-,17436
ENSG00000243485,MIR1302-2HG,ENSG00000243485,1,29554,31109,+,29554,chr1,29554.0,29554,1,ENSG00000243485,+,29554
ENSG00000284332,MIR1302-2,ENSG00000284332,1,30366,30503,+,30366,chr1,30366.0,30366,1,ENSG00000284332,+,30366
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000271254,AC240274.1,ENSG00000271254,KI270711.1,4612,29626,-,29626,,,,,ENSG00000271254,-,29626
ENSG00000275405,U1,ENSG00000275405,KI270713.1,21861,22024,-,22024,,,,,ENSG00000275405,-,22024
ENSG00000275987,U1,ENSG00000275987,KI270713.1,30437,30580,-,30580,,,,,ENSG00000275987,-,30580
ENSG00000277475,AC213203.1,ENSG00000277475,KI270713.1,31698,32528,-,32528,,,,,ENSG00000277475,-,32528


#### Load the preprocessed expression data:

In [5]:
orig_expr_df = pd.read_csv(expr_path)
orig_expr_df

Unnamed: 0,TargetID,Gene_Symbol,Chr,Coord,HG00096,HG00097,HG00099,HG00100,HG00101,HG00102,...,NA20828,stable_id,gencode_v12_gene_name,our_gene_name,EUR_eGene,YRI_eGene,top_EUR_eqtl_rsid,top_YRI_eqtl_rsid,top_EUR_eqtl_distance,top_YRI_eqtl_distance
0,ENSG00000257527.1,ENSG00000257527.1,16,18505708,-0.057361,-0.313160,-0.684395,-1.209085,-0.012644,-0.270612,...,-1.127696,ENSG00000257527,rp11-1212a22.6,,False,False,,,,
1,ENSG00000151503.7,ENSG00000151503.7,11,134095348,3.653703,3.555238,3.969966,3.832266,3.620463,3.682108,...,3.984807,ENSG00000151503,ncapd3,,False,False,,,,
2,ENSG00000254681.2,ENSG00000254681.2,16,18495797,2.088882,2.326419,2.128807,2.199625,2.331783,2.627187,...,1.565265,ENSG00000254681,rp11-1212a22.3,,False,False,,,,
3,ENSG00000228477.1,ENSG00000228477.1,1,40428352,5.579332,5.352685,5.758683,6.045576,5.563191,5.176924,...,5.187391,ENSG00000228477,rp3-342p20.2,,False,False,,,,
4,ENSG00000159733.9,ENSG00000159733.9,4,2420390,-0.984586,-1.124469,-0.433654,-1.025796,-0.705150,-1.333362,...,0.044033,ENSG00000159733,zfyve28,zfyve28,True,False,rs4974687,,9347.0,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
23717,ENSG00000137709.4,ENSG00000137709.4,11,120107349,-1.573193,0.184419,-1.035174,0.130528,-0.713233,0.147750,...,-0.414130,ENSG00000137709,pou2f3,,False,False,,,,
23718,ENSG00000006007.7,ENSG00000006007.7,16,19533467,2.938774,2.976678,2.681771,2.732348,2.782939,3.024868,...,3.017779,ENSG00000006007,gde1,,False,False,,,,
23719,ENSG00000172297.6,ENSG00000172297.6,Y,27600708,-1.760798,-1.955373,-1.859498,-1.676782,-4.246593,-2.011718,...,-2.351444,ENSG00000172297,golga2p3y,,False,False,,,,
23720,ENSG00000125266.5,ENSG00000125266.5,13,107187462,-1.917913,-1.706478,-1.403048,-2.520733,-1.504605,-1.454579,...,-2.262784,ENSG00000125266,efnb2,,False,False,,,,


#### Randomly select 80% of individuals as training individuals 

In [33]:
all_expr_data_individauls=orig_expr_df.columns[4:-9]
print(all_expr_data_individauls)
print(all_expr_data_individauls.shape)

Index(['HG00096', 'HG00097', 'HG00099', 'HG00100', 'HG00101', 'HG00102',
       'HG00103', 'HG00104', 'HG00105', 'HG00106',
       ...
       'NA20810', 'NA20811', 'NA20812', 'NA20813', 'NA20814', 'NA20815',
       'NA20816', 'NA20819', 'NA20826', 'NA20828'],
      dtype='object', length=462)
(462,)


In [35]:
n_train_individuals = int(.8*len(all_expr_data_individauls))
np.random.seed(12)
shuffled_indices = np.random.permutation(len(all_expr_data_individauls))
shuffled_individs = np.array(all_expr_data_individauls)[shuffled_indices]  # convert to numpy array for indexing
train_individuals = shuffled_individs[:n_train_individuals]
print(train_individuals)
print(len(train_individuals))

['NA20803' 'NA18499' 'HG00103' 'NA12156' 'NA12489' 'NA20804' 'NA12776'
 'HG00120' 'NA18867' 'NA19210' 'NA19114' 'NA20536' 'NA10847' 'HG00342'
 'HG00139' 'HG00367' 'NA12546' 'HG00145' 'HG00377' 'NA20507' 'HG00366'
 'NA11831' 'HG00097' 'NA18858' 'NA12775' 'HG00109' 'NA18487' 'NA18489'
 'NA20520' 'HG00244' 'HG00115' 'HG00136' 'NA12154' 'NA19152' 'HG00344'
 'HG00128' 'NA20530' 'HG00252' 'HG00134' 'NA20799' 'HG00112' 'NA19138'
 'HG00313' 'HG00245' 'HG00364' 'NA20771' 'NA20800' 'NA12827' 'NA20540'
 'NA18909' 'NA20787' 'NA20544' 'NA19236' 'HG00182' 'NA12872' 'HG00309'
 'NA20797' 'NA20801' 'NA19209' 'HG00096' 'HG00355' 'NA20816' 'NA12751'
 'HG00138' 'HG00246' 'NA20766' 'HG00325' 'HG00239' 'NA12716' 'NA19206'
 'HG00235' 'NA19185' 'HG00110' 'NA19198' 'NA20815' 'NA12399' 'HG00155'
 'NA12005' 'HG00365' 'HG00327' 'NA18917' 'NA18502' 'NA19171' 'NA20516'
 'NA19117' 'HG00341' 'NA19113' 'NA11892' 'HG00152' 'NA20802' 'NA19137'
 'HG00181' 'NA18908' 'NA18510' 'HG00142' 'NA12778' 'HG00330' 'NA20531'
 'NA20

#### Put the expression data into the format required by ReferenceGenomeDataset (indexed by gene IDs, column values are sample names):

In [37]:
expr_df = orig_expr_df[all_expr_data_individauls]
expr_df.index=orig_expr_df['stable_id']
expr_df

Unnamed: 0_level_0,HG00096,HG00097,HG00099,HG00100,HG00101,HG00102,HG00103,HG00104,HG00105,HG00106,...,NA20810,NA20811,NA20812,NA20813,NA20814,NA20815,NA20816,NA20819,NA20826,NA20828
stable_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ENSG00000257527,-0.057361,-0.313160,-0.684395,-1.209085,-0.012644,-0.270612,-0.930251,-1.402520,-0.765358,-0.563563,...,-0.884948,0.191873,-0.661033,-0.629318,-1.685557,-1.953357,-0.282678,-0.768272,-0.785180,-1.127696
ENSG00000151503,3.653703,3.555238,3.969966,3.832266,3.620463,3.682108,3.862410,3.840910,3.058031,3.774548,...,3.378918,3.867792,3.782963,3.834467,3.758226,3.751885,3.677478,3.152041,4.041150,3.984807
ENSG00000254681,2.088882,2.326419,2.128807,2.199625,2.331783,2.627187,1.608311,2.100229,1.791579,1.899732,...,1.643761,2.015555,1.730166,2.568180,1.656951,1.329385,1.048834,1.974319,2.417701,1.565265
ENSG00000228477,5.579332,5.352685,5.758683,6.045576,5.563191,5.176924,5.579479,5.563419,5.659788,5.583868,...,5.536799,5.675933,5.688522,6.027998,5.421165,5.949244,5.711467,5.544912,5.555267,5.187391
ENSG00000159733,-0.984586,-1.124469,-0.433654,-1.025796,-0.705150,-1.333362,-0.532541,-0.944133,-1.028620,-0.693211,...,-0.430091,-1.074794,-0.169165,-0.077494,0.551200,0.226508,0.234683,-0.253473,-0.135414,0.044033
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000137709,-1.573193,0.184419,-1.035174,0.130528,-0.713233,0.147750,-1.028686,-0.111465,-0.286314,-0.138425,...,-0.693062,-1.475142,-0.073097,0.876906,-0.347311,-0.161878,-0.696457,-0.597816,0.043586,-0.414130
ENSG00000006007,2.938774,2.976678,2.681771,2.732348,2.782939,3.024868,2.889567,2.972804,2.820010,2.894324,...,2.922630,3.057945,3.437837,2.909016,2.904415,2.767006,2.886100,2.523124,3.025187,3.017779
ENSG00000172297,-1.760798,-1.955373,-1.859498,-1.676782,-4.246593,-2.011718,-1.821764,-4.070596,-2.487116,-1.982339,...,-2.256593,-4.199391,-2.110590,-2.359366,-1.676287,-2.269958,-1.997345,-2.070705,-1.583471,-2.351444
ENSG00000125266,-1.917913,-1.706478,-1.403048,-2.520733,-1.504605,-1.454579,-1.827309,-1.193953,-1.797129,-2.323232,...,-1.162982,-2.053384,-1.186172,-2.471141,-1.972355,-2.437253,-1.811572,-2.152038,-2.628604,-2.262784


In [None]:
avg_y_data = np.array(np.median(np.array(expr_df.loc['ENSG00000125266',self.train_subs])))


#### Restrict genes to protein-coding genes that are present in the expression data and in the metadata

In [19]:
protein_coding_gene_list = np.loadtxt(f'{base_repo_path}input_data/protein_coding_genes.csv',delimiter=',',dtype=str)
protein_coding_gene_list.shape

(19805,)

In [20]:
use_gene_list  = np.intersect1d(np.intersect1d(protein_coding_gene_list, orig_expr_df['stable_id']), gene_meta_info.index)
use_gene_list.shape

(14769,)

#### Split by chromosome to get train, validaiton, and test gene sets 

In [22]:
train_genes, val_genes, test_genes = SAGEnet.tools.get_train_val_test_genes(use_gene_list,tss_data_path=tss_data_path, use_enformer_gene_assignments=False)
print(len(train_genes))
print(len(val_genes))
print(len(test_genes))

selecting train/val/test gene sets based on chromosome split
11173
1580
1495


#### Select train and validation gene meta information:

In [38]:
train_gene_meta = gene_meta_info.loc[train_genes]
val_gene_meta = gene_meta_info.loc[val_genes]
train_gene_meta

Unnamed: 0_level_0,gene_name,gene_id,chr_hg38,start_hg38,end_hg38,strand_hg38,tss_hg38,chr_hg19,tss_hg19,tss,chr,ensg,strand,pos
region_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
ENSG00000000457,SCYL3,ENSG00000000457,1,169849631,169894267,-,169894267,chr1,169863408.0,169863408,1,ENSG00000000457,-,169894267
ENSG00000000460,C1orf112,ENSG00000000460,1,169662007,169854080,+,169662007,chr1,169631245.0,169631245,1,ENSG00000000460,+,169662007
ENSG00000000938,FGR,ENSG00000000938,1,27612064,27635277,-,27635277,chr1,27961788.0,27961788,1,ENSG00000000938,-,27635277
ENSG00000001036,FUCA2,ENSG00000001036,6,143494811,143511690,-,143511690,chr6,143832827.0,143832827,6,ENSG00000001036,-,143511690
ENSG00000001084,GCLC,ENSG00000001084,6,53497341,53616970,-,53616970,chr6,53481768.0,53481768,6,ENSG00000001084,-,53616970
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
ENSG00000262209,PCDHGB3,ENSG00000262209,5,141370264,141512979,+,141370264,chr5,140749831.0,140749831,5,ENSG00000262209,+,141370264
ENSG00000262246,CORO7,ENSG00000262246,16,4354542,4425705,-,4425705,chr16,4475706.0,4475706,16,ENSG00000262246,-,4425705
ENSG00000262576,PCDHGA4,ENSG00000262576,5,141355025,141512979,+,141355025,chr5,140734592.0,140734592,5,ENSG00000262576,+,141355025
ENSG00000262621,AC025283.2,ENSG00000262621,16,3365099,3479550,+,3365099,chr16,3415099.0,3415099,16,ENSG00000262621,+,3365099


#### Initialize datasets and dataloaders 

In [54]:
input_len=40000

train_dataset = SAGEnet.data.ReferenceGenomeDataset(metadata=train_gene_meta, hg38_file_path=hg19_path, y_data=expr_df, input_len=input_len,majority_seq=False,train_subs=train_individuals)
val_dataset = SAGEnet.data.ReferenceGenomeDataset(metadata=val_gene_meta, hg38_file_path=hg19_path, y_data=expr_df, input_len=input_len,majority_seq=False,train_subs=train_individuals)
train_dataloader = DataLoader(train_dataset,  shuffle=True)
val_dataloader = DataLoader(val_dataset, shuffle=False)

#### Initialize an r-SAGE-net model

In [55]:
my_model = SAGEnet.models.rSAGEnet(input_length=input_len)

#### Set up for model training 

In [56]:
device=1 # which GPU 
max_epochs=50
wandb_logger = WandbLogger(project='test_project_name', name='test_job_name', id='test_job_name', resume="allow") # change these based on your logging preferences 

es = EarlyStopping(monitor="val_pearson", patience=10,mode='max')
checkpoint_callback = ModelCheckpoint(dirpath=model_save_dir, monitor="val_pearson", save_top_k=1, mode="max", save_last=True, every_n_epochs=1)
lr_monitor = LearningRateMonitor(logging_interval='epoch')
callbacks=[es,checkpoint_callback,lr_monitor]

trainer = pl.Trainer(
accelerator="gpu", 
devices=[int(device)] if device else 1, 
num_nodes=1, 
strategy="ddp" if not device else 'auto', 
callbacks=callbacks, 
max_epochs=max_epochs, 
benchmark=False, 
profiler='simple', 
gradient_clip_val=1, 
logger=wandb_logger, 
log_every_n_steps=10)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


#### Train! 

In [57]:
trainer.fit(my_model, train_dataloader, val_dataloaders=val_dataloader)

/homes/gws/aspiro17/miniconda3/envs/SAGEnet/lib/python3.10/site-packages/pytorch_lightning/loggers/wandb.py:397: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.
/homes/gws/aspiro17/miniconda3/envs/SAGEnet/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /homes/gws/aspiro17/SAGEnet/example_notebooks/data/mostafavilab/personal_genome_expr/revisions/GEUVADIS/res exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name               | Type       | Params | Mode 
----------------------------------------------------------
0 | conv0              | Sequential | 36.9 K | train
1 | convlayers         | ModuleList | 2.5 M  | train
2 | dilated_convlayers | ModuleList | 0      | train
3 | fc0                | Sequential | 65.8 K | train
4 | fclayers           | ModuleList | 65.8 K 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/homes/gws/aspiro17/miniconda3/envs/SAGEnet/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
  return F.conv1d(
/homes/gws/aspiro17/miniconda3/envs/SAGEnet/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x7f10285149d0>> (for post_run_cell), with arguments args (<ExecutionResult object at 7f101274e680, execution_count=57 error_before_exec=None error_in_exec=name 'exit' is not defined info=<ExecutionInfo object at 7f10122936d0, raw_cell="trainer.fit(my_model, train_dataloader, val_datalo.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Bchelan.cs.washington.edu/homes/gws/aspiro17/SAGEnet/example_notebooks/GEUVADIS_example_rsagenet.ipynb#X61sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe