In [None]:
fig = px.scatter_3d(
    infom,
    x='sensitivity_wrt_species-6' + '_null', y='precision_wrt_species-6' + '_null', z='sp_distance', 
    color='mutation_num')
fig.show()



In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import os
import corner
import sys
from copy import deepcopy

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import plotly.express as px


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)


from src.srv.sequence_exploration.sequence_analysis import b_tabulate_mutation_info
from src.utils.data.data_format_tools.common import load_json_as_dict
from src.utils.common.setup_new import prepare_config
from src.utils.circuit.agnostic_circuits.circuit_manager_new import CircuitModeller
from src.utils.evolution.evolver import Evolver
from src.utils.misc.numerical import count_monotonic_group_lengths, find_monotonic_group_idxs, is_within_range
from src.utils.misc.string_handling import string_to_tuple_list
from src.utils.misc.type_handling import flatten_listlike, get_first_elements
from src.utils.results.analytics.naming import get_analytics_types_all, get_true_names_analytics, get_true_interaction_cols
from tests_local.shared import create_test_inputs, CONFIG, TEST_CONFIG, five_circuits, mutate, simulate

# config = load_json_as_dict('../tests_local/configs/simple_circuit.json')
SEQ_LENGTH = 20
config = deepcopy(CONFIG)

In [None]:
fn = '../data/ensemble_mutation_effect_analysis/2023_04_11_192013/summarise_simulation/tabulated_mutation_info.csv'
info = pd.read_csv(fn)

## Process summary

### Define new fields

Define distance metric to robust adaptation region

In [None]:
info['sp_distance'] = 0 
info.loc[(info['sensitivity_wrt_species-6'] <= 1) & (info['precision_wrt_species-6'] <= 10), 'sp_distance'] = np.sqrt(
    np.power(1-info['sensitivity_wrt_species-6'], 2) + np.power(10 - info['precision_wrt_species-6'], 2))
info.loc[(info['sensitivity_wrt_species-6'] <= 1) & (info['precision_wrt_species-6'] > 10), 'sp_distance'] = np.absolute(info['sensitivity_wrt_species-6'] - 1)
info.loc[(info['sensitivity_wrt_species-6'] > 1) & (info['precision_wrt_species-6'] <= 10), 'sp_distance'] = np.absolute(info['precision_wrt_species-6'] - 10)


Mutation binding site groups and locations

In [None]:
num_group_cols = [e.replace('energies', 'binding_sites_groups') for e in get_true_interaction_cols(info, 'energies')]
num_bs_cols = [e.replace('energies', 'binding_sites_count') for e in get_true_interaction_cols(info, 'energies')]
bs_idxs_cols = [e.replace('energies', 'binding_sites_idxs') for e in get_true_interaction_cols(info, 'energies')]
bs_range_cols = [e.replace('energies', 'binding_site_group_range') for e in get_true_interaction_cols(info, 'energies')]

for b, g, bs, bsi, r in zip(get_true_interaction_cols(info, 'binding_sites'), num_group_cols, num_bs_cols, bs_idxs_cols, bs_range_cols):
    fbs = [string_to_tuple_list(bb) for bb in info[b]]
    first = get_first_elements(fbs, empty_replacement=[])
    info[bs] = [count_monotonic_group_lengths(bb) for bb in first]
    info[bsi] = [find_monotonic_group_idxs(bb) for bb in first]
    info[g] = info[bs].apply(len)
    info[r] = [[(bb[0], bb[-1]) for bb in b] for b in info[bsi]]

Mutation number ratiometric change:

In [None]:
numerical_cols = [c for c in info.columns if (type(info[(info['mutation_num'] > 0) & (info['eqconstants_0-0'] > 1)][c].iloc[0]) != str) and (type(info[c].iloc[0]) != list) and c not in get_true_interaction_cols(info, 'binding_sites')]
key_cols = ['circuit_name', 'interacting', 'mutation_name', 'name', 'sample_name']

# Group by 'name'
grouped = info.groupby(['circuit_name', 'sample_name'], as_index=False)

# Subtract the values from the zero row from each group
mutation_log = grouped[numerical_cols].apply(lambda x: np.abs(np.log(x.loc[x['mutation_num'] == 0].squeeze() / x)))
mutation_log['mutation_num'] = info['mutation_num']
mutation_log['RMSE'] = info['RMSE']
mutation_log['sp_distance'] = info['sp_distance']
for c in key_cols:
    mutation_log[c] = info[c]

### Expand DataFrames

Melt energies:

In [None]:
good_cols = list(info.columns)
[good_cols.remove(x) for x in get_true_interaction_cols(info, 'binding_rates_dissociation') + get_true_interaction_cols(info, 'eqconstants') +
 get_true_interaction_cols(info, 'energies') + get_true_interaction_cols(info, 'binding_sites') + num_group_cols + num_bs_cols]
good_cols

binding_idx_map = {e.replace('energies_', ''): i for i, e in enumerate(get_true_interaction_cols(info, 'energies'))}


In [None]:
infom = info.melt(good_cols, value_vars=get_true_interaction_cols(info, 'energies'), var_name='energies_idx', value_name='energies')
dfm = info.melt(good_cols, value_vars=num_group_cols, var_name='num_groups_idx', value_name='num_groups')
infom['idx_species_binding'] = dfm['num_groups_idx'].apply(lambda x: binding_idx_map[x.replace('binding_sites_groups_', '')])
infom['num_groups'] = dfm['num_groups']
dfm = info.melt(good_cols, value_vars=num_bs_cols, var_name='num_bs_idx', value_name='num_bs')
infom['num_bs'] = dfm['num_bs']

for k in ['binding_sites', 'binding_rates_dissociation', 'eqconstants']:
    dfm = info.melt(good_cols, value_vars=get_true_interaction_cols(info, k), var_name=f'{k}_idx', value_name=k)
    infom[k] = dfm[k]

Melt mutation logs:

In [None]:
mutation_cols = [c for c in numerical_cols + key_cols if c not in get_true_interaction_cols(mutation_log, 'energies') +
                 get_true_interaction_cols(mutation_log, 'binding_rates_dissociation') +
                 get_true_interaction_cols(mutation_log, 'eqconstants') +
                 get_true_interaction_cols(mutation_log, 'binding_sites_groups')]
mutation_logm = mutation_log.melt(mutation_cols, value_vars=get_true_interaction_cols(mutation_log, 'energies'), var_name='energies_idx', value_name='energies')

for k in ['binding_rates_dissociation', 'eqconstants', 'binding_sites_groups']:
    dfm = mutation_log.melt(mutation_cols, value_vars=get_true_interaction_cols(mutation_log, k), var_name=f'{k}_idx', value_name=k)
    mutation_logm[k] = dfm[k]
    
for c in ['idx_species_binding', 'num_groups', 'num_bs']:
    mutation_logm[c] = infom[c]

Energy diffs

In [None]:
for k in ['binding_rates_dissociation', 'eqconstants', 'energies']:
    infom[f'{k}_diffs'] = info.groupby(['circuit_name'])[get_true_interaction_cols(info, f'{k}')].apply(lambda x: x - x.iloc[0]).melt(value_vars=get_true_interaction_cols(info, f'{k}'), var_name='idx', value_name=f'{k}_diffs')[f'{k}_diffs']
    mutation_logm[f'{k}_diffs'] = mutation_log.groupby(['circuit_name'])[get_true_interaction_cols(mutation_log, f'{k}')].apply(lambda x: x - x.iloc[0]).melt(value_vars=get_true_interaction_cols(mutation_log, f'{k}'), var_name='idx', value_name=f'{k}_diffs')[f'{k}_diffs']


### STD's

In [None]:
relevant_cols = [
    'fold_change', 
    # 'initial_steady_states', 
    # 'max_amount', 'min_amount',
    'overshoot', 
    'RMSE', 
    'steady_states', 
    # 'response_time_wrt_species-6',
    # 'response_time_wrt_species-6_diff_to_base_circuit',
    # 'response_time_wrt_species-6_ratio_from_mutation_to_base',
    'precision_wrt_species-6',
    'precision_wrt_species-6_diff_to_base_circuit',
    'precision_wrt_species-6_ratio_from_mutation_to_base',
    'sensitivity_wrt_species-6',
    'sensitivity_wrt_species-6_diff_to_base_circuit',
    'sensitivity_wrt_species-6_ratio_from_mutation_to_base',
    'fold_change_diff_to_base_circuit',
    # 'initial_steady_states_diff_to_base_circuit',
    # 'max_amount_diff_to_base_circuit', 'min_amount_diff_to_base_circuit',
    'overshoot_diff_to_base_circuit', 
    # 'RMSE_diff_to_base_circuit',
    'steady_states_diff_to_base_circuit',
    'fold_change_ratio_from_mutation_to_base',
    # 'initial_steady_states_ratio_from_mutation_to_base',
    # 'max_amount_ratio_from_mutation_to_base',
    # 'min_amount_ratio_from_mutation_to_base',
    # 'overshoot_ratio_from_mutation_to_base',
    # 'RMSE_ratio_from_mutation_to_base',
    'steady_states_ratio_from_mutation_to_base', 
    # 'num_groups',
    'energies', 
    'binding_rates_dissociation',
    'eqconstants',
    'energies_diffs', 
    'binding_rates_dissociation_diffs',
    'eqconstants_diffs'
    ]

named_aggs = {}
for c in relevant_cols:
    named_aggs.update({c + '_std': pd.NamedAgg(column=c, aggfunc="std")})
    named_aggs.update({c + '_mean': pd.NamedAgg(column=c, aggfunc="mean")})
    named_aggs.update({c + '_std_normed_by_mean': pd.NamedAgg(column=c, aggfunc=lambda x: np.std(x) / np.max([1, np.mean(x)]))})
info_summ = infom.groupby(['circuit_name', 'mutation_num', 'sample_name'], as_index=False).agg(**named_aggs)
    # {c: [np.std, np.mean, lambda x: np.std(x) / np.mean(x)] for c in relevant_cols})

# Visualisations

In [None]:
sns.histplot(mutation_log, x='fold_change', hue='mutation_num', log_scale=[False, True], element='step', bins=100)


In [None]:
fig = px.scatter_3d(
    mutation_log,
    x='sensitivity_wrt_species-6', y='precision_wrt_species-6', z='sp_distance', 
    color='mutation_num')
fig.show()

In [None]:
import plotly.graph_objects as go

xv = np.array(
    mutation_log['sensitivity_wrt_species-6']
)
yv = np.array(
    mutation_log['precision_wrt_species-6']
)
h = np.histogram2d(xv, yv, bins=50)


def bin_centers(bins):
    centers = (bins + (bins[1]-bins[0])/2)[:-1]
    return centers


x_bins_centers = bin_centers(h[1])
y_bins_centers = bin_centers(h[2])
x_bins_centers = h[1][:-1]
y_bins_centers = h[2][:-1]

df = pd.DataFrame(np.log(h[0]), index=x_bins_centers, columns=y_bins_centers)
fig = go.Figure(data=[go.Surface(z=df)])
# fig.update_zaxes(type="log")
fig.show()

In [None]:

sns.histplot(mutation_logm, x='binding_rates_dissociation', hue='mutation_num', element='step', log_scale=[False, True], bins = 100)
plt.title('Ratiometric change from reference to mutated circuit')

## Mutational disruption

Expand the mutation positions and types from lists in each cell to distribute list values over rows.

In [None]:

info_summ = mutation_logm.groupby(['circuit_name', 'mutation_num', 'sample_name'], as_index=False).agg(**named_aggs)

info['mutation_type'] = info['mutation_type'].str.strip('[]').str.split(',').apply(lambda x: [int(xx) for xx in x if xx])
info['mutation_positions'] = info['mutation_positions'].str.strip('[]').str.split(',').apply(lambda x: [int(xx) for xx in x if xx])

info_e = info.explode(column=['mutation_type', 'mutation_positions'])

mut_in_bs_cols = [e.replace('energies', 'is_mutation_in_binding_site') for e in get_true_interaction_cols(info, 'energies')]
for isb, r in zip(mut_in_bs_cols, bs_range_cols):
    info_e[isb] = [any([is_within_range(m, r) for r in range_tuples]) for m, range_tuples in zip(info_e['mutation_positions'], info_e[r])]

mut_in_edge_cols = [e.replace('energies', 'is_mutation_on_edge') for e in get_true_interaction_cols(info, 'energies')]
for ise, r in zip(mut_in_edge_cols, bs_range_cols):
    info_e[ise] = [any([(m == r[0]) or (m == r[-1]) for r in range_tuples]) for m, range_tuples in zip(info_e['mutation_positions'], info_e[r])]

info_e[(info_e[isb] == True) & (info_e['mutation_num'] == 0)][['mutation_positions', isb, r]]

infom['frac_muts_in_binding_site'] = info_e.groupby(['circuit_name', 'mutation_num', 'sample_name'], as_index=False).agg({isb: lambda x: sum(x) / np.max([1, len(x)]) for isb in mut_in_bs_cols}).melt(
    id_vars=['circuit_name', 'mutation_num', 'sample_name'], value_vars=mut_in_bs_cols, var_name='idx', value_name='frac_muts_in_binding_site')['frac_muts_in_binding_site']

sns.jointplot(infom[infom['mutation_num'] > 0], x='frac_muts_in_binding_site', y='energies', hue='mutation_num')



Mutation corner

In [None]:
corner_cols = [c for c in mutation_logm.columns if ('diff_to_base_circuit' not in c) and ('ratio_from_mutation' not in c) and ('name' not in c) and ('idx' not in c)]
corner_cols.remove('num_bs')
corner_cols.remove('num_groups')
corner_cols.remove('final_deriv')
corner_cols.remove('energies')
corner_cols.remove('energies_diffs')
corner_cols.remove('num_interacting')
corner_cols.remove('num_self_interacting')
corner_cols.remove('binding_sites_groups')
corner_cols.remove('interacting')
corner_cols.remove(corner_cols[5])
corner_cols

temp = mutation_logm[(mutation_logm['sample_name'] != 'RNA_0') & (mutation_logm['precision_wrt_species-6'] != np.inf) & (mutation_logm['precision_wrt_species-6_diff_to_base_circuit'] != -np.inf)][corner_cols]
fig = corner.corner(temp, labels=corner_cols, show_titles=True, title_kwargs={"fontsize": 12}, color='b')


## Removing null circuit

In [None]:
null_circuit = infom.groupby(['circuit_name', 'mutation_name'], as_index=False).agg({'energies': 'mean'})
null_circuit = null_circuit[(null_circuit['energies'] == 0)].iloc[0] 
null_circuit = infom[(infom['circuit_name'] == null_circuit['circuit_name']) & (infom['mutation_name'] == null_circuit['mutation_name'])]

null_cols = [
    'overshoot',
    'RMSE',
    'steady_states',
    'initial_steady_states',
    'max_amount',
    'min_amount',
    'response_time_wrt_species-6'
]
info_null = deepcopy(infom)
sample_names = info['sample_name'].unique()
for s in sample_names:
    for c in relevant_cols:
        info_null.loc[info_null['sample_name'] == s, c] = np.log(info_null[info_null['sample_name']
                                                                           == s][c] / null_circuit[null_circuit['sample_name'] == s][c].values[0])

for c in relevant_cols:
    infom[c + '_null'] = info_null[c]

info_summ_null = info_null.groupby(['circuit_name', 'mutation_num', 'sample_name'], as_index=False).agg(**named_aggs)


## Visualisations without null circuit


In [None]:
m = mutation_logm[mutation_logm['energies_idx'] == 'energies_0-0']
sns.jointplot(x=m['sensitivity_wrt_species-6'], y=infom['sensitivity_wrt_species-6' + '_null'], hue=infom['mutation_num'])


# %%
sns.histplot(infom, x='sensitivity_wrt_species-6' + '_null', hue='mutation_num', log_scale=[False, True], element='step', bins=100)

# %%
sns.jointplot(infom[infom['energies_idx'] == 'energies_0-0'], x='sensitivity_wrt_species-6', y='precision_wrt_species-6', hue='fold_change_null')
plt.xscale('log')
plt.yscale('log')

plt.plot(np.ones(100), np.linspace(10, 1e7, 100), color='r')
plt.plot(np.linspace(1, 10, 40), np.ones(40) * 10, color='r')

# %%
sns.jointplot(infom[infom['energies_idx'] == 'energies_0-0'], x='sensitivity_wrt_species-6', y='precision_wrt_species-6', hue='sensitivity_wrt_species-6_null')
plt.xscale('log')
plt.yscale('log')


# %%
sns.jointplot(infom[infom['energies_idx'] == 'energies_0-0'], x='sensitivity_wrt_species-6', y='precision_wrt_species-6', hue='steady_states_null')
plt.xscale('log')
plt.yscale('log')


# %%
sns.jointplot(infom[infom['energies_idx'] == 'energies_0-0'], x='steady_states_null', y='sensitivity_wrt_species-6', hue='sample_name')

# %%
sns.histplot(mutation_logm[mutation_logm['energies_idx'] == 'energies_0-0'], x='sensitivity_wrt_species-6', hue='mutation_num', log_scale=[False, True], element='step', bins=100)
# sns.histplot(infom[infom['energies_idx'] == 'energies_0-0'], x='sensitivity_wrt_species-6' + '_null', hue='mutation_num', log_scale=[False, True], element='step')


# %%
fig = px.scatter_3d(
    infom,
    x='sensitivity_wrt_species-6' + '_null', y='precision_wrt_species-6' + '_null', z='sp_distance', 
    color='mutation_num')
fig.show()



# %%
top_circuits = info[info['sample_name'] != 'RNA_0'].sort_values('sp_distance')
top_circuits.iloc[:8]

# %%



