# Discrete SOL

In [None]:
from __future__ import division
import numpy as np
from vpython import *

import matplotlib.pyplot as plt

from math import atan2
from scipy.integrate import solve_ivp
import os
from shutil import copy
from functools import partial

import MBRLtool
import utils
try:
    import tkinter as tk
except:
    import tk
from pathlib import Path

In [None]:
class SimResults():
    def __init__(self,t,Lib,DB,SysID,Ctrl,output_dir_path,select={'states':1,'value':1,'P':1,'error':1}, demand_flag = False):
        self.t=t
        len_t=len(t)
        self.Lib=Lib
        self.DB=DB
        self.SysID=SysID
        self.Ctrl=Ctrl
        self.x_s_history=np.zeros((self.Lib.n,len_t))
        self.u_history=np.zeros((self.Lib.m,len_t))
        self.P_history=np.zeros((len_t,self.Lib._Phi_dim,self.Lib._Phi_dim))
        self.V_history=np.zeros((len_t))
        self.error_history=np.zeros((len_t))
        self.select=select
        self.output_dir_path=output_dir_path
        self.demand_flag = demand_flag
        self.demand = np.zeros((len_t))
        self.pallet=['r','g','b','m','#E67E22','#1F618D']
    def record(self,i,x_s,u,P,V,error, demand = 0):
        self.x_s_history[:,i]=x_s
        self.u_history[:,i]=u
        self.P_history[i]=P
        self.V_history[i]=V
        self.error_history[i]=error
        self.demand[i] = demand
    def graph(self,j,i):
        px=self.Lib._Phi_dim
        if self.select['P']:
            fig = plt.figure()
            for ii in range(px):
                for jj in range(px):
                    plt.plot(self.t[:i],self.P_history[:i,ii,jj],'g')
            plt.savefig(self.output_dir_path+'/fig_P{}.png'.format(j))
            plt.close(fig)
            plt.show()
        if self.select['states']:
            fig1 = plt.figure()
            legend_1 = []
            for im in range(self.Lib.m):
                plt.plot(self.t[:i], self.u_history[im,:i],   self.pallet[im%len(self.pallet)]+'--')
                legend_1.append("u"+str(im))
            for ii in range(self.Lib.n):
                plt.plot(self.t[:i], self.x_s_history[ii,:i], self.pallet[ii%len(self.pallet)])
                legend_1.append("x"+str(ii))
            plt.legend(legend_1, loc=1)
            plt.xlabel('t (sec)')
            plt.ylabel('States and Control')
            plt.ylim((-10, 10))
            plt.grid(color='k', linestyle=':', linewidth=1)
            plt.savefig(self.output_dir_path+'/fig_states_control{}.png'.format(j))
            plt.close(fig1)
            plt.show()
            fig0, axs = plt.subplots(3, 1)
            b1,=axs[0].plot(self.t[:i],self.V_history[:i],'b')
            axs[0].set_ylabel('Value')
            axs[0].grid(color='k', linestyle=':', linewidth=1)
            for ii in range(px):
                for jj in range(px):
                    axs[1].plot(self.t[:i],self.P_history[:i,ii,jj],'g')
            axs[1].set_ylabel('Parameters')
            axs[1].grid(color='k', linestyle=':', linewidth=1)
            b1,=axs[2].plot(self.t[:i],self.error_history[:i],'r')
            axs[2].set_ylabel('Error')
            axs[2].set_ylim([0, 1])
            axs[2].grid(color='k', linestyle=':', linewidth=1)
            plt.tight_layout()
            plt.savefig(self.output_dir_path+'/fig_states_control_Value_Param_Error{}.png'.format(j))
            plt.close(fig0)
            plt.show()
        if self.demand_flag:
            fig4 = plt.figure()
            plt.plot(self.t[:i],self.demand[:i],'y')
            plt.savefig(self.output_dir_path+'/fig_demand{}.png'.format(j))
            plt.close(fig4)

        if self.select['error']:
            fig2 = plt.figure()
            plt.plot(self.t[:i],self.error_history[:i],'g')
            plt.ylim((0,200))
            plt.savefig(self.output_dir_path+'/fig_error{}.png'.format(j))
            plt.close(fig2)
            plt.show()
        if self.select['value']:
            fig3 = plt.figure()
            plt.plot(self.t[:i],self.V_history[:i],'b')
            plt.tight_layout()
            plt.savefig(self.output_dir_path+'/fig_value{}.png'.format(j))
            plt.close(fig3)
            plt.show()
    def printout(self,j):
        print('Episode {}:'.format(j+1))
        if self.DB.db_overflow:
            print('Number of samples in database : ',self.DB.db_dim)
        else:
            print('Number of samples in database : ',self.DB.db_index)
        chosen_basis_label=self.Lib._Phi_lbl
        for ii in range(self.Lib.n):
            handle_str='x_dot({}) = '.format(ii+1)
            for jj in range(self.DB.Theta_dim):
                if self.SysID.Weights[ii,jj]!=0:
                    if jj<self.Lib._Phi_dim:
                        handle_str=handle_str+(' {:7.3f}*{} '.format(self.SysID.Weights[ii,jj],chosen_basis_label[jj]))
                    elif jj>=self.Lib._Phi_dim:
                        handle_str=handle_str+(' {:7.3f}*{}*u{} '.format(self.SysID.Weights[ii,jj],chosen_basis_label[jj%self.Lib._Phi_dim],jj//self.Lib._Phi_dim))
            print(handle_str)
        #print: obtained value function
        handle_str='V(x) = '
        for ii in range(self.Lib._Phi_dim):
            for jj in range(ii+1):
                if (self.Ctrl.P[ii,jj]!=0):
                    if (ii==jj):
                        handle_str=handle_str+'{:7.3f}*{}^2'.format(self.Ctrl.P[ii,jj],chosen_basis_label[jj])
                    else:
                        handle_str=handle_str+'{:7.3f}*{}*{}'.format(2*self.Ctrl.P[ii,jj],chosen_basis_label[ii],chosen_basis_label[jj])
        print(handle_str)
        print("% of non-zero elements in P: {:4.1f} %".format(100*np.count_nonzero(self.Ctrl.P)/(self.Lib._Phi_dim**2)))

class Control():
    def __init__(self,h,Objective,Lib,P_init):
        self.Objective=Objective
        self.Lib=Lib
        self.Qb=np.zeros((self.Lib._Phi_dim,self.Lib._Phi_dim))
        const = int('1' in Lib.chosen_bases)
        self.Qb[const:self.Lib.n+const,const:self.Lib.n+const]=self.Objective.Q
        self.P=P_init
        self.h=h
        self.update_P=1

    def integrate_P_dot(self,x,Wt,k,sparsify):
        self.Phi=self.Lib._Phi_(x)
        self.pPhi=self.Lib._pPhi_(x)
        dp_dt=partial(self.P_dot, x=x,Wt=Wt)
        sol=solve_ivp(dp_dt,[0, k*self.h], self.P.flatten(), method='RK45', t_eval=None,rtol=1e-6, atol=1e-6, dense_output=False, events=None, vectorized=False)
        self.P=sol.y[...,-1].reshape((self.Lib._Phi_dim,self.Lib._Phi_dim))
        if (sparsify):
            absPk=np.absolute(self.P)
            maxP=np.amax(absPk)
            small_index = absPk<(0.001*maxP)
            self.P[small_index]=0
    def P_dot(self,t,P,x,Wt):
        P=P.reshape((self.Lib._Phi_dim,self.Lib._Phi_dim))
        W=Wt[:,:self.Lib._Phi_dim]
        SIGMA=np.zeros((self.Lib._Phi_dim,self.Lib._Phi_dim))
        P_pPhi_W=np.matmul(np.matmul(P,self.pPhi),W)
        for im in range(self.Lib.m):
            P_pPhi_Wcj_Phi=np.matmul(np.matmul(P,self.pPhi),np.matmul(Wt[:,self.Lib._Phi_dim*(im+1):self.Lib._Phi_dim*(im+2)],self.Phi))
            SIGMA+=1/self.Objective.R[im,im]*np.outer(P_pPhi_Wcj_Phi,P_pPhi_Wcj_Phi)
        return (self.Qb-SIGMA+P_pPhi_W+P_pPhi_W.T-self.Objective.gamma*P).flatten()

    def calculate(self,x,Wt,u_lim):
        u=np.zeros((self.Lib.m))
        for im in range(self.Lib.m):
            u[im]=-(1/self.Objective.R[im,im])*np.matmul(self.Phi,np.matmul(np.matmul(self.P,self.Lib._pPhi_(x)),np.matmul(Wt[:,self.Lib._Phi_dim*(im+1):self.Lib._Phi_dim*(im+2)],self.Phi)))
        u=np.clip(u, -u_lim, u_lim)
        return u
    def value(self):
        return np.matmul(np.matmul(self.Phi,self.P),self.Phi)

class Control_full_R(Control):
    def __init__(self,h,Objective,Lib,P_init):
        Control.__init__(self, h, Objective, Lib, P_init)
        self.inv_R = np.linalg.inv(self.Objective.R)

    def P_dot(self,t,P,x,Wt):
        P=P.reshape((self.Lib._Phi_dim,self.Lib._Phi_dim))              # np x np
        W=Wt[:,:self.Lib._Phi_dim]                                      # nx x np
        P_pPhi = np.matmul(P,self.pPhi)                                 # np x nx
        P_pPhi_W=np.matmul(P_pPhi,W)                                    # np x np
        P_pPhi_Wcj_Phi = np.zeros( ( self.Lib._Phi_dim, self.Lib.m) )   # np x nu
        for im in range(self.Lib.m):
            P_pPhi_Wcj_Phi[:,im] = np.matmul(np.matmul(P_pPhi, Wt[:,self.Lib._Phi_dim*(im+1):self.Lib._Phi_dim*(im+2)]), self.Phi )
        PRP_ = np.matmul( P_pPhi_Wcj_Phi, np.matmul( self.inv_R, P_pPhi_Wcj_Phi.T ) )
        return  (self.Qb+P_pPhi_W+P_pPhi_W.T-self.Objective.gamma*P - PRP_).flatten()

class new_SysID():
    def __init__(self,select_ID_algorithm,Database,Weights,Lib, lam = 0.01):
        self.ID_alg=select_ID_algorithm
        self.DB=Database
        self.Weights=Weights
        self.Lib=Lib
        self.Theta=np.zeros((self.DB.Theta_dim))
        self.P_rls = np.zeros((self.Lib.n,self.DB.Theta_dim,self.DB.Theta_dim))
        self.lam = lam
        for i in range(self.Lib.n):
            self.P_rls[i]=np.eye(self.DB.Theta_dim)*1000

    def update(self,x,x_dot,u):
        if self.ID_alg['SINDy']:
            lam= self.lam
            if self.DB.db_overflow:
                self.Weights=(utils.SINDy(self.DB.db_X_dot,self.DB.db_Theta,lam))
            else:
                self.Weights=(utils.SINDy(self.DB.db_X_dot[:,:self.DB.db_index],self.DB.db_Theta[:,:self.DB.db_index],lam))
        elif self.ID_alg['RLS']:
            _Phi_=self.Lib._Phi_(x)
            self.Theta[:self.Lib._Phi_dim]=_Phi_
            for im in range(self.Lib.m):
                self.Theta[self.Lib._Phi_dim*(im+1):self.Lib._Phi_dim*(im+2)]=_Phi_*u[im]
            #for i in range(self.Lib.n):
            self.Weights,self.P_rls[0] =utils.identification_RLS(self.Weights, self.P_rls[0], x_dot, self.Theta)
        elif self.ID_alg['LS']:
            Phi, X_dot = self.DB.read()
            self.Weights = np.dot( X_dot, np.linalg.pinv(Phi) )
        elif self.ID_alg['GD']:
            lam = self.lam
            prediction = self.evaluate( x, u )
            self.Weights -= lam*np.outer( (prediction-x_dot), self.Theta )
        return self.Weights

    def evaluate(self,x,u):
        _Phi_=self.Lib._Phi_(x)
        self.Theta[:self.Lib._Phi_dim]=_Phi_
        for im in range(self.Lib.m):
            self.Theta[self.Lib._Phi_dim*(im+1):self.Lib._Phi_dim*(im+2)]=_Phi_*u[im]
        return np.matmul(self.Weights,self.Theta)
    
    def save(self):
        if self.ID_alg['RLS']:
            if (self.DB.save):
                np.save(self.DB.output_dir_path+'/Weights.npy', self.Weights)
                np.save(self.DB.output_dir_path+'/P_rls.npy', self.P_rls)


class Library():
    def __init__(self,chosen_bases,measure_dim,m):
        self.chosen_bases=chosen_bases
        self.n=measure_dim
        self.m=m
        #library of bases
        self.lib={'1':lambda x:[1],\
                 'x':lambda x:x,\
                 'sqrtx':lambda x:np.sqrt( np.abs(x)),\
                 'x^2':lambda x: x**2,\
                 'x^3':lambda x: x**3,\
                 'sinx':lambda x:np.sin(x),\
                 '(sinx)^2':lambda x:np.sin(x)**2,\
                 'cosx':lambda x:np.cos(x),\
                 '(cosx)^2':lambda x:np.cos(x)**2,\
                 'xx':lambda x:self.build_product(x)}
        #library of the corresponding gradients
        self.plib={'1':lambda x:x*0,\
                  'x':lambda x:np.diag(x**0),\
                  'sqrtx':lambda x:np.diag(0.5/np.sqrt( np.abs(x))),\
                  'x^2':lambda x: np.diag(2*x),\
                  'x^3':lambda x: np.diag(3*(x**2)),\
                  'sinx':lambda x:np.diag(np.cos(x)),\
                  '(sinx)^2':lambda x:np.diag(np.multiply(2*np.sin(x),np.cos(x))),\
                  'cosx':lambda x:np.diag(-np.sin(x)),\
                  '(cosx)^2':lambda x:np.diag(np.multiply(-2*np.cos(x),np.sin(x))),\
                  'xx':lambda x:self.build_pproduct(x)}
        #library of the corresponding labels
        self.lib_labels={'1':'1',\
                        'x':self.build_lbl('x'),\
                        'sqrtx':self.build_lbl('sqrtx'),\
                        'x^2':self.build_lbl('x^2'),\
                        'x^3':self.build_lbl('x^3'),\
                        'sinx':self.build_lbl('sinx'),\
                        '(sinx)^2':self.build_lbl('(sinx)^2'),\
                        'cosx':self.build_lbl('cosx'),\
                        '(cosx)^2':self.build_lbl('(cosx)^2'),\
                        'xx':self.build_lbl_product('xx')}
        self.lib_dims={'1':1,\
                        'x':self.n,\
                        'sqrtx':self.n,\
                        'x^2':self.n,\
                        'x^3':self.n,\
                        'sinx':self.n,\
                        '(sinx)^2':self.n,\
                        'cosx':self.n,\
                        '(cosx)^2':self.n,\
                        'xx':(self.n**2-self.n)/2}

        self._Phi_lbl=[]
        for i in self.chosen_bases:
            self._Phi_lbl.extend(self.lib_labels[i])
        self._Phi_dim=len(self._Phi_lbl)
        #reserve the memeory required to evaluate Phi
        self._Phi_res=np.zeros((self._Phi_dim))
        #reserve the memeory required to evaluate pPhi
        self._pPhi_res=np.zeros((self._Phi_dim,self.n))

    def build_product(self,x):
        function=np.zeros((int((self.n**2-self.n)/2)))
        ind=0
        for i in range(self.n):
            for j in range(i+1,self.n):
                  function[ind]=x[i]*x[j]
                  ind+=1
        return function
    def build_pproduct(self,x):
        g=np.zeros((int((self.n**2-self.n)/2),self.n))
        ind=0
        for i in range(self.n):
            for j in range(i+1,self.n):
                g[ind][i]=x[j]
                g[ind][j]=x[i]
                ind+=1
        return g
    def build_lbl(self,func_name):
        lbl=[]
        for i in range(self.n):
            index=func_name.find('x')
            lbl.append(func_name[:index+1]+'({})'.format(i+1)+func_name[index+1:])
        return lbl
    def build_lbl_product(self,func_name):
        lbl=[]
        for i in range(self.n):
            for j in range(i+1,self.n):
                index1=func_name.find('x')
                index2=func_name.find('x',index1+1)
                lbl.append(func_name[:index1+1]+'({})'.format(i+1)+func_name[index1+1:index2+1]+'({})'.format(j+1)+func_name[index2+1:])
        return lbl
    def _Phi_(self,x):
        i=0
        for key in self.chosen_bases:
            temp=int(self.lib_dims[key])
            self._Phi_res[i:i+temp]=self.lib[key](x)
            i+=temp
        return self._Phi_res
    def _pPhi_(self,x):
        i=0
        for key in self.chosen_bases:
            temp=int(self.lib_dims[key])
            self._pPhi_res[i:i+temp,:]=self.plib[key](x)
            i+=temp
        return self._pPhi_res
