In [7]:
import argparse
from Bio.Seq import Seq
from datasets import load_dataset
from gpn.data import Genome, load_dataset_from_file_or_dir
import grelu.resources
from grelu.sequence.format import strings_to_one_hot
import numpy as np
import os
import pandas as pd
import tempfile
import torch
import torch.nn.functional as F
from transformers import Trainer, TrainingArguments


class VEPModel(torch.nn.Module):
	def __init__(self, model, personalized_enformer=False, personalized_enformer_diff=False, shifts=[0]):
		super().__init__()
		self.model = model
		self.shifts = shifts
		self.personalized_enformer = personalized_enformer
		self.personalized_enformer_diff = personalized_enformer_diff

	def get_personalized_enformer_diff_scores(self, x_ref, x_alt):
		x_ref = x_ref.permute(0, 2, 1)
		x_alt = x_alt.permute(0, 2, 1)
		seq1 = torch.cat([x_ref.unsqueeze(1), x_ref.unsqueeze(1)], dim=1)
		seq2 = torch.cat([x_alt.unsqueeze(1), x_alt.unsqueeze(1)], dim=1)
		batch = {"seq1":seq1, "seq2":seq2}
		with torch.no_grad():
			outputs = self.model.predict_step(batch, batch_idx=0)
		return outputs["Y_diff"]

	def get_personalized_enformer_scores(self, x_ref, x_alt):
		x_ref = x_ref.permute(0, 2, 1)
		x_alt = x_alt.permute(0, 2, 1)

		outputs_ref = self.model(x_ref, 
               return_base_predictions=True,
			   base_predictions_head=None)
		outputs_alt = self.model(x_alt, 
               return_base_predictions=True,
			   base_predictions_head=None)
		y_ref = torch.cat([outputs_ref['human'], outputs_ref['mouse']], dim=2)
		assert y_ref.shape[1:] == (896, 5313 + 1643), f"y_ref.shape: {y_ref.shape}; outputs_ref['human'].shape: {outputs_ref['human'].shape}; outputs_ref['mouse'].shape: {outputs_ref['mouse'].shape}"
		y_alt = torch.cat([outputs_alt['human'], outputs_alt['mouse']], dim=2)
		assert y_alt.shape[1:] == (896, 5313 + 1643), f"y_alt.shape: {y_alt.shape}; outputs_alt['human'].shape: {outputs_alt['human'].shape}; outputs_alt['mouse'].shape: {outputs_alt['mouse'].shape}"
		lfc = torch.log2(1 + y_alt) - torch.log2(1 + y_ref) # torch.Size([bs, 896, 5313 + 1643]) for human
		l2 = torch.linalg.norm(lfc, dim=1) # [bs, 5313 + 1643]
		return l2

	def get_scores(self, x_ref, x_alt):
		y_ref = self.model(x_ref)
		y_alt = self.model(x_alt)
		lfc = torch.log2(1 + y_alt) - torch.log2(1 + y_ref) # [bs, 5313, 896]
		l2 = torch.linalg.norm(lfc, dim=2) # [bs, 5313]
		return l2

	def shift(self, x, shift_size):
		if shift_size == 0:
			return x
		
		original_shape = x.shape
		assert x.shape[1] == 4, f"x.shape: {x.shape}"

		# x is bs x 4 x seq_len torch.Size([bs, 4, 196608])
		# we want to shift x by shift_size adding padding to the left or right

		Ns = torch.zeros(x.shape[0], 4, abs(shift_size), device=x.device, dtype=x.dtype)+0.25
		assert Ns.shape == (x.shape[0], 4, abs(shift_size)), f"Ns.shape: {Ns.shape}, shift_size: {shift_size}, x.shape: {x.shape}"

		if shift_size > 0:
			#0: torch.Size([4, 4, 196608])
			# 1: torch.Size([4, 4, 196611])
			# 2: torch.Size([4, 4, 3])

			# add padding to the left
			# print ("0:",x.shape)
			x = torch.cat([Ns, x], dim=2)
			# print ("1:",x.shape)
			x = x[:,:,:-shift_size]
			# print ("2:",x.shape, shift_size)
		else:
			# add padding to the right
			# print ("0:",x.shape)
			x = torch.cat([x, Ns], dim=2)
			# print ("1:",x.shape)
			x = x[:,:,-shift_size:]
			# print ("2:",x.shape, shift_size)
		
		assert x.shape == original_shape, f"x.shape: {x.shape}, original_shape: {original_shape}, shift_size: {shift_size}"
		# AssertionError: x.shape: torch.Size([4, 2, 196610]),
		# original_shape: torch.Size([4, 4, 196608])

		assert x.shape[1] == 4, f"x.shape: {x.shape}"
		return x

	def forward(
		self,
		x_ref_fwd=None,
		x_alt_fwd=None,
		x_ref_rev=None,
		x_alt_rev=None,
	):
		if self.personalized_enformer:
			score_fn = self.get_personalized_enformer_scores
		elif self.personalized_enformer_diff:
			score_fn = self.get_personalized_enformer_diff_scores
		else:
			score_fn = self.get_scores

		scores = []
		for shift in self.shifts:
			scores.append(score_fn(self.shift(x_ref_fwd, shift), self.shift(x_alt_fwd, shift)))
			scores.append(score_fn(self.shift(x_ref_rev, shift), self.shift(x_alt_rev, shift)))

		return torch.mean(torch.stack(scores), dim=0)


def run_vep(
	variants,
	genome,
	window_size,
	model,
	per_device_batch_size=8,
	dataloader_num_workers=0,
):
	def transform(V):
		# we convert from 1-based coordinate (standard in VCF) to
		# 0-based, to use with Genome
		chrom = np.array(V["chrom"])
		n = len(chrom)
		pos = np.array(V["pos"]) - 1
		start = pos - window_size // 2
		end = pos + window_size // 2
		seq_fwd, seq_rev = zip(
			*(genome.get_seq_fwd_rev(chrom[i], start[i], end[i]) for i in range(n))
		)
		seq_fwd = np.array([list(seq.upper()) for seq in seq_fwd], dtype="object")
		seq_rev = np.array([list(seq.upper()) for seq in seq_rev], dtype="object")
		assert seq_fwd.shape[1] == window_size
		assert seq_rev.shape[1] == window_size
		ref_fwd = np.array(V["ref"])
		alt_fwd = np.array(V["alt"])
		ref_rev = np.array([str(Seq(x).reverse_complement()) for x in ref_fwd])
		alt_rev = np.array([str(Seq(x).reverse_complement()) for x in alt_fwd])
		pos_fwd = window_size // 2
		pos_rev = pos_fwd - 1 if window_size % 2 == 0 else pos_fwd

		def prepare_output(seq, pos, ref, alt):
			assert (seq[:, pos] == ref).all(), f"{seq[:, pos]}, {ref}"
			seq_ref = seq
			seq_alt = seq.copy()
			seq_alt[:, pos] = alt
			return (
				strings_to_one_hot(["".join(x) for x in seq_ref]),
				strings_to_one_hot(["".join(x) for x in seq_alt]),
			)

		res = {}
		res["x_ref_fwd"], res["x_alt_fwd"] = prepare_output(seq_fwd, pos_fwd, ref_fwd, alt_fwd)
		res["x_ref_rev"], res["x_alt_rev"] = prepare_output(seq_rev, pos_rev, ref_rev, alt_rev)
		return res

	variants.set_transform(transform)
	training_args = TrainingArguments(
		output_dir=tempfile.TemporaryDirectory().name,
		per_device_eval_batch_size=per_device_batch_size,
		dataloader_num_workers=dataloader_num_workers,
		remove_unused_columns=False,
		report_to="none",  # disables all reporting, including wandb
	)
	trainer = Trainer(model=model, args=training_args)
	return trainer.predict(test_dataset=variants).predictions

In [3]:
dataset = "complex_traits_matched_9"
if not os.path.exists(f"results/dataset/{dataset}/test.parquet"):
	os.makedirs(f"results/dataset/{dataset}", exist_ok=True)
pd.read_parquet(f"hf://datasets/songlab/TraitGym/{dataset}/test.parquet").to_parquet(f"results/dataset/{dataset}/test.parquet")

In [4]:
genome = Genome("results/genome.fa.gz")

In [5]:
variants = load_dataset_from_file_or_dir(
		f"results/dataset/{dataset}/test.parquet",
		split="test",
		is_file=True,
	)

In [6]:
# variants = variants.select(range(100))

In [8]:
# model = grelu.resources.load_model(project="enformer", model_name="human")
# instead:
# download somewhere where wandb is no blocked
# import wandb
# api = wandb.Api()
# art = api.artifact('grelu/enformer/human:latest')
# art.download("C:\\Users\\user\\Downloads\\")

# then upload to /data/ckpts/wandb_human_enformer_latest.ckpt

from grelu.lightning import LightningModel
cpt_dir = "data/ckpts/"
cpt_name = "wandb_human_enformer_latest.ckpt"
model = LightningModel.load_from_checkpoint(os.path.join(cpt_dir, cpt_name), map_location="cpu")

  rank_zero_warn(


In [9]:
metadata_output_path = "results/metadata/Enformer.csv"
metadata = pd.DataFrame(model.data_params["tasks"])
if not os.path.exists(metadata_output_path):
	os.makedirs(os.path.dirname(metadata_output_path), exist_ok=True)
	metadata.to_csv(metadata_output_path, index=False)

In [14]:
columns = model.data_params['tasks']["name"]
window_size = model.data_params["train"]["seq_len"]
shifts = [-3,-6,0,3,6]
model_full = VEPModel(model.model, shifts=shifts)

per_device_batch_size = 4
dataloader_num_workers = 16

pred = run_vep(
	variants,
	genome,
	window_size,
	model_full,
	per_device_batch_size=per_device_batch_size,
	dataloader_num_workers=dataloader_num_workers,
)

if len(shifts)>1:
	save_model_name = f"{cpt_name}-{len(shifts)}"
else:
	save_model_name = cpt_name
output_path = f"results/dataset/{dataset}/features/{save_model_name}_L2.parquet"

directory = os.path.dirname(output_path)
if directory != "" and not os.path.exists(directory):
	os.makedirs(directory)
pd.DataFrame(pred, columns=columns).to_parquet(output_path, index=False)

In [5]:
import pandas as pd

df2 = pd.read_parquet("results/dataset/complex_traits_matched_9/features/wandb_human_enformer_latest.ckpt_L2.parquet")
df2.head()

Unnamed: 0,ENCFF833POA,ENCFF110QGM,ENCFF880MKD,ENCFF463ZLQ,ENCFF890OGQ,ENCFF996AEF,ENCFF660YSU,ENCFF787MSC,ENCFF568LMQ,ENCFF685MZL,...,CNhs14551,CNhs14618,CNhs14226,CNhs14229,CNhs14238,CNhs14239,CNhs14240,CNhs14241,CNhs14244,CNhs14245
0,0.007151,0.007426,0.010048,0.005994,0.006915,0.006616,0.00752,0.009041,0.008568,0.007695,...,0.006877,0.004066,0.001862,0.001971,0.00147,0.001695,0.001498,0.001478,0.003121,0.00272
1,0.012779,0.011033,0.013261,0.008563,0.009205,0.010085,0.008803,0.010515,0.0096,0.011925,...,0.009894,0.006118,0.003727,0.004223,0.002492,0.001939,0.001603,0.002188,0.005033,0.004664
2,0.013707,0.014915,0.016609,0.013668,0.012713,0.012658,0.011917,0.014117,0.013406,0.014505,...,0.011026,0.006833,0.004081,0.003441,0.002514,0.002845,0.001811,0.002418,0.006415,0.004842
3,0.042233,0.042491,0.050276,0.036298,0.047151,0.05523,0.030327,0.047163,0.043771,0.048341,...,0.030421,0.019885,0.010315,0.015707,0.012041,0.011184,0.013122,0.014038,0.017381,0.020587
4,0.009971,0.016547,0.032334,0.007006,0.007911,0.00869,0.007243,0.012781,0.008166,0.019799,...,0.030627,0.022836,0.007717,0.008418,0.006228,0.004725,0.004778,0.006216,0.021264,0.025121


In [10]:
import pandas as pd

df1 = pd.read_parquet("results/dataset/complex_traits_matched_9/features/wandb_human_enformer_latest.ckpt_L2.parquet")
df2 = pd.read_parquet("results/dataset/complex_traits_matched_9/features/wandb_human_enformer_latest.ckpt-5_L2.parquet")
(df2.columns.values == df1.columns.values).all()

md1 = pd.read_csv("results/metadata/wandb_human_enformer_latest.ckpt.csv")
md2 = pd.read_csv("results/metadata/wandb_human_enformer_latest.ckpt-5.csv")
assert (md1.columns.values == md2.columns.values).all()
assert (md1.index.values == md2.index.values).all()
assert (md1["name"].values == md2["name"].values).all()
assert (md1["assay"].values == md2["assay"].values).all()


In [13]:
# copy metadata file
metadata = pd.read_csv("results/metadata/Enformer_human.csv")
metadata.to_csv(f"results/metadata/{save_model_name}.csv", index=False)

In [None]:
pred.shape # (11400, 5313) - 11400 SNVs x 5313 enformer track predictions

(11400, 5313)

In [None]:
# to calc stats:
# snakemake -s workflow/Snakefile --cores 4 results/dataset/complex_traits_matched_9/AUPRC_by_chrom_weighted_average/all/Enformer_L2_L2.plus.all.csv

# job                                     count
# ------------------------------------  -------
# dataset_subset_all                          1
# get_metric_by_block                         1
# get_metric_by_block_weighted_average        1
# grelu_aggregate_assay                       1
# unsupervised_pred                           1
# total                                       5
