In [None]:
import os
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

# Create model

In [None]:
from pypoptim.model import CardiacModel

In [None]:
dirname = '/home/andrey/WORK/HPL/Code/pypoptim/src/model_ctypes/_koivumaki'
filename_so = os.path.join(dirname, 'koivumaki.so')

koivumaki = CardiacModel(filename_so)

legend_constants = pd.read_csv(os.path.join(dirname, "legend_constants.csv"), index_col='name')['value']
legend_states = pd.read_csv(os.path.join(dirname, "legend_states.csv"), index_col='name')['value']
legend_algebraic = pd.read_csv(os.path.join(dirname, "legend_algebraic.csv"), index_col='name')['value']

S = legend_states.copy()
R = S * 0
C = legend_constants.copy()
A = legend_algebraic.copy()

In [None]:
kwargs_model = dict(stim_period_legend_name='STIM_PERIOD',
                    t_sampling=0.0001, n_beats=1, n_beats_save=4,
                    stim_protocol=None)

In [None]:
result = koivumaki.run(S, C, **kwargs_model)
koivumaki._status

In [None]:
plt.plot(result.V)

# Parse config

In [None]:
from pypoptim.helpers import strip_comments, random_value_from_bounds
import json

In [None]:
from pypoptim.cardio import create_genes_dict_from_config, \
                            create_constants_dict_from_config, \
                            generate_bounds_gammas_mask_multipliers

In [None]:
def prepare_config(config_filename):
    
    config_path = os.path.dirname(os.path.realpath(config_filename))

    with open(config_filename) as f:
        text = f.read()
        text = strip_comments(text)
        config = json.loads(text)
        
    config['runtime'] = dict()

    config['runtime']['genes_dict'] = create_genes_dict_from_config(config)
    config['runtime']['constants_dict'] = create_constants_dict_from_config(config)

    m_index_tuples = [(exp_cond_name, gene_name) for exp_cond_name, gene in config['runtime']['genes_dict'].items() for gene_name in gene]
    m_index = pd.MultiIndex.from_tuples(m_index_tuples)
    m_index.names = ['ec_name', 'g_name']

    config['runtime']['m_index'] = m_index
    
    legend = dict()
    legend['states'] = pd.read_csv(os.path.join(config_path, config["filename_legend_states"]),
                                   usecols=['name', 'value'], index_col='name')['value']  # Series
    legend['constants'] = pd.read_csv(os.path.join(config_path, config["filename_legend_constants"]),
                                      usecols=['name', 'value'], index_col='name')['value']  # Series
    config['runtime']['legend'] = legend

    config['runtime']['n_organisms'] = config['n_organisms']

    for exp_cond_name, exp_cond in config['experimental_conditions'].items():

        if exp_cond_name == 'common':
            continue

        filename_phenotype = os.path.normpath(os.path.join(config_path, exp_cond['filename_phenotype']))
        exp_cond['phenotype'] = pd.read_csv(filename_phenotype)
        exp_cond['filename_phenotype'] = filename_phenotype

        filename_state = os.path.normpath(os.path.join(config_path, exp_cond['filename_state']))
        exp_cond['initial_state'] = pd.Series(np.loadtxt(filename_state), index=legend['states'].index)
        exp_cond['filename_state'] = filename_state

        column_stim_protocol = config.get('column_stim_protocol', None)
        if column_stim_protocol is not None:
            filename_stim_protocol = os.path.normpath(os.path.join(config_path, exp_cond['filename_stim_protocol']))
            exp_cond['stim_protocol'] = pd.read_csv(filename_stim_protocol)
            exp_cond['filename_stim_protocol'] = filename_stim_protocol
        else:
            exp_cond['stim_protocol'] = None

    states_initial = pd.DataFrame(data={exp_cond_name: exp_cond['initial_state'].copy()
                                        for exp_cond_name, exp_cond in config['experimental_conditions'].items()
                                        if exp_cond_name != 'common'})

    config['runtime']['states_initial'] = states_initial

    bounds, gammas, mask_multipliers = generate_bounds_gammas_mask_multipliers(config['runtime']['genes_dict'])
    config['runtime']['bounds'] = bounds
    config['runtime']['gammas'] = gammas
    config['runtime']['mask_multipliers'] = mask_multipliers
    
    return config

In [None]:
config_filename = "../configs/koivumaki/10kHz_gKb/config_G2C1_test.json"

config = prepare_config(config_filename)

In [None]:
genes_dict = create_genes_dict_from_config(config)
state = config['runtime']['states_initial']
genes_m_index = config['runtime']['m_index']
    
genes = [random_value_from_bounds(gene_params['bounds'], log_scale=gene_params['is_multiplier'])
         if gene_name not in state.index else state[exp_cond_name][gene_name]
         for exp_cond_name, exp_cond in genes_dict.items() for gene_name, gene_params in exp_cond.items()]

genes = pd.Series(data=genes, index=genes_m_index)
genes['common'] = 1.

In [None]:
from pypoptim.cardio import update_S_C_from_genes_current, update_genes_from_state, calculate_n_samples_per_stim

In [None]:
def calculate_loss(pred, config):
    
    loss = 0

    columns_control = ['V']
    columns_model = ['V']
    
    for exp_cond_name, exp_cond in config['experimental_conditions'].items():

        if exp_cond_name == 'common':
            continue

        n_samples_per_stim = calculate_n_samples_per_stim(exp_cond_name, config)

        phenotype_control = exp_cond['phenotype'][columns_control][-n_samples_per_stim - 1:]
        phenotype_model   = pred['phenotype'][exp_cond_name][columns_model][-n_samples_per_stim - 1:]

        phenotype_model   = phenotype_model[:len(phenotype_control)]

        loss += float(np.sqrt(np.mean((phenotype_control - phenotype_model)**2)))

    return loss

In [None]:
from pypoptim.algorythm import Solution


class SolModel(Solution):
    
    def __init__(self, x, **kwargs_data):
        super().__init__(x, **kwargs_data)
        self._model = koivumaki
        self['kwargs_model'] = kwargs_model
    
    
    def update(self):
        
        model = self._model
        
        self['phenotype'] = {}

        legend = config['runtime']['legend']

        for exp_cond_name in config['experimental_conditions']:

            if exp_cond_name == 'common':
                continue
        
            C = legend['constants'].copy()
            S = self['state'][exp_cond_name].copy()
        
            genes = pd.Series(self.x, index=genes_m_index)
            update_S_C_from_genes_current(S, C, genes, exp_cond_name, config)
            
            column_stim_protocol = config['column_stim_protocol']
            stim_protocol = config['experimental_conditions'][exp_cond_name]['stim_protocol']
            pred = model.run(S, C, stim_protocol=stim_protocol,
                             column_stim_protocol=column_stim_protocol,
                             **self['kwargs_model'])        
        
            update_genes_from_state(genes=genes, state=self['state'],
                                    config=config, exp_cond_name=exp_cond_name)
            self.x = genes.values
            
            self['phenotype'][exp_cond_name] = pred.copy()
            self['state'][exp_cond_name] = self['phenotype'][exp_cond_name].iloc[-1]
        
        self._y = calculate_loss(self, config)
                
        
    def is_valid(self):
        flag_valid = True
        flag_valid &= self._model._status == 2
        flag_valid &= all(not np.any(np.isnan(p)) for p in sol['phenotype'].values())
        return flag_valid

In [None]:
kwargs_model = dict(stim_period_legend_name='STIM_PERIOD',
                    t_sampling=0.0001, n_beats=1, n_beats_save=4)

In [None]:
genes

In [None]:
x0 = genes.values.copy()
x0[1] = 1
sol = SolModel(x0, state=state.copy())
# sol

In [None]:
sol.update()
sol.y

In [None]:
sol.is_valid()

In [None]:
plt.plot(sol['phenotype']['1032'].V)
plt.plot(config['experimental_conditions']['1032']['phenotype'].V)

In [None]:
from pypoptim.algorythm.ga import GA

In [None]:
bounds = config['runtime']['bounds']

ga_optim = GA(SolModel, bounds,
              keys_data_transmit=['state'])

In [None]:
population = ga_optim.generate_population(10)
for sol in population:
    sol['state'] = state.copy()

In [None]:
ga_optim.update_population(population)

In [None]:
len(population)

In [None]:
population[0].x[0] = 1000

In [None]:
ga_optim._is_solution_inside_bounds(population[0])

In [None]:
ga_optim._is_solution_inside_bounds(sol)

In [None]:
population = ga_optim.filter_population(population)
len(population)

In [None]:
from tqdm.auto import tqdm

In [None]:
parallel = True

n_solutions = 24
n_elites = 2
n_epochs = 10

population = ga_optim.generate_population(n_solutions)

for sol in population:
    sol['state'] = state.copy()

loss = []

for i in tqdm(range(n_epochs)):

    if parallel:
        for sol in population:
            sol.update()
    else:
        ga_optim.update_population(population)
        
    population = ga_optim.filter_population(population)

    loss.append(min(population).y)

    elites  = ga_optim.get_elites(population, n_elites)
    mutants = ga_optim.get_mutants(population, n_solutions - n_elites)

    population = elites + mutants

In [None]:
for sol in population:
    print(sol.y)

In [None]:
len(population)

In [None]:
plt.semilogy(loss)