In [1]:
from mphelper import ProcessWrapPool
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from biodata.delimited import DelimitedReader
from Bio import SeqIO
import json
import itertools

In [2]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))
import utils

In [3]:
import matplotlib
font_dir = Path.cwd().parent / "font"
for font in ["Arial.ttf", "Arial_Bold.ttf"]:
    matplotlib.font_manager.fontManager.addfont(font_dir / font)
matplotlib.rcParams["font.family"] = "Arial"
bfontsize = 12
sfontsize = 9

In [4]:
PROJECT_DIR_s = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/softwares/procapnet/"
PROJECT_DIR_d = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/PROcap/"
PROJECT_DIR_r = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/resources/"
PROJECT_DIR_o = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/"

In [5]:
sys.path.append(f"{PROJECT_DIR_s}2_train_models/")
from data_loading import one_hot_encode

# Get sequences

In [6]:
s = "HCT116"
ks = ["pl", "mn"]
ps = ["divergent", "unidirectional"]
ds = ["distal", "proximal"]
gs = [f"{p}_{d}" for p, d in itertools.product(ps, ds)]
types = ["wt", "mt"]
ts = ["minTSS", "maxTSS"]

In [7]:
inputfile = f"{PROJECT_DIR_o}FIMO/{s}_CTCF_pos.json"
with open(inputfile, "r") as f:
	motifs = json.load(f)

In [8]:
inputfile = f"{PROJECT_DIR_r}genomes/human/GRCh38_no_alt_analysis_set_GCA_000001405.15.fasta"
with open(inputfile, "r") as f:
	fdict = SeqIO.to_dict(SeqIO.parse(f, "fasta"))

In [9]:
# For comparison with Fig.4
# Generate 1-kb sequences (original vs. mutant) anchored on motif center for prediction

def get_sequences(inputfiles, outputfile, fdict, t, motifs):
	seqs = []
	signs = {}
	for inputfile in inputfiles:
		with DelimitedReader(inputfile) as dr:
			for cols in dr:
				chrom, start, end = cols[:3]
				e = "_".join([chrom, start, end])
				if e in motifs:
					signs[e] = cols[-1]
					center = motifs[e][0]
					seq = fdict[chrom][center-500:center+500].seq.upper()
					wt_seq = one_hot_encode(seq)
					if t == "wt":
						seqs.append(wt_seq.T)
					else:
						mut_start = motifs[e][1]
						# in-silico deletion: replace motif instance with [0,0,0,0]
						mt_seq = wt_seq
						for i in range(500-(center-mut_start), 500+(center-mut_start)+1):
							mt_seq[i] = [0,0,0,0]
						seqs.append(mt_seq.T)
	seqs = np.array(seqs)
	np.save(outputfile, seqs)
	return signs

In [10]:
# Anchor on motif center, generate 1-kb sequences (original vs. mutant) for prediction

signs = {}
for d in ds:
	inputfiles = [f"{PROJECT_DIR_d}bed_plot/{s}_{p}_{d}.bed" for p in ps]
	for t in types:
		outputfile = f"{PROJECT_DIR_o}procapnet/prediction/{s}_{d}_CTCF_{t}.npy"
		signs[(d,t)] = get_sequences(inputfiles, outputfile, fdict, t, motifs)

# Prediction

In [11]:
# Add GPU at the end

script = f"{PROJECT_DIR_s}slurm/predict.sh"
scale = "True"
model_type = "strand_merged_umap"
for d, t in itertools.product(ds, types):
	inputfile = f"{PROJECT_DIR_o}procapnet/prediction/{s}_{d}_CTCF_{t}.npy"
	output_prefix = f"{PROJECT_DIR_o}procapnet/prediction/{s}_{d}_CTCF_{t}."
	commands = ["sbatch", script,
				s,
				model_type,
				f"{PROJECT_DIR_o}procapnet/",
				inputfile,
				output_prefix,
				scale,
				str(ds.index(d))
				]
	print(" ".join(commands))

sbatch /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/softwares/procapnet/slurm/predict.sh HCT116 strand_merged_umap /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/procapnet/ /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/procapnet/prediction/HCT116_distal_CTCF_wt.npy /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/procapnet/prediction/HCT116_distal_CTCF_wt. True 0
sbatch /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/softwares/procapnet/slurm/predict.sh HCT116 strand_merged_umap /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/procapnet/ /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/procapnet/prediction/HCT116_distal_CTCF_mt.npy /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/procapnet/prediction/HCT116_distal_CTCF_mt. True 0
sbatch /fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/softwares/procapnet/slurm/

# Generate feature matrix

## Get individual matrix

In [12]:
def get_matrix(data, signs, tss):
	results = []
	for n in range(data.shape[0]):
		sign = list(signs.values())[n]
		if tss == "maxTSS":
			values = list(data[n][0]) if sign == "+" else [v for v in data[n][1][::-1]]
		else:
			values = [-v for v in data[n][1]] if sign == "+" else [-v for v in data[n][0][::-1]]
		results.append(values)
	df = pd.DataFrame(results, index=list(signs.keys()))
	return df

In [13]:
df_features = {}
for d, t in itertools.product(ds, types):
	inputfile = f"{PROJECT_DIR_o}procapnet/prediction/{s}_{d}_CTCF_{t}.scaled_profiles.npy"
	data = np.load(inputfile)
	for tss in ts:
		df_features[(d,t,tss)] = get_matrix(data, signs[(d,t)], tss)

In [14]:
# 10-bp bins

df_bins = {}
for k in df_features:
	df_bins[k] = utils.bin_values(df_features[k])

In [15]:
df_reformat = {}
for k in df_bins:
	df = df_bins[k].reset_index()
	df_reformat[k] = pd.melt(df,
						  id_vars="index",
						  value_vars=list(df.columns[1:]),
						  var_name="Position",
						  value_name="Feature"
						  )

## Combine dataframes for plotting

In [16]:
inputfile = f"{PROJECT_DIR_o}labels/{s}_CTCF_DE.json"
with open(inputfile, "r") as f:
	DE_labels = json.load(f)

In [17]:
df_metaplots = {}
for d, tss in itertools.product(ds, ts):
	label = "Up" if tss == "minTSS" else "Down"
	dfs = []
	for t in types:
		df = df_reformat[(d,t,tss)]
		filtered_df = df[df["index"].isin([k for k, v in DE_labels[tss].items() if v == label])].copy()
		filtered_df["Label"] = f"{tss} ({t})"
		dfs.append(filtered_df)
	df_metaplots[(d,tss)] = pd.concat(dfs).reset_index(drop=True)	

# Generate metaplots

## Settings

In [18]:
for d, tss in itertools.product(ds, ts):
	maxs = []
	mins = []
	df = df_metaplots[(d,tss)]
	for t in types:
		maxs.append(df[df["Label"]==f"{tss} ({t})"].groupby("Position")["Feature"].mean().max())
		mins.append(df[df["Label"]==f"{tss} ({t})"].groupby("Position")["Feature"].mean().min())
	print(d, tss, max(maxs), min(mins))

distal minTSS -0.09394923 -14.459512
distal maxTSS 9.159374 0.11017028
proximal minTSS -0.20903422 -28.572908
proximal maxTSS 20.437168 0.26015943


## Generate metaplots

In [19]:
def generate_metaplot(d, tss, df_metaplots, ylims, yticks, outputfile, xlabel="Distance (bp)"):
	fig, ax = plt.subplots(figsize=(4.5,2))
	labelpad = 2

	hue_order = [f"{tss} ({t})" for t in types]
	if tss == "minTSS":
		palette = ["#a6cee3", "#08519c"]
		label = "Up"
	else:
		palette = ["#fb9a99", "#a50f15"]
		label = "Down"	
	
	df = df_metaplots[(d,tss)]
	sns.lineplot(data=df, x="Position", y="Feature", hue="Label", hue_order=hue_order, palette=palette, ax=ax, errorbar="se")
	
	ax.spines[["right", "top"]].set_visible(False)
	ax.spines['left'].set_position(('outward', 10))
	ax.spines['bottom'].set_position(('outward', 10))
	
	tot = len(set(df["index"]))
	ax.set_title(f"{label} (n={tot})", fontsize=bfontsize, pad=labelpad+5, fontweight="bold")
	ax.legend(fontsize=sfontsize)

	ax.set_ylim(ylims)
	ax.set_yticks(yticks)
	ax.set_yticklabels([str(y) for y in yticks])
	ax.set_ylabel("Predicted", fontsize=bfontsize, fontweight="bold")
	if tss == "minTSS":
		ax.yaxis.set_label_coords(-0.15, 0.5)
	else:
		ax.yaxis.set_label_coords(-0.12, 0.5)
		
	xticklabel_list = ["-250", "", "", "", "", "0", "", "", "", "", "250"]
	xtick_list = [n*5 for n in range(11)]
	ax.set_xticks(xtick_list)
	ax.set_xticklabels(xticklabel_list)
	ax.set_xlabel(xlabel, fontsize=bfontsize, fontweight="bold")
	
	ax.tick_params(labelsize=sfontsize, pad=labelpad)

	ax.axhline(y=0, color="#bdbdbd", ls="--")
	ax.axvline(x=25, ymin=-0.8, ymax=1, color="#bdbdbd", ls="--", lw="1", clip_on=False, zorder=1)

	plt.savefig(outputfile, bbox_inches="tight", dpi=300) 

In [20]:
# The matplotlib version for other analysis is v3.9.2
# Need to change matplotlib to v3.8.3 or lower here
# Otherwise, "ax.axvline(...)" won't display as expected

In [21]:
pwpool = ProcessWrapPool(len(ds)*len(ts))

for d in ds:
	folder = "supp_figures/" if d == "distal" else "other_figures/"
	for tss in ts:
		if tss == "minTSS":
			outputfile = f"{PROJECT_DIR_o}{folder}suppFig11a.pdf"
			if d == "distal":
				ylims = [-17, 2]
				yticks = [-15, -10, -5, 0]
			else:
				ylims = [-33, 3]
				yticks = [-30, -20, -10, 0]
		else:
			outputfile = f"{PROJECT_DIR_o}{folder}suppFig11b.pdf"
			if d == "distal":
				ylims = [-2, 11]
				yticks = [0, 3, 6, 9]
			else:
				ylims = [-2, 26]
				yticks = [0, 8, 16, 24]
		pwpool.run(generate_metaplot, args=[d, tss, df_metaplots, ylims, yticks, outputfile])

In [23]:
len(pwpool.finished_tasks)

4

In [None]:
pwpool.close()