In [1]:
import os
import pickle
import pandas as pd
import numpy as np
from itertools import combinations
import matplotlib.pyplot as plt
from numpy import nan

In [2]:
files = [x for x in os.listdir('heatmap_ERF_omicron/') if 'results' in x]

In [3]:
all_results = []
for file in files:
    with open('heatmap_ERF_omicron/'+file, 'rb') as f:
        result_file = pickle.load(f)
        for r in result_file['results']:
            all_results.append(r)
            
result_df = pd.DataFrame(all_results)

In [4]:
with open('heatmap_ERF_omicron/paramgrid_params_omicron.txt','rb') as f:
    param_grid = pickle.load(f)

In [5]:
along_ap = sorted(list(set(result_df['ap'])))

In [6]:
along_m = sorted(list(set(result_df['m'])))
along_n = sorted(list(set(result_df['n'])))
along_ap = sorted(list(set(result_df['ap'])))
along_ttr = sorted(list(set(result_df['ttr'])))
along_qd = sorted(list(set(result_df['qd'])))
along_m.remove(20000)
along_ap.remove(6.03)
along_ttr.remove(5)
along_qd.remove(1-3e-6)
along_n.remove(100)
all_along_params = [('m',along_m), ('n',along_n), ('ap',along_ap), ('ttr',along_ttr), ('qd',along_qd)]
all_names_params = ['m', 'n', 'ap', 'ttr', 'qd']

In [7]:
all_colormap_params = list(combinations(all_along_params, 2))
all_colormap_names = list(combinations(all_names_params,2))

In [8]:
dict_other_vals = {'m':20000, 'n':100, 'ap': 6.03, 'qd':1-3e-6, 'ttr':5}

#### 1) Generating all matrices for colormaps

In [None]:
all_matrices_for_cmap = []

for comb in all_colormap_params:
    params_axes = (comb[0][0], comb[1][0])
    print(params_axes)
    other_params = set(all_names_params)-set(params_axes)

    colormap_params = np.zeros((len(comb[0][1]), len(comb[1][1])))
    
    for xind in range(len(comb[0][1])):
        for yind in range(len(comb[0][1])):
            df_axes = result_df.copy()
            for p in other_params:
                df_axes = df_axes[df_axes[p] == dict_other_vals[p]]
            df_axes = df_axes[(df_axes[params_axes[0]] == comb[0][1][xind])&(df_axes[params_axes[1]] == comb[1][1][yind])]
            colormap_params[xind, yind] = list(df_axes['Y_treatment'])[0]/list(df_axes['Y_control'])[0]
            
    all_matrices_for_cmap.append(colormap_params)

('m', 'n')
('m', 'ap')
('m', 'ttr')


#### 2) Plotting all 10 cmaps

##### Indices for plots

In [None]:
indices_colormaps = []
for i in range(5):
    for j in range(2):
        indices_colormaps.append((i,j))
indices_colormaps = dict(zip(range(10), indices_colormaps))

##### Labels for axes

In [None]:
def calculate_ticks(ticklabels, truerange, tickrange):
    tick_pos = []
    for l in ticklabels:
        NewValue = (((l - truerange[0]) * (tickrange[1] - tickrange[0])) / (truerange[1] - truerange[0])) + tickrange[0]
        tick_pos.append(NewValue)
    return(np.array(tick_pos))

def calculate_ticks_log(ticklabels, truerange, tickrange):
    tick_labels = np.log(ticklabels)
    true_range = np.log(truerange)
    tick_pos = []
    for l in ticklabels:
        NewValue = (((l - truerange[0]) * (tickrange[1] - tickrange[0])) / (truerange[1] - truerange[0])) + tickrange[0]
        tick_pos.append(NewValue)
    return(np.array(tick_pos))

In [None]:
param_ticks = {'m':[1500,5000,10000,15000,20000,25000,29900],
               'n':[1,200,400,600,800,1000],
               'ap':[7.7,8.0,8.3,8.6,8.8],
               'qd':[1-1e-6,1-3e-6,1-5e-6,1-7.5e-6,1-1e-5], 
               'ttr':[0,3,7,10,14]}
param_ranges = {'m':[1500,29900],
                'n':[1,1000],
                'ap':[7.7,8.8],
                'qd':[1e-6,1e-5],
                'ttr':[0,14]}

In [None]:
all_xticks, all_yticks = [],[]
all_xticklabels, all_yticklabels = [],[]

for i in range(10):
    
    x_param, y_param = all_colormap_names[i]

    if y_param == 'qd':
        ticks_y = calculate_ticks_log(1-np.array(param_ticks[y_param]), param_ranges[y_param], [0,19])
    else:
        ticks_y = calculate_ticks(np.array(param_ticks[y_param]), param_ranges[y_param], [0,19])
        
    if y_param == 'qd':
        labels_y = list(["{:.1e}".format(1-x) for x in np.array(param_ticks[y_param])])
    elif y_param == 'ap':
        labels_y = [np.round(x,2) for x in param_ticks[y_param]]
    else:
        labels_y = [int(np.round(x)) for x in param_ticks[y_param]]
        
    if x_param == 'qd': 
        ticks_x = 19-np.array(calculate_ticks_log(1-np.array(param_ticks[x_param]), param_ranges[x_param], [0,19])[::-1])
    else:
        ticks_x = 19-np.array(calculate_ticks(np.array(param_ticks[x_param]), param_ranges[x_param], [0,19])[::-1])
            

    if x_param == 'qd':
        labels_x = list(["{:.1e}".format(1-x) for x in np.array(param_ticks[x_param])])
    elif x_param == 'ap':
        labels_x = [np.round(x,2) for x in param_ticks[x_param]]
    else:
        labels_x = [int(np.round(x)) for x in param_ticks[x_param]]
    
    
    all_xticks.append(ticks_x)
    all_yticks.append(ticks_y)
    all_xticklabels.append(labels_x)
    all_yticklabels.append(labels_y)

In [None]:
params_description = {'m':'Number of lethal positions $m$',
                      'n':'Number of gain-of-function positions $n$',
                      'ap': 'Clearance rate in clearance phase $a_1$',
                      'qd': 'Viral mutation rate under treatment $u_1$',
                      'ttr': 'Time of start of treatment'}

In [None]:
fig, ax = plt.subplots(5,2,figsize=(8,18))

for i in range(10):
    im = ax[indices_colormaps[i]].imshow(all_matrices_for_cmap[i][::-1], cmap = 'seismic', vmin = 0., vmax = 2)
    if all_colormap_params[i][1][0] in ['qd']:
        im = ax[indices_colormaps[i]].imshow(np.flip(all_matrices_for_cmap[i][::-1],1), cmap = 'seismic', vmin = 0., vmax = 2)

    ax[indices_colormaps[i]].set_xticks(all_yticks[i])
    ax[indices_colormaps[i]].set_xticklabels(all_yticklabels[i])
    
    ax[indices_colormaps[i]].set_yticks(all_xticks[i])
    ax[indices_colormaps[i]].set_yticklabels(all_xticklabels[i][::-1])    
    
    
    ax[indices_colormaps[i]].set_ylabel(params_description[all_colormap_params[i][0][0]])
    ax[indices_colormaps[i]].set_xlabel(params_description[all_colormap_params[i][1][0]])
    
    
plt.tight_layout(pad=1.0)
#plt.savefig('colormaps_gradualclearance_all_params_FEB1.jpg', dpi = 600)