## Imports

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import jax
import pandas as pd
import numpy as np
import jax.numpy as jnp
import haiku as hk
import seaborn as sns
import matplotlib.pyplot as plt
from synbio_morpher.utils.data.data_format_tools.manipulate_fasta import load_seq_from_FASTA
from synbio_morpher.utils.evolution.mutation import get_mutation_type_mapping
from synbio_morpher.utils.misc.string_handling import convert_liststr_to_list


if __package__ is None:

    module_path = os.path.abspath(os.path.join('..'))
    sys.path.append(module_path)

    __package__ = os.path.basename(module_path)
    

from src.models.nucleotide_transformer import NucleotideTransformerConfig, build_nucleotide_transformer_fn
from src.models.pretrained import FixedSizeNucleotidesKmersTokenizer

root_dir = '..'

# Processing

In [31]:
rel_fn = 'data/raw/ensemble_mutation_effect_analysis/2023_07_17_105328/tabulated_mutation_info.csv'
# rel_fn = 'data/raw/ensemble_mutation_effect_analysis/2023_06_05_164913/tabulated_mutation_info.csv'
# rel_fn = 'data/raw/ensemble_mutation_effect_analysis/2023_05_16_174613/tabulated_mutation_info.csv'
fn = os.path.join(root_dir, rel_fn)

In [32]:
data = pd.read_csv(fn)
data.head()

Unnamed: 0,circuit_name,mutation_name,mutation_num,mutation_type,mutation_positions,path_to_template_circuit,index,name,interacting,self_interacting,...,RMSE_diff_to_base_circuit,steady_states_diff_to_base_circuit,fold_change_ratio_from_mutation_to_base,initial_steady_states_ratio_from_mutation_to_base,max_amount_ratio_from_mutation_to_base,min_amount_ratio_from_mutation_to_base,overshoot_ratio_from_mutation_to_base,RMSE_ratio_from_mutation_to_base,steady_states_ratio_from_mutation_to_base,sample_name
0,toy_mRNA_circuit_0,ref_circuit,0,[],[],data/ensemble_mutation_effect_analysis/2023_07...,0.0,toy_mRNA_circuit_0,[[0 1]],[[1 1]],...,0.0,0.0,1.0,1.0,1.0,1.0,1.0,inf,1.0,RNA_0
1,toy_mRNA_circuit_0,ref_circuit,0,[],[],data/ensemble_mutation_effect_analysis/2023_07...,0.0,toy_mRNA_circuit_0,[[0 1]],[[1 1]],...,0.0,0.0,1.0,1.0,1.0,1.0,1.0,inf,1.0,RNA_1
2,toy_mRNA_circuit_0,ref_circuit,0,[],[],data/ensemble_mutation_effect_analysis/2023_07...,0.0,toy_mRNA_circuit_0,[[0 1]],[[1 1]],...,0.0,0.0,1.0,1.0,1.0,1.0,inf,inf,1.0,RNA_2
3,toy_mRNA_circuit_0,RNA_0_m1-0,1,[10],[14],data/ensemble_mutation_effect_analysis/2023_07...,0.0,toy_mRNA_circuit_0,[[0 1]],[[1 1]],...,8.53363,8.585205,0.976079,1.0472,1.0,1.0472,0.307851,inf,1.02215,RNA_0
4,toy_mRNA_circuit_0,RNA_0_m1-0,1,[10],[14],data/ensemble_mutation_effect_analysis/2023_07...,0.0,toy_mRNA_circuit_0,[[0 1]],[[1 1]],...,5.955922,5.909309,1.033255,1.031187,1.031187,1.087942,0.473072,inf,1.065479,RNA_1


Preprocess

In [33]:
data['mutation_type'] = convert_liststr_to_list(data['mutation_type'].str)
data['mutation_positions'] = convert_liststr_to_list(data['mutation_positions'].str)

In [7]:
# source_dir = '/home/wadh6511/Kode/gene-circuit-glitch-prediction'
# paths = data['path_to_template_circuit'].unique()
# for p in paths:
#     if type(p) != str:
#         continue
#     fp = os.path.join(source_dir, p)
    
#     p_spl = str(p.split('data/')[-1])
#     dst = os.path.join(root_dir, 'data', 'raw', p_spl)
#     # print(dst)
#     shutil.copyfile(fp, dst)


## Mode 1 - string of RNA

Add the actual sequence pre-mutation to as a field

In [34]:
data['mutation_species'] = data['mutation_name'].str[:5]
data.loc[data['mutation_species'] == 'ref_c', 'mutation_species'] = None
# data['src_sequence'] = np.nan



Load RNA sequences.

In [36]:
mutation_species = data['mutation_species'].unique()
circuit_name = data['circuit_name'].unique()
path_to_template_circuit = list(data[data['mutation_num'] > 0]['path_to_template_circuit'].unique())
# path_to_template_circuit = jax.tree_util.tree_map(lambda x: x.replace('..', '../../synbio_morpher'),path_to_template_circuit)
# circuit_paths = jax.tree_util.tree_map(lambda x: os.path.join(root_dir, 'data', 'raw', str(x.split('data/')[-1])), path_to_template_circuit)
circuit_paths = jax.tree_util.tree_map(lambda x: os.path.join('..', '..', 'synbio_morpher', x), path_to_template_circuit)
fastas = jax.tree_util.tree_map(lambda cp: load_seq_from_FASTA(cp, as_type='dict'), circuit_paths)
fasta_d = dict(zip(circuit_name, fastas))

In [37]:
# data['src_sequence'] = jax.tree_util.tree_map(lambda cn, ms, sn: fasta_d[cn][ms] if ms != 'ref_c' else fasta_d[cn][sn], data['circuit_name'].to_list(), data['mutation_species'].to_list(), data['sample_name'].to_list())

Simplify mutation_types

In [38]:


data['mutation_types_simp'] = data['mutation_type']

mutation_type_mapping = get_mutation_type_mapping('RNA')
mutation_type_mapping_simp = {k: v for k, v in zip(mutation_type_mapping.keys(), np.arange(1, len(mutation_type_mapping)+1))}
mutation_map_translation = {}
for (ka, kb), v in jax.tree_util.tree_flatten_with_path(mutation_type_mapping)[0]:
    mutation_map_translation[v] = mutation_type_mapping_simp[kb.key]

data['mutation_types_simp'] = data['mutation_types_simp'].apply(lambda x: jax.tree_util.tree_map(lambda y: mutation_map_translation[y] if x else [], x))
mutation_map_translation


{0: 2, 1: 3, 2: 4, 3: 1, 4: 3, 5: 4, 6: 1, 7: 2, 8: 4, 9: 1, 10: 2, 11: 3}

In [39]:
def apply_values(sequence, indices, values):
    # GCG
    result = np.zeros(len(sequence))
    list(map(lambda idx, val: result.__setitem__(idx, val), indices, values))
    return result

sequence = 'ABDSAFD'
indices = (0, 5, 6)
values = (1, 1, 3)

output = apply_values(sequence, indices, values)
output

# Expected output: [1., 0., 0., 0., 0., 1., 3.]

array([1., 0., 0., 0., 0., 1., 3.])

In [40]:
# data['mutation_types_seq'] = jax.tree_util.tree_map(lambda seq, typs, pos: apply_values(seq, pos, typs) if typs else np.zeros(len(seq)), data['src_sequence'].to_list(), data['mutation_types_simp'].to_list(), data['mutation_positions'].to_list())

### Combine into input

Circuit represented as $n$ RNA sequences

In [41]:


def reverse_mut_mapping(mut_encoding: int, sequence_type: str = 'RNA'):
    for k, v in get_mutation_type_mapping(sequence_type).items():
        if mut_encoding in list(v.values()):
            for mut, enc in v.items():
                if enc == mut_encoding:
                    return mut
    raise ValueError(
        f'Could not find mutation for mapping key {mut_encoding}.')
    
def apply_values(sequence, indices, values):
    result = np.array(list(sequence))
    list(map(lambda idx, val: result.__setitem__(idx, val), indices, values))
    return ''.join(result)
# data['src_sequence'] = jax.tree_util.tree_map(lambda cn, ms, sn: fasta_d[cn][ms] if ms != 'ref_c' else fasta_d[cn][sn], data['circuit_name'].to_list(), data['mutation_species'].to_list(), data['sample_name'].to_list())

d = {v: v for v in jax.tree_util.tree_flatten(get_mutation_type_mapping('RNA'))[0]}
mutation_type_mapping_rev = jax.tree_util.tree_map(lambda x: reverse_mut_mapping(x), d)



In [None]:
[data[x].iloc[:5].to_list() for x in ['circuit_name', 'sample_name', 'mutation_species', 'mutation_positions', 'mutation_type_explicit']]

In [43]:
data['mutation_type_explicit'] = jax.tree_util.tree_map(lambda mt: mutation_type_mapping_rev[mt], data['mutation_type'].to_list())
data['sample_seq'] = jax.tree_util.tree_map(lambda cn, sn, ms, mp, mt: fasta_d[cn][sn] if sn != ms else apply_values(fasta_d[cn][sn], mp, mt), 
                                            *[data[x].to_list() for x in ['circuit_name', 'sample_name', 'mutation_species', 'mutation_positions', 'mutation_type_explicit']])
# data['sample_seq'] = jax.tree_util.tree_map(lambda cn, sn: fasta_d[cn][sn], data['circuit_name'].to_list(), data['sample_name'].to_list())

In [None]:


# data['sample_seq_mut'] = jax.tree_util.tree_map(
#     lambda ms, sn, seq, mp, mt: apply_values(seq, mp, [mutation_type_mapping_rev[t] for t in mt]) if mp and ms == sn else seq, 
#     *[data[x].to_list() for x in ['mutation_species', 'sample_name', 'sample_seq', 'mutation_positions', 'mutation_type']])


## Labels

For now, we will discretise the labels on a log scale to accommodate sensitivity and precision. 

In [44]:
data['sensitivity_wrt_species-6_ratio_from_mutation_to_base']

0          1.000000
1          1.000000
2          1.000000
3          1.000000
4          0.385668
             ...   
1628995    1.000000
1628996    1.000000
1628997    1.000000
1628998    1.000000
1628999    1.000000
Name: sensitivity_wrt_species-6_ratio_from_mutation_to_base, Length: 1629000, dtype: float64

In [None]:
from synbio_morpher.utils.results.analytics.naming import get_true_names_analytics

get_true_names_analytics(data)

In [46]:
dirlist = [root_dir, 'data', 'processed'] + fn.split('raw/')[-1].split(os.sep)
for i, d in enumerate(dirlist[:-1]):
    if not os.path.isdir(os.path.join(*dirlist[:i+1])):
        print('Making directory ', os.path.join(*dirlist[:i+1]))
        os.makedirs(os.path.join(*dirlist[:i+1]), mode=0o777)
    
# os.path.join(*dirlist)

Making directory  ../data/processed/ensemble_mutation_effect_analysis/2023_07_17_105328


In [47]:
data[['circuit_name', 'mutation_name', 'mutation_species', 'mutation_num', 'sample_seq', 'sample_name'] + [
    'fold_change',
    'initial_steady_states',
    'max_amount',
    'min_amount',
    'overshoot',
    'RMSE',
    'steady_states',
    'response_time_wrt_species-6',
    # 'response_time_wrt_species-6_diff_to_base_circuit',
    'response_time_wrt_species-6_ratio_from_mutation_to_base',
    'precision_wrt_species-6',
    # 'precision_wrt_species-6_diff_to_base_circuit',
    'precision_wrt_species-6_ratio_from_mutation_to_base',
    'sensitivity_wrt_species-6',
    # 'sensitivity_wrt_species-6_diff_to_base_circuit',
    'sensitivity_wrt_species-6_ratio_from_mutation_to_base',
    # 'fold_change_diff_to_base_circuit',
    # 'initial_steady_states_diff_to_base_circuit',
    # 'max_amount_diff_to_base_circuit',
    # 'min_amount_diff_to_base_circuit',
    # 'overshoot_diff_to_base_circuit',
    # 'RMSE_diff_to_base_circuit',
    # 'steady_states_diff_to_base_circuit',
    'fold_change_ratio_from_mutation_to_base',
    'initial_steady_states_ratio_from_mutation_to_base',
    'max_amount_ratio_from_mutation_to_base',
    'min_amount_ratio_from_mutation_to_base',
    'overshoot_ratio_from_mutation_to_base',
    # 'RMSE_ratio_from_mutation_to_base',
    'steady_states_ratio_from_mutation_to_base'
]].to_csv(os.path.join(*dirlist))