In [1]:
import numpy as np
import torch
from torch import nn
import torch.autograd.functional as F
STATELEN = 10
ACTLEN = 10
STEP_SIZE = 4

In [2]:
class Dynamics(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(Dynamics, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )
    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [3]:
class reward(nn.Module):

    def __init__(self, input_size, hidden_size, output_size = 1):
        super(reward, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [None]:
my_Dyna = Dynamics(STATELEN + ACTLEN, STATELEN, STATELEN)
my_reward = reward(STATELEN + ACTLEN, STATELEN , 1)

In [None]:
class ilqr:
    
    def __init__(self, ts, dyn, re, sl, al):
        """
        Args:
            ts: time step
            dyn: dynamic
            re: reward
            sl: state length
            al: action length
        """
        self.ts = ts
        self.dyn = dyn
        self.re = re
        self.sl = sl
        self.al = al
        
        self.S = torch.rand((self.ts, 1, self.sl))
        self.A = torch.rand((self.ts, 1, self.al))
        self.R = torch.empty((self.ts, 1, 1))
        self.K_arr = torch.zeros(self.ts, self.al, self.sl)
        self.k_arr = torch.zeros(self.ts, 1, self.al)
        self.ifconv = 0

    def _forward(self):
        
        p_S = self.S
        p_A = self.A
        s = p_S[0].clone().detach()
        a = p_A[0].clone().detach()

        i = 0
        while i < self.time_step:
            self.S[i] = s
            a = (torch.matmul(s - p_S[i],torch.transpose((self.K_arr[i]),0,1)) + 
                 self.k_arr[i] + p_A[i]
                )
            self.A[i] = a
            sa_in = torch.cat((s, self.A[i]),dim = 1)
            #sa_in shape = [1,state_size + action_size]

            s = self.dyn(sa_in)
            #state shape = [1,state_size]

            self.R[i] = self.re(sa_in)
            i = i + 1
    

    def _backward(self):
        
        i = self.time_step -1
        self.K_arr = torch.zeros(self.ts, self.al ,self.sl )
        self.k_arr = torch.zeros(self.ts, 1,self.al )

        while i > -1:
            sa_in = torch.cat((self.S[i],self.A[i]),dim = 1)
            C_t = F.hessian(self.re, sa_in.view(-1))
            #shape = [state+action, state+action]
            #print(torch.sum(C_t))
            F_t = F.jacobian(self.dyn, sa_in.view(-1))
            #shape = [state, state+action]
            #print(torch.sum(F_t))
            c_t = F.jacobian(self.re, sa_in.view(-1))
            #shape = [1, state+action]
            #print(torch.sum(c_t))

            if i == step_size-1:
                Q_t = C_t
                q_t = c_t
            else:
                Q_t = C_t + torch.matmul(torch.matmul(torch.transpose(F_t,0,1),V_t),F_t)
                q_t = c_t + torch.matmul(v_t,F_t)

            Q_pre1 = torch.split(Q_t,[self.sl, self.al])[0]
            Q_pre2 = torch.split(Q_t,[self.sl, self.al])[1]
            Q_xx = torch.split(Q_pre1,[self.sl, self.al],dim = 1)[0]
            Q_xu = torch.split(Q_pre1,[self.sl, self.al],dim = 1)[1]
            Q_uu = torch.split(Q_pre2,[self.sl, self.al],dim = 1)[1]

            try:
                invQuu = torch.linalg.inv(Q_uu - torch.eye(self.al)) #regularize term
            except:
                invQuu = torch.linalg.inv(Q_uu + torch.eye(self.al)*0.01)
                self.ifconv = 1

            K_t = -torch.matmul(invQuu, torch.transpose(Q_xu,0,1))
            transK_t = torch.transpose(K_t,0,1)
            #K_t shape = [actlen, statelen]
            q_t = torch.split(q_t, [self.sl, self.al], dim = 1)
            k_t = -torch.matmul(q_t[1],invQuu)
            #k_t shape = [1,actlen]

            V_t = (Q_xx + torch.matmul(Q_xu,K_t)*2 + 
                   torch.matmul(torch.matmul(transK_t,Q_uu),K_t)
                  )
            #V_t shape = [statelen, statelen]

            v_t = (q_t[0] + torch.matmul(k_t,torch.transpose(Q_xu,0,1)) + 
                   torch.matmul(q_t[1],K_t) + 
                   torch.matmul(k_t, torch.matmul(Q_uu,K_t)) 
                  )
            #v_t shape = [1, statelen]
            self.K_arr[i] = K_t
            self.k_arr[i] = k_t
            i = i - 1
    
    
    def fit(self):
        
        i = 0
        while(self.ifconv =! 1) and i < 100:
            i = i + 1
            self._forward()
            self._backward()
        
        return self.A

In [6]:

i = 0
while i < 10000:
    
    state_col, act_col ,reward_col = forward_step(state_col, act_col, K_arr, k_arr, STEP_SIZE)
    K_arr, k_arr, ifconv = backward(state_col, act_col, reward_col, STEP_SIZE, ifconv)
    if i%100 == 0:
        print(i)
        print(act_col[0])
    if ifconv == 1:
        break
    
    i = i + 1


0
tensor([[0.8255, 0.8293, 0.3522, 0.4119, 0.4937, 0.9669, 0.4103, 0.9699, 0.7606,
         0.0503]], grad_fn=<SelectBackward0>)
100
tensor([[-2.0106, -2.7740,  0.6842,  2.0354,  2.4880, -0.1632,  2.4594, -1.8509,
          0.1161, -3.2119]], grad_fn=<SelectBackward0>)
200
tensor([[-2.6977, -3.2965,  0.3127,  2.6208,  3.1344, -0.9797,  3.6866, -2.7824,
          0.9371, -4.5258]], grad_fn=<SelectBackward0>)
300
tensor([[-3.1273, -3.5203,  0.0622,  2.9640,  3.4330, -1.5143,  4.4964, -3.3415,
          1.5872, -5.4317]], grad_fn=<SelectBackward0>)
400
tensor([[-3.4153, -3.6490, -0.0937,  3.1752,  3.6010, -1.8778,  5.1152, -3.7635,
          2.0480, -6.1381]], grad_fn=<SelectBackward0>)
500
tensor([[-3.6205, -3.7296, -0.1944,  3.3124,  3.7031, -2.1374,  5.6245, -4.0997,
          2.3910, -6.7335]], grad_fn=<SelectBackward0>)
600
tensor([[-3.7764, -3.7823, -0.2620,  3.4051,  3.7674, -2.3316,  6.0594, -4.3761,
          2.6594, -7.2538]], grad_fn=<SelectBackward0>)
700
tensor([[-3.8994, -3.

KeyboardInterrupt: 

In [None]:
#for param in rew.parameters():
#    print(param)