# Determine kon and dG from observed yield vs time data #

In [None]:
# make sure jupyter path is correct for loading local moudules
import sys
# path to steric_simulator module relative to notebook
sys.path.append("/home/adip/mjohn218_KineticAssembly")
import copy

In [None]:
from KineticAssembly_AD import ReactionNetwork, VectorizedRxnNetExp, VecSim, Optimizer, EquilibriumSolver,OptimizerExp
import networkx as nx
import torch
from torch import DoubleTensor as Tensor
import numpy as np

## Setup Reaction Network
#### Read the corresponding input file and call the ReactionNetwork class

In [None]:
base_input = 'path_to_input'
rn = ReactionNetwork(base_input, one_step=True)
rn.resolve_tree()


## Checking reaction network
Looping over all network nodes to check if all species are created
Creating a dictionary for later reference. This dictionary holds the reactants as keys and values as the reaction index

In [None]:
uid_dict = {}
sys.path.append("../../")
import numpy as np
from reaction_network import gtostr
for n in rn.network.nodes():
    print(n,"--",gtostr(rn.network.nodes[n]['struct']))
    for k,v in rn.network[n].items():
        uid = v['uid']
        r1 = set(gtostr(rn.network.nodes[n]['struct']))
        p = set(gtostr(rn.network.nodes[k]['struct']))
        r2 = p-r1
        reactants = (r1,r2)
        uid_dict[(n,k)] = uid

print(uid_dict)

## Set the initial parameter values 
For a tetramer model there are 22 reactions. We can set an initial value for all reaction rates as given in the next cell. 

For the Rategrowth, the number of rates decrease to only 3 values. To set the initial values for all rates in a rate growth model, additional code is also given.

In [None]:
"""
Set initial rate values for all reactions
"""
import networkx as nx
#Define a new tensor array with all values initialized to zero
new_kon = torch.zeros([rn._rxn_count], requires_grad=True).double()
#Assign value for each rate 
new_kon = new_kon + Tensor([1]*np.array(1e0))

"""
For RateGrowth model, initial values are assigned differently
"""
#Define initial values for dimer,trimer and tetramer rate
# kdim=
# ktri=
# ktetra=
# rates= [kdim, ktri, ktetra]

#Assign the corresponding reaction values to it's reaction type.
# counter=0
# for k,v in rn.rxn_class.items():
#     for rid in v:
#         new_kon[v] = rates[counter]
#     counter+=1


"""
Update the reaction network with the new initial values
"""
update_kon_dict = {}
for edge in rn.network.edges:
    print(rn.network.get_edge_data(edge[0],edge[1]))
    update_kon_dict[edge] = new_kon[uid_dict[edge]]

nx.set_edge_attributes(rn.network,update_kon_dict,'k_on')





### Define the Vectorized Reaction Network class

In this class all reaction rates, concentrations, dG's are stored in Tensors for vectorized operations.

In [None]:
vec_rn = VectorizedRxnNet(rn, dev='cpu')

## Using the optimizer ##

### Define an instance of the optimizer class
#### Input Arguments:

reaction_network : Input the vectorized rxn network

sim_runtime: The runtime of the kinetic simulation. Needs to be same as the time over the experimental reaction data.

optim_iterations: No. of iterations to run the optimization. Can start at low values(100) and increase depending upon memory usage.

learning_rate = The size of the gradient descent step for updating parameter values. Needs to be atleast (1e-3-1e-1)* min{parameter value}. If learning rate is too high, it can take a longer step and sometimes lead to negative value of parameters which is unphysical. Requires some trial runs to find the best value. 

device: cpu or gpu

method: Choose which pytorch based optimized to use for gradient descent - Adam or RMSprop

mom: Only for RMSprop method. Use momentum term during gradient descent. 



In [None]:
# learn_rate=[1e-3,1e-3]
learn_rate=1e-3
momentum=0.2
runtime=10



In [None]:

vec_rn.reset(reset_params=True)
optim = OptimizerExp(reaction_network=vec_rn,
                  sim_runtime=runtime,
                  optim_iterations=100,
                  learning_rate=learn_rate,
                  device='cpu',method="Adam",reg_penalty=1000000,mom=momentum)


### Call the optimization method

#### Input arguments

files_range = Array that hold list of concentration values to be simulated for global optimization. All values are stored as integers

conc_files_pref = Path location and prefix for data files with true values of yield at each time points

yield_species: Yield of the species being optimized(node index)

yield_thresh= Used to define the maximum yield point of the window used for calculating the error between the true and predicted yield values.

yield_min = Min point of the window used for calculating yield error.

mode = Mode of calculating error. There are two modes - a) 'square' - Sum of Squared error b) 'abs' - Using absolute value of error



In [None]:
files_range=[100,500,1000,5000,10000]
yield_thresh=0.8
yield_min=0.7

optim.rn.update_reaction_net(rn)
optim.optimize_wrt_conc_beta(conc_scale=1e-1,conc_thresh=1e-1,mod_bool=True,mod_factor=10,max_thresh=1e2,max_yield=0,yield_species=14,
                        conc_files_pref="dG_trap/ConcProfile_Time_HomoRates_",conc_files_range=files_range,yield_threshmin=yield_min,yield_threshmax=yield_thresh)

## Track the error over optim iterations

In [None]:
import matplotlib.pyplot as plt
%matplotlib notebook
fig,ax = plt.subplots()

ax.plot(optim.mse_error)

ax.tick_params(labelsize='xx-large')

ax.set_xlabel("Iterations",fontsize=25)
ax.set_ylabel("MSE",fontsize=25)


# ax.legend(fontsize='large')

### Store parameter values obtained over the entire optimization run

In [None]:
yields= []
final_params=[]
mse_error = []

for i in range(len(optim.yield_per_iter)):
    yields.append(optim.yield_per_iter[i])
    params=[]
    for j in range(len(optim.parameter_history[i])):
        params.append(np.array(optim.parameter_history[i][j]))
    final_params.append(params)
    mse_error.append(optim.mse_error[i])
    
sorted_yields=np.array(yields)#[sort_indx]
sorted_params = np.array(final_params)#[sort_indx]
mse_error = np.array(mse_error)



### Select parameter values with min error

In [None]:
min_indx = np.argmin(mse_error)

min_rates = list(sorted_params[min_indx])
min_error = mse_error[min_indx]

dG = -1*torch.log(min_rates[0][0]*vec_rn._C0/min_rates[1][0])
print("Params: ",min_rates)
print("dG: ",dG)
print("Min SSE: ",min_error )

## Storing parameter values in a file

### For part 1 - Only kon optimization ###

In [None]:
#Writing all solutions to a file

klabels=['k'+str(i) for i in range(len(vec_rn.kon))]
header = '#Yield\t' + "\t".join(klabels) + "\tt50\tt85\tt95\n"

files_range = [str(f) for f in files_range]
filestr = ",".join(files_range)


with open("Solutions_Conc_Homorates_dGNotrap_02_part1",'a') as fl:
    fl.write(header)
    fl.write("# Range of Concentrations: %s\n" %filestr)
    fl.write("# Learning rate: %s\n" %(str(learn_rate)))
    fl.write("# Momentum: %f\n" %(momentum))
    fl.write("# Yield thresh: %f\n" %(yield_thresh))
    for i in range(len(sorted_yields)):
        fl.write("%f" %(sorted_yields[i]))
        fl.write("\t%f" %(mse_error[i]))
        for j in range((sorted_params[i].shape[0])):
            fl.write("\t%f" %(sorted_params[i][j]))
        fl.write("\n")
                 


### Part 2 - koff optimization ###

In [None]:
# #Writing all solutions to a file

# klabels=['k'+str(i) for i in range(len(vec_rn.kon))]
# header = '#Yield\t' + "\t".join(klabels) + "\tt50\tt85\tt95\n"

# files_range = [str(f) for f in files_range]
# filestr = ",".join(files_range)


# with open("Solutions_Conc_Homorates_dGparam_07_part2",'a') as fl:
#     fl.write(header)
#     fl.write("# Range of Concentrations: %s\n" %filestr)
#     fl.write("# Learning rate: %s\n" %(",".join(str(lr) for lr in learn_rate)))
#     fl.write("# Momentum: %f\n" %(momentum))
#     fl.write("# Yield thresh: %f\n" %(yield_thresh))
#     for i in range(len(sorted_yields)):
#         fl.write("%f" %(sorted_yields[i]))
#         fl.write("\t%f" %(mse_error[i]))
#         for j in range((sorted_params[i].shape[0])):
#             for k in range(len(sorted_params[i][j])):
#                 fl.write("\t%f" %(sorted_params[i][j][k]))
#         fl.write("\n")
        
                 
