In [20]:
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from mphelper import ProcessWrapPool
import numpy as np
from biodata.delimited import DelimitedReader
import itertools
import biodataplot.metaplot as bpm
from biodata.bigwig import BigWigIReader
from biodata.bed import BEDXReader
from genomictools import GenomicCollection

In [2]:
PROJECT_DIR_d = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/PROcap/"
PROJECT_DIR_r = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/resources/"
PROJECT_DIR_o = "/fs/cbsuhy02/storage/yc2553/yc2553/projects/TRE_directionality/output/"

# 1. Generate feature matrix

## 1.1 Get individual matrix

In [11]:
s = "C1"
ks = ["pl", "mn"]
ps = ["divergent", "unidirectional"]
ds = ["distal", "proximal"]
cs = ["", "_control"]
gs = [f"{p}_{d}{c}" for p, d, c in itertools.product(ps, ds, cs)]

In [7]:
bws = {"PROcap_pl": f"{PROJECT_DIR_d}Analysis/{s}_dedup_chr1-22-X_R2_coverage_pl.bw",
	   "PROcap_mn": f"{PROJECT_DIR_d}Analysis/{s}_dedup_chr1-22-X_R2_coverage_mn.bw",
	   "DNase": f"{PROJECT_DIR_r}ENCODE/ENCFF414OGC.bigWig",
	   "H3K27ac": f"{PROJECT_DIR_r}ENCODE/ENCFF849TDM.bigWig"
		}

In [21]:
# For better visualizaiton and easier interpretation, we flipped the side with more reads to the right side

for g in gs:
	bed = f"{PROJECT_DIR_d}bed_plot/{s}_{g}.bed"
	regions = BEDXReader.read_all(GenomicCollection, bed)
	non_flip_regions = list(filter(lambda a: a.stranded_genomic_pos.strand == "+", regions))
	flip_regions = list(filter(lambda a: a.stranded_genomic_pos.strand == "-", regions))
	pl_non_flip = bpm.generate_signal_profile(non_flip_regions, BigWigIReader(bws["PROcap_pl"]), fixed_size=1001, use_strand=True)
	pl_flip = bpm.generate_signal_profile(flip_regions, BigWigIReader(bws["PROcap_pl"]), fixed_size=1001, use_strand=True)
	mn_non_flip = bpm.generate_signal_profile(non_flip_regions, BigWigIReader(bws["PROcap_mn"]), fixed_size=1001, use_strand=True)
	mn_flip = bpm.generate_signal_profile(flip_regions, BigWigIReader(bws["PROcap_mn"]), fixed_size=1001, use_strand=True)
	break

AttributeError: 'BEDX' object has no attribute 'strand'

In [8]:
ms = ["PROcap", "DNase", "H3K27ac"]

In [8]:
pwpool = ProcessWrapPool(20)

for g in gs:
	for c in cs:
		bed_in = PROJECT_DIR_d + "bed_plot/" + "_".join([s, g+c]) + ".bed"
		for m in bws:
			output1 = PROJECT_DIR_o + "temp/" + "_".join([s, m, g+c]) + ".txt"
			output2 = PROJECT_DIR_o + "temp/" + "_".join([s, m, g+c, "reformat"]) + ".txt"
			if "PROcap" in m:		
				strand_sensitive = True
				strand = m.split("_")[1]
			else:
				strand_sensitive = False
				strand = None		
			pwpool.run(utils.get_feature_matrix, args=[bed_in, bws[m], output1, output2, strand_sensitive, strand])

In [12]:
# 4*2*4

len(pwpool.finished_tasks)

32

In [13]:
pwpool.close()

In [14]:
# Check the output1

inputfile = PROJECT_DIR_o + "temp/" + "_".join([s, m, g+c]) + ".txt"
df = pd.read_table(inputfile)
df.head(2)

Unnamed: 0,Element,0,1,2,3,4,5,6,7,8,...,991,992,993,994,995,996,997,998,999,1000
0,"('chr1', '2967063', '2967563')",1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,...,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906,1.01906
1,"('chr1', '4538064', '4538564')",0.80969,0.80969,0.80969,0.80969,0.80969,0.80969,0.80969,0.80969,0.80969,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [15]:
# Check the output2

inputfile = PROJECT_DIR_o + "temp/" + "_".join([s, m, g+c, "reformat"]) + ".txt"
df = pd.read_table(inputfile)
df.head(2)

Unnamed: 0,Element,Position,Feature
0,"('chr1', '2967063', '2967563')",0,1.01906
1,"('chr1', '4538064', '4538564')",0,0.80969


## 1.2 Combine dataframes for plotting

In [16]:
dfs_heatmap = {}
for m in ms:
	dfs = []
	for g in gs:
		if m == "PROcap":
			inputfiles = [PROJECT_DIR_o + "temp/" + "_".join([s, m, k, g]) + ".txt" for k in ks]
			df1, df2 = [pd.read_table(inputfile) for inputfile in inputfiles]
			elements = df1["Element"]
			df = pd.concat([df1.drop("Element", axis=1), df2.drop("Element", axis=1)], axis=1)
			df["Element"] = elements
		else:
			inputfile = PROJECT_DIR_o + "temp/" + "_".join([s, m, g]) + ".txt"
			df = pd.read_table(inputfile)
		df["Group"] = g
		dfs.append(df)
	dfs_heatmap[m] = pd.concat(dfs, ignore_index=True)	

In [17]:
# Get the sign for each element

signs = {}
for g in gs:
	inputfile = PROJECT_DIR_d + "bed_plot/" + "_".join([s, g]) + ".bed"
	with DelimitedReader(inputfile) as dr:
		for cols in dr:
			chrom, _, _, start, end, _, _, sign = cols
			signs[str((chrom, start, end))] = sign

In [18]:
def PROcap_label(row, k, signs):
    if signs[row["Element"]] == "+":
        return "maxTSS" if k == "pl" else "minTSS"
    else:
        return "minTSS" if k == "pl" else "maxTSS"

In [19]:
dfs_metaplot = {}
for m in ms:
	dfs = []
	for g in gs:
		if m == "PROcap":
			# show maximum and minimum TSS separately, otherwise the signals can be cancelled out
			for k in ks:
				inputfile = PROJECT_DIR_o + "temp/" + "_".join([s, m, k, g, "reformat"]) + ".txt"
				df = pd.read_table(inputfile)
				df["Group"] = g
				df["Label"] = df.apply(lambda row: PROcap_label(row, k, signs), axis=1)
				dfs.append(df)
		else:
			for c in cs:
				inputfile = PROJECT_DIR_o + "temp/" + "_".join([s, m, g+c, "reformat"]) + ".txt"
				df = pd.read_table(inputfile)
				df["Group"] = g
				if c:
					df["Label"] = "Control"
				else:
					df["Label"] = "TRE"
				dfs.append(df)
	dfs_metaplot[m] = pd.concat(dfs, ignore_index=True)	

# 2. Metaplots & heatmaps

## 2.1 Settings

### 2.1.1 ylims

In [20]:
# Decide on the ylims
# Divergent & unidirectional
# TREs & controls
# Set same value ranges for metaplots and heatmaps

def get_ylims(ms):
	for d in ds:
		for m in ms:
			maxs = []
			mins = []
			if m == "PROcap":
				labels = ["maxTSS", "minTSS"]
			else:
				labels = ["TRE", "Control"]
			for p in ps:
				df = dfs_metaplot[m][dfs_metaplot[m]["Group"]==p+"_"+d]
				for label in labels:
					maxs.append(df[df["Label"]==label].groupby("Position")["Feature"].mean().max())
					mins.append(df[df["Label"]==label].groupby("Position")["Feature"].mean().min())
			print(d, m, max(maxs), min(mins))

In [21]:
get_ylims(ms)

distal PROcap 63.31777441508555 -15.302875101850773
distal DNase 3.313858404529022 0.042020992283974934
distal H3K27ac 16.178052635240952 0.6100446112028132
proximal PROcap 481.4731309664271 -108.06896921591762
proximal DNase 5.482331557567373 0.04443297767064024
proximal H3K27ac 21.700563841246552 0.6094325536293212


### 2.1.2 cmap

In [22]:
# PRO-cap

customized_cmaps = {} 
for d in ds:
	if d == "proximal":
		cvals = [-250, 0, 500]
	else:
		cvals = [-40, 0, 80]
	colors = ["#2c7bb6", "#ffffff", "#d7191c"]
	norm = plt.Normalize(min(cvals), max(cvals))
	tuples = list(zip(map(norm, cvals), colors))
	cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", tuples)
	customized_cmaps[d] = cmap

In [23]:
cmaps = {}
for d in ds:
	cmaps[d] = [customized_cmaps[d]] + ["viridis"]*(len(ms)-1)

### 2.1.3 Sorting

In [24]:
# Get distance to center for each element as we'll display elements in heatmap in the following order 
# Divergent: sort by the distance between two prominent TSSs
# Unidirectional: sort by the distance between the prominent TSS and the center of overlapping DNase peaks

inputfiles = [PROJECT_DIR_d + "bed_plot/" + "_".join([s, p, d]) + ".bed" for p in ps for d in ds]
outputfile = PROJECT_DIR_o + "labels/" + "_".join([s, "distance"]) + ".txt"
utils.get_element_distance(inputfiles, outputfile)

In [25]:
# Check the output

inputfile = PROJECT_DIR_o + "labels/" + "_".join([s, "distance"]) + ".txt"
df = pd.read_table(inputfile)
df.head(2)

Unnamed: 0,Element,Value
0,"('chr1', '1021175', '1021375')",51
1,"('chr1', '1058364', '1058581')",82


## 2.2 Generate metaplots & heatmaps

In [31]:
def generate_metaplot_and_heatmap(d, ms, sort_file, ylims, yticks, cmaps, test, outputfile, xlabel="Distance to center (kb)", fontsize=25, labelpad=7.5, labelsize=20, y_align= -0.2, wspace=0.25, hspace=0.2):
	"""
	Generate metaplots and heatmaps for a list of features.  
	"""
	
	width_ratios = [12]*len(ps) + [1]
	height_ratios = [10,10]*len(ms)
	fig, axes = plt.subplots(len(ms)*2,len(ps)+1, figsize=(7*len(ps), 7*len(ms)), gridspec_kw={'width_ratios': width_ratios, 'height_ratios': height_ratios})

	for m in ms:
		for p in ps:
			col = ps.index(p)
			for n in range(2):
				row = ms.index(m) * 2 + n
				ax = axes[row, col]
				# Metaplot
				if n == 0:
					if m == "PROcap":
						hue_order = ["maxTSS", "minTSS"]
						palette = ["#d7191c", "#2c7bb6"]
					else:
						hue_order = ["Control", "TRE"]
						# Divergent
						if col == 0:
							palette = ["#969696", "#313695"]
						# Unidirectional
						else:
							palette = ["#969696", "#de77ae"]
							
					df = dfs_metaplot[m][dfs_metaplot[m]["Group"]==p+"_"+d]
					utils.generate_feature_metaplot(df, palette, hue_order, ax, test)
						
					# Y axis
					ax.set_ylim(ylims[ms.index(m)])
					ax.set_yticks(yticks[ms.index(m)])
					if col == 0:
						ax.set_ylabel(m, fontsize=fontsize, fontweight="bold")
						ax.tick_params(labelsize=labelsize, pad=labelpad)
						ax.get_yaxis().set_label_coords(y_align, 0.5)
					else:
						ax.set_ylabel("")
						ax.set_yticklabels([])

					if row == 0:
						ax.set_title(p.capitalize(), fontsize=40, fontweight="bold", pad=20)
						
					# Legend
					if (row == 0 and col == 1) or row == 2:
						ax.legend(loc="upper right", fontsize=16)
					else:
						ax.legend([],[], frameon=False)
						
				# Heatmap
				else:
					if col == len(ps)-1:
						cbar = True
						cbar_ax = axes[row, len(ps)]
					else:
						cbar = False
						cbar_ax = None				
					cbar_kws = {"ticks": yticks[ms.index(m)]}
					
					df = dfs_heatmap[m][dfs_heatmap[m]["Group"]==p+"_"+d]
					utils.generate_feature_heatmap(df, sort_file, yticks[ms.index(m)], cmaps[ms.index(m)], cbar, cbar_ax, cbar_kws, ax, test)

					if col == len(ps)-1:
						cbar_ax.set_yticklabels(yticks[ms.index(m)])
						cbar_ax.tick_params(axis="y", labelsize=labelsize, pad=labelpad)
			
				# X axis
				xtick_list = [0,500,1000]
				xticklabel_list = ["-0.5", "0", "0.5"]
				ax.set_xlim([0, 1000])
				ax.set_xticks(xtick_list)
				if row == len(ms)*2-1:
					ax.set_xticklabels(xticklabel_list)
					ax.set_xlabel(xlabel, fontsize=fontsize, fontweight="bold")
					ax.tick_params(labelsize=labelsize, pad=labelpad)
				else:
					ax.set_xticklabels([])
					ax.set_xlabel("")

				# Leave the last subplot in each even row empty
				if row % 2 == 0:
					axes[row, len(ks)].set_visible(False)

	fig.subplots_adjust(wspace=wspace, hspace=hspace)
	plt.savefig(outputfile, bbox_inches = 'tight', dpi=300) 

In [32]:
pwpool = ProcessWrapPool(2)

# test = True
test = False
sort_file = PROJECT_DIR_o + "labels/" + "_".join([s, "distance"]) + ".txt"
for d in ds:
	if d == "distal":
		ylims = [[-50, 90], [-0.5, 5], [-2, 20]]
		yticks = [[-40, 0, 40, 80], [0, 1.5, 3, 4.5], [0, 6, 12, 18]]
		outputfile = PROJECT_DIR_o + "figures/Fig1d.png"
	else:
		ylims = [[-300, 550], [-0.5, 6.5], [-2, 26]]
		yticks = [[-250, 0, 250, 500], [0, 2, 4, 6], [0, 8, 16, 24]]
		outputfile = PROJECT_DIR_o + "supp_figures/SuppFig1c.png"
	pwpool.run(generate_metaplot_and_heatmap, args=[d, ms, sort_file, ylims, yticks, cmaps[d], test, outputfile])

In [35]:
len(pwpool.finished_tasks)

2

In [36]:
pwpool.close()