In [None]:
import os
os.chdir("..")

In [None]:
import h5py
import pandas as pd
import numpy as np
import seaborn as sb
import matplotlib as mpl
import matplotlib.pyplot as plt
from glob import glob

In [None]:
from TravelAndMutate.datamanager import getHeteroAttributes, collectAttributeFromGroup, filterGroupmembersWithParams
from TravelAndMutate.trees import Tree

In [None]:
file = h5py.File("analysis/newfitness.h5")

In [None]:
attrs = getHeteroAttributes(file).sort_values(["fitness_delta-","fitness_delta+","fitness_p"])
attrs

In [None]:
for col in attrs:
	print(col, attrs[col].unique())

In [None]:
def plotSurvived(data, ax, alongrows, alongcolumns):
	toplot = data.pivot(index=alongcolumns, columns=alongrows, values="survivalrate")
	cmap = mpl.colormaps["Greys"].with_extremes(bad="yellow")
	mask = (toplot == 0)
	sb.heatmap(toplot, ax=ax, cmap=cmap, mask=mask, norm=mpl.colors.Normalize(vmax=1,vmin=0))
	ax.set_xlabel(labels[alongrows])
	ax.set_ylabel(labels[alongcolumns])
	ax.set_title("Fraction of survived epidemics")
	return ax

def plotInfByHaplosMax(data, ax, alongrows, alongcolumns):
	toplot = data.pivot(index=alongcolumns, columns=alongrows, values="InfByHaplos_max")
	cmap = mpl.colormaps["Reds"].with_extremes(bad="yellow")
	sb.heatmap(toplot, ax=ax, cmap=cmap, norm=mpl.colors.Normalize(vmax=1,vmin=0))
	ax.set_xlabel(labels[alongrows])
	ax.set_ylabel(labels[alongcolumns])
	ax.set_title("Fraction of N° infections caused by most infectious haplo")
	return ax

def plotTreeDepthMean(data, ax, alongrows, alongcolumns):
	toplot = data.pivot(index=alongcolumns, columns=alongrows, values="TreeDepth_mean")
	cmap = mpl.colormaps["Greens"].with_extremes(bad="yellow")
	sb.heatmap(toplot, ax=ax, cmap=cmap, norm=mpl.colors.Normalize())
	ax.set_xlabel(labels[alongrows])
	ax.set_ylabel(labels[alongcolumns])
	ax.set_title("Mean depth of full mutation tree")
	return ax

def plotNChildrenMax(data, ax, alongrows, alongcolumns):
	toplot = data.pivot(index=alongcolumns, columns=alongrows, values="nChildren_max")
	cmap = mpl.colormaps["Blues"].with_extremes(bad="yellow")
	sb.heatmap(toplot, ax=ax, cmap=cmap, norm=mpl.colors.Normalize(vmax=1,vmin=0))
	ax.set_xlabel(labels[alongrows])
	ax.set_ylabel(labels[alongcolumns])
	ax.set_title("Fraction of N° children generated by best parent haplo")
	return ax

def plotB2Norm(data, ax, alongrows, alongcolumns):
	toplot = data.pivot(index=alongcolumns, columns=alongrows, values="B2Norm")
	cmap = mpl.colormaps["Oranges"].with_extremes(bad="yellow")
	sb.heatmap(toplot, ax=ax, cmap=cmap, norm=mpl.colors.Normalize(vmax=1,vmin=0))
	ax.set_xlabel(labels[alongrows])
	ax.set_ylabel(labels[alongcolumns])
	ax.set_title(r"$B_2$-index normalized")
	return ax

def plotCopheneticNorm(data, ax, alongrows, alongcolumns):
	toplot = data.pivot(index=alongcolumns, columns=alongrows, values="CopheneticNorm")
	cmap = mpl.colormaps["Purples"].with_extremes(bad="yellow")
	sb.heatmap(toplot, ax=ax, cmap=cmap)
	ax.set_xlabel(labels[alongrows])
	ax.set_ylabel(labels[alongcolumns])
	ax.set_title(r"Cophenetic-index normalized")
	ax.invert_yaxis()
	return ax

In [None]:
metrics = {
	"survivalrate" : plotSurvived,
	"InfByHaplos_max" : plotInfByHaplosMax,
	"TreeDepth_mean" : plotTreeDepthMean,
	"nChildren_max" : plotNChildrenMax,
	"B2Norm" : plotB2Norm,
	"CopheneticNorm" : plotCopheneticNorm
}
labels = {
	"fitness_p" : r"$p_{\Delta\varphi^-}$",
	"fitness_delta+" : r"$\Delta\varphi^+$",
	"fitness_delta-" : r"$\Delta\varphi^-$",
	"mutation_rate" : r"mean($\eta$)",
}

# 1° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.1,
	"mutation_rate" : 0.015,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 2° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.1,
	"mutation_rate" : 0.025,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 3° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.4,
	"mutation_rate" : 0.015,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 4° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.4,
	"mutation_rate" : 0.025,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 5° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.4,
	"mutation_rate" : 0.008,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 6° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.1,
	"mutation_rate" : 0.008,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 7° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.4,
	"mutation_rate" : 0.005,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

# 8° combination

In [None]:
fixed_params = {
	"fitness_delta-" : -0.1,
	"mutation_rate" : 0.005,
}

In [None]:
groupnames = filterGroupmembersWithParams(file, fixed_params)

data = attrs.loc[groupnames]

for metric in metrics:
	data[metric] = [file[groupname]["single_quantities"].fields(metric)[:].mean() if "single_quantities" in file[groupname] else np.NaN for groupname in data.index]

In [None]:
nrows = int(np.ceil(len(metrics) / 2))
fig, axs = plt.subplots(nrows, 2, figsize=(10,2+3*nrows))
axs = axs.flatten()

title = "\n".join([f"{labels[key]} = {str(val)}" for key,val in fixed_params.items()])
fig.suptitle(title, y=1)

for i,plotfunc in enumerate(metrics.values()):
	plotfunc(data, axs[i], "fitness_delta+", "fitness_p")

fig.tight_layout()
filename = "analysis/single_quantities/" + "-".join(data.index) + ".png"
fig.savefig(filename)

In [None]:
file.close()