In [46]:
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from biodatatools.utils.common import json_load
from biodata.delimited import DelimitedReader, DelimitedWriter
import sys
import itertools
import glob
import pybedtools
import os
import numpy as np
from collections import defaultdict
import seaborn as sns
import pickle
from mphelper import ProcessWrapPool
from Bio import SeqIO
import matplotlib.image as mpimg
from pathlib import Path
import logomaker

In [2]:
font_dir = Path.cwd().parent / "font"
for font in ["Arial.ttf", "Arial_Bold.ttf"]:
    matplotlib.font_manager.fontManager.addfont(font_dir / font)
matplotlib.rcParams["font.family"] = "Arial"
bfontsize = 12
sfontsize = 9

In [3]:
PROJECT_DIR_d = "/home/yc2553/projects/HEA/databases/"
PROJECT_DIR_o = "/home/yc2553/projects/HEA/output/"
PROJECT_DIR_s = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/3.Human_atlas/procapnet/"
PROJECT_DIR_r = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/resources/"

In [28]:
sys.path.append(f"{PROJECT_DIR_s}2_train_models/")
from data_loading import extract_observed_profiles, one_hot_encode
sys.path.append(f"{PROJECT_DIR_s}5_modisco/")
from modiscolite_utils import load_modisco_results

# Observed patterns

In [17]:
groups = json_load(f"{PROJECT_DIR_d}PROcap/metainfo/samples.json")
folders = glob.glob(f"{PROJECT_DIR_o}procapnet/deepshap_out/*")
samples = [folder.split("/")[-1] for folder in folders if folder.split("/")[-1] in groups["normal_tissues"]]

In [18]:
labels = json_load(f"{PROJECT_DIR_d}PROcap/metainfo/classifications.json")

In [19]:
inputfile = f"{PROJECT_DIR_o}procapnet/modisco_out/all_motifs.json"
with open(inputfile, "rb") as f:
    motifs = pickle.load(f)

In [20]:
model_type = "strand_merged_umap"
tasks = ["counts", "profile"]

In [21]:
df_motifs = {}
for s, task in itertools.product(samples, tasks):
    df_motifs[(s, task)] = pd.read_table(f"{PROJECT_DIR_o}procapnet/finemo/{s}/{task}/hits_with_motif_names.bed")

In [68]:
# Get motif instances for each motif
# If a motif instance found in multiple peaks, select the one with the highest "hit_importance"
# Note for the same motif, different models may output different fwd/rev patterns; will fix it in later sections

def get_motif_pos(s, task):
    df = df_motifs[(s, task)].dropna()
    for motif in set(df["motif"]):
        df2 = df[df["motif"]==motif]
        df2_filtered = df2.loc[df2.groupby(['#chr', 'start', 'end'])['hit_importance'].idxmax()]
        outputfile = f"{PROJECT_DIR_o}procapnet/instances/{s}_{task}_{motif}.bed"
        with DelimitedWriter(outputfile) as dw:
            for _, row in df2_filtered.iterrows():
                dw.write([row['#chr'], row['start'], row['end'], row['strand']])

In [69]:
pwpool = ProcessWrapPool(10)
		
for s, task in df_motifs:
	pwpool.run(get_motif_pos, args=[s, task])

In [51]:
len(pwpool.finished_tasks)

In [27]:
pwpool.close()

In [29]:
def get_matrix(bed_file, data, tss, rev):
    strands = [i.fields[-1] for i in pybedtools.BedTool(bed_file)]
    results = []
    for n in range(len(strands)):
        # Flip the prediction to be on the same orientation as the motif
        if (strands[n] == "-" and rev == False) or (strands[n] == "+" and rev == True):
            data[n] = data[n][::-1, ::-1]
        
        if tss == "fwd":
            values = list(data[n][0])
        else:
            # set negative values for reverse strand
            values = [-v for v in data[n][1]]
        results.append(values)
    df = pd.DataFrame(results)
    return df

In [127]:
def generate_individual_obs_metaplot(s, df, ax, title=True, legend=True, xlabel=True, ylabel=None, ylims=None):
	sns.lineplot(data=df, x="Position", y="Feature", 
					 estimator=np.mean, errorbar=None,
					 hue="Label", hue_order=tss_types, palette=["#fb8072", "#80b1d3"], 
					 ax=ax)
	ax.spines[["right", "top"]].set_visible(False)
	ax.spines['left'].set_position(('outward', 10))
	ax.spines['bottom'].set_position(('outward', 10))
	ax.axhline(y=0, ls="--", c="grey")
	
	if title:
		ax.set_title(title, fontsize=bfontsize, pad=5)
	if legend:
		ax.legend(fontsize=sfontsize ,frameon=False)
	else:
		ax.legend([], [], frameon=False)
		
	if ylabel:
		ax.set_ylabel(ylabel, fontsize=bfontsize)
	else:
		ax.set_ylabel("")
	if ylims:
		ax.set_ylim(ylims)

	ax.set_xlim([0,400])
	ax.set_xticks([0,100,200,300,400])
	if xlabel:
		ax.set_xticklabels(["-200", "-100", "0", "100", "200"])
		ax.set_xlabel("Distance to motif (bp)", fontsize=bfontsize)
	else:
		ax.set_xticklabels([])
		ax.set_xlabel("")
	
	ax.tick_params(labelsize=sfontsize)

In [101]:
def generate_obs_metaplots(motif, df_metaplots, outputfile):
    fig, axes = plt.subplots(len(df_metaplots), 1, figsize=(4, len(df_metaplots)*2))

    for s in df_metaplots:
        if len(df_metaplots) == 1:
            ax = axes
        else:
            row = list(df_metaplots).index(s)
            ax = axes[row]

        title = motif if len(df_metaplots) == 1 or row == 0 else False
        legend = True if len(df_metaplots) == 1 or row == 0 else False
        xlabel = True if len(df_metaplots) == 1 or row == len(df_metaplots)-1 else False
        generate_individual_obs_metaplot(s, df_metaplots[s], ax, title, legend, xlabel, s)
    
    plt.savefig(outputfile, bbox_inches="tight")

In [104]:
def obs_profiles_across_samples(task, motif, samples, outputfile):
	df_metaplots = {}
	for s in samples:
		# Get observed profiles
		bed = f"{PROJECT_DIR_o}procapnet/instances/{s}_{task}_{motif}.bed"
		obs_profs = extract_observed_profiles(bws[s][0],
											bws[s][1],
											bed,
											out_window=400)
		rev = True if motif in revs[task] and s in revs[task][motif] else False
		# Combine two strands
		dfs = []
		for tss in tss_types:
			df_features = get_matrix(bed, obs_profs, tss, rev)
			df_reformat = pd.melt(df_features,
								  id_vars=None,
								  value_vars=list(df_features.columns),
								  var_name="Position",
								  value_name="Feature"
								  )
			df_reformat["Label"] = tss
			dfs.append(df_reformat)
		df_metaplots[s] = pd.concat(dfs).reset_index(drop=True)
	
	# Generate metaplots for observed profiles
	generate_obs_metaplots(motif, df_metaplots, outputfile)

In [33]:
bws = json_load(f"{PROJECT_DIR_d}PROcap/metainfo/sample_bws.json")

In [39]:
revs = {"counts": {"CREB": [samples[n] for n in [1,2,3,6,7,8,11,14]],
                    "MEF2": ["EN23", "GT22"],
                    "SRF": ["GT22"],
                    "ZNF384": ["EN6", "EN3", "GT22", "EN4", "GT24"],
                    "YY1": ["EN6", "EN23"]
                    }
       }

In [None]:
tss_types = ["fwd", "rev"]

In [92]:
motif_types = defaultdict(set)
for k in motifs:
    motif_types[k[1]] |= set(motifs[k])
for k in motif_types:
    print(k, len(motif_types[k]))

counts 25
profile 24


In [132]:
task = "counts"
motif_samples = {}
for motif in motif_types[task]:
	motif_samples[(task, motif)] = [s for s in samples if os.path.exists(f"{PROJECT_DIR_o}procapnet/instances/{s}_{task}_{motif}.bed")]

In [106]:
pwpool = ProcessWrapPool(10)

for motif in motif_types[task]:
	outputfile = f"{PROJECT_DIR_o}procapnet/temp/{task}_{motif}.pdf"
	pwpool.run(obs_profiles_across_samples, args=[task, motif, motif_samples[(task, motif)], outputfile])

In [51]:
len(pwpool.finished_tasks)

In [27]:
pwpool.close()

# In silico mutagensis

## Run prediction

In [12]:
mut_types = ["wt", "mt"]

In [131]:
inputfile = f"{PROJECT_DIR_r}genomes/human/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
with open(inputfile, "r") as f:
    fdict = SeqIO.to_dict(SeqIO.parse(f, "fasta"))

In [167]:
# Centered on motif

for motif in motif_types[task]:
	for s, t in itertools.product(motif_samples[(task, motif)], mut_types):
		b = pybedtools.BedTool(f"{PROJECT_DIR_o}procapnet/instances/{s}_{task}_{motif}.bed")
		seqs = []
		for i in b:
			chrom, m_start, m_end = i.fields[:3]
			m_start, m_end = int(m_start), int(m_end)
			center = (m_start+m_end)//2
			wt_seq = fdict[chrom][center-500:center+500].seq.upper()
			wt_seq = one_hot_encode(wt_seq)
			if t == "wt":
				seqs.append(wt_seq.T)
			else:
				# in-silico deletion: replace motif instance with [0,0,0,0]
				mt_seq = wt_seq
				for i in range(500-(center-m_start), 500+(center-m_start)):
					mt_seq[i] = [0,0,0,0]
				seqs.append(mt_seq.T)
		seqs = np.array(seqs)
		outputfile = f"{PROJECT_DIR_o}procapnet/motif_prediction/{s}_{task}_{motif}_{t}.npy"
		np.save(outputfile, seqs)

In [171]:
script = f"{PROJECT_DIR_s}slurm/predict.sh"
model_type = "strand_merged_umap"
for motif in motif_types[task]:
	for s, t in itertools.product(motif_samples[(task, motif)], mut_types):
		inputfile = f"{PROJECT_DIR_o}procapnet/motif_prediction/{s}_{task}_{motif}_{t}.npy"
		commands = ["bash", script,
					s,
					model_type,
					f"{PROJECT_DIR_o}procapnet/",
					inputfile,
					f"{PROJECT_DIR_o}procapnet/motif_prediction/{s}_{task}_{motif}_{t}.",
					str(mut_types.index(t))
					]
		print(" ".join(commands))

## Predicted profiles

In [34]:
def get_scaled_profiles(s, task, motif, t):
    f1 = f"{PROJECT_DIR_o}procapnet/motif_prediction/{s}_{task}_{motif}_{t}.pred_counts.npy"
    f2 = f"{PROJECT_DIR_o}procapnet/motif_prediction/{s}_{task}_{motif}_{t}.pred_profiles.npy"
    pred_counts = np.load(f1)
    pred_profiles = np.load(f2)
    scaled_pred_profiles = pred_profiles * np.exp(pred_counts)[..., None]
    return scaled_pred_profiles

In [128]:
def generate_individual_pred_metaplot(s, df, ax, title=True, legend=True, xlabel=True, ylabel=None, legend_loc="right", ylims=None):
	hue_order = [f"{tss} ({mut})" for tss, mut in itertools.product(tss_types, mut_types)]
	palette = ["#fb8072", "#a50f15", "#80b1d3", "#08519c"]
	
	sns.lineplot(data=df, x="Position", y="Feature", 
					 estimator=np.mean, errorbar=None,
					 hue="Label", hue_order=hue_order, palette=palette, 
					 ax=ax)
	
	ax.spines[["right", "top"]].set_visible(False)
	ax.spines['left'].set_position(('outward', 10))
	ax.spines['bottom'].set_position(('outward', 10))
	ax.axhline(y=0, ls="--", c="grey")
	
	if title:
		ax.set_title(title, fontsize=bfontsize, pad=5)
	if legend:
		if legend_loc == "right":
			ax.legend(fontsize=sfontsize, loc="upper left", bbox_to_anchor=(1.02,1), frameon=False)
		else:
			ax.legend(fontsize=sfontsize, loc="lower center", bbox_to_anchor=(1.1,-0.8), ncols=4, frameon=False)
	else:
		ax.legend([], [], frameon=False)
	
	if ylabel:
		ax.set_ylabel(ylabel, fontsize=bfontsize)
	else:
		ax.set_ylabel("")
	if ylims:
		ax.set_ylim(ylims)

	ax.set_xlim([50,450])
	ax.set_xticks([50,150,250,350,450])
	if xlabel:
		ax.set_xticklabels(["-200", "-100", "0", "100", "200"])
		ax.set_xlabel("Distance to motif (bp)", fontsize=bfontsize)
	else:
		ax.set_xticklabels([])
		ax.set_xlabel("")
	
	ax.tick_params(labelsize=sfontsize)

In [108]:
def generate_pred_metaplots(motif, df_metaplots, outputfile, legend_loc="right"):
    fig, axes = plt.subplots(len(df_metaplots), 1, figsize=(4, len(df_metaplots)*2))
    
    for s in df_metaplots:
        if len(df_metaplots) == 1:
            ax = axes
        else:
            row = list(df_metaplots).index(s)
            ax = axes[row]

        title = motif if len(df_metaplots) == 1 or row == 0 else False
        legend = True if len(df_metaplots) == 1 or row == 0 else False
        xlabel = True if len(df_metaplots) == 1 or row == len(df_metaplots)-1 else False
        generate_individual_pred_metaplot(s, df_metaplots[s], ax, title, legend, xlabel, s, legend_loc)
    
    plt.savefig(outputfile, bbox_inches="tight")

In [112]:
def pred_profiles_across_samples(task, motif, samples, outputfile, legend_loc="right"):
    df_metaplots = {}
    for s in samples:
        bed = f"{PROJECT_DIR_o}procapnet/instances/{s}_{task}_{motif}.bed"
        rev = True if motif in revs[task] and s in revs[task][motif] else False
        dfs = []
        for mut in mut_types:
            # Get predicted profiles (scaled) for wt and mt sequences
            pred_profs = get_scaled_profiles(s, task, motif, mut)
            for tss in tss_types:
                df_features = get_matrix(bed, pred_profs, tss, rev)     
                df_reformat = pd.melt(df_features,
                                              id_vars=None,
                                              value_vars=list(df_features.columns),
                                              var_name="Position",
                                              value_name="Feature"
                                              )
                df_reformat["Label"] = f"{tss} ({mut})"
                dfs.append(df_reformat)
        df_metaplots[s] = pd.concat(dfs).reset_index(drop=True) 

    # Generate metaplots of predicted profiles
    generate_pred_metaplots(motif, df_metaplots, outputfile, legend_loc)

In [114]:
pwpool = ProcessWrapPool(10)

for motif in motif_types[task]:
	outputfile = f"{PROJECT_DIR_o}procapnet/temp/{task}_{motif}_pred.pdf"
	pwpool.run(pred_profiles_across_samples, args=[task, motif, motif_samples[(task, motif)], outputfile])

In [51]:
len(pwpool.finished_tasks)

In [27]:
pwpool.close()

# Broad vs. focused effect curves

In [40]:
# Highlight these motifs

example_samples = {"CREB": "EN55",
                   "TATA": "BCT5",
                   "SRF": "GT22",
                   "MEF2": "EN23"
                  }

In [84]:
df_obs_metaplots = {}
df_pred_metaplots = {}

for motif in example_samples:
	s = example_samples[motif]
	rev = True if motif in revs[task] and s in revs[task][motif] else False
	bed = f"{PROJECT_DIR_o}procapnet/instances/{s}_{task}_{motif}.bed"
	
	# Observed profiles
	obs_profs = extract_observed_profiles(bws[s][0],
										bws[s][1],
										bed,
										out_window=400)
	dfs = []
	for tss in tss_types:
		df_features = get_matrix(bed, obs_profs, tss, rev)
		df_reformat = pd.melt(df_features,
							  id_vars=None,
							  value_vars=list(df_features.columns),
							  var_name="Position",
							  value_name="Feature"
							  )
		df_reformat["Label"] = tss
		dfs.append(df_reformat)
	df_obs_metaplots[s] = pd.concat(dfs).reset_index(drop=True)    

	# Predicted profiles
	dfs = []
	for mut in mut_types:
		pred_profs = get_scaled_profiles(s, task, motif, mut)
		for tss in tss_types:
			df_features = get_matrix(bed, pred_profs, tss, rev)
			df_reformat = pd.melt(df_features,
							  id_vars=None,
							  value_vars=list(df_features.columns),
							  var_name="Position",
							  value_name="Feature"
							  ) 
			df_reformat["Label"] = f"{tss} ({mut})"
			dfs.append(df_reformat)
	df_pred_metaplots[s] = pd.concat(dfs).reset_index(drop=True)  

== In Extract Profiles ==
Peak filepath: /home/yc2553/projects/HEA/output/procapnet/instances/EN55_counts_CREB.bed
Profile length: 400
Num. Examples: 56422
== In Extract Profiles ==
Peak filepath: /home/yc2553/projects/HEA/output/procapnet/instances/BCT5_counts_TATA.bed
Profile length: 400
Num. Examples: 30713
== In Extract Profiles ==
Peak filepath: /home/yc2553/projects/HEA/output/procapnet/instances/GT22_counts_SRF.bed
Profile length: 400
Num. Examples: 1416
== In Extract Profiles ==
Peak filepath: /home/yc2553/projects/HEA/output/procapnet/instances/EN23_counts_MEF2.bed
Profile length: 400
Num. Examples: 10251


In [41]:
# Refer to codes in https://github.com/jmschrei/tfmodisco-lite/blob/main/modiscolite/report.py and https://github.com/kundajelab/ProCapNet/blob/main/src/figure_notebooks/other_motif_utils.py

def trim_motif(cwm, trim_threshold=0.3, pad=2): 
    trim_thresh = np.max(cwm) * trim_threshold
    pass_inds = np.where(cwm >= trim_thresh)[0]

    start = max(np.min(pass_inds) - pad, 0)
    end = min(np.max(pass_inds) + pad + 1, len(cwm) + 1)
    return cwm[start:end]

def plot_weights(array, ax):
    df = pd.DataFrame(array, columns=['A', 'C', 'G', 'T'])
    df.index.name = 'pos'
    crp_logo = logomaker.Logo(df, ax=ax)
    crp_logo.style_spines(visible=False)
    ax.set_ylim(min(df.sum(axis=1).min(), 0), df.sum(axis=1).max())
    return crp_logo

In [123]:
obs_ylims = [[-7,7], [-7,7], [-70,70], [-3,3]]
pred_ylims = [[-3,3], [-3,3], [-25,25], [-1.8,1.8]]

In [129]:
fig, ax = plt.subplots(3, len(example_samples), figsize=(9.5, 4.5), gridspec_kw={'height_ratios': [1, 2, 2]})
for motif in example_samples:
	col = list(example_samples).index(motif)
	s = example_samples[motif]
	
	# Sequence logo
	row = 0
	modisco_results = load_modisco_results(f"{PROJECT_DIR_o}procapnet/modisco_out/{s}/{model_type}/merged/{task}_modisco_results.hd5")
	pattern = modisco_results["pos_patterns"][f"pattern_{motifs[(s, task)][motif]}"]
	cwm_fwd = np.array(pattern['contrib_scores'][:])
	rev = True if motif in revs[task] and s in revs[task][motif] else False
	cwm = cwm_fwd[::-1, ::-1] if rev else cwm_fwd
	cwm_trim = trim_motif(cwm)
	plot_weights(cwm_trim, ax[row, col])
	ax[row, col].set_xticks([])
	ax[row, col].set_yticks([])
	ax[row, col].set_title(motif, fontsize=bfontsize)

	title = ""
	# Observed profiles
	row = 1	
	legend = False
	xlabel = False
	ylabel = "Observed" if col == 0 else False
	generate_individual_obs_metaplot(s, df_obs_metaplots[s], ax[row, col], title, legend, xlabel, ylabel, obs_ylims[col])
	
	# Predicted profiles before and after in silico mutagenesis
	row = 2
	legend = True if col == 1 else False
	xlabel = True
	ylabel = "Predicted" if col == 0 else False
	generate_individual_pred_metaplot(s, df_pred_metaplots[s], ax[row, col], title, legend, xlabel, ylabel, "lower", pred_ylims[col])

plt.subplots_adjust(wspace=0.4, hspace=0.3)
plt.savefig(f"{PROJECT_DIR_o}figures/Fig5a.pdf", bbox_inches="tight", transparent=True)