In [1]:
import numpy as np
import torch
from torch import nn
import torch.autograd.functional as F
STATELEN = 10
ACTLEN = 10
STEP_SIZE = 4
#based on https://homes.cs.washington.edu/~todorov/papers/TassaIROS12.pdf

In [2]:
class Dynamics(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(Dynamics, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )
    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [3]:
class reward(nn.Module):

    def __init__(self, input_size, hidden_size, output_size = 1):
        super(reward, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [None]:
my_Dyna = Dynamics(STATELEN + ACTLEN, STATELEN, STATELEN)
my_reward = reward(STATELEN + ACTLEN, STATELEN , 1)

In [None]:
class ilqr:
    
    def __init__(self, ts, dyn, re, sl, al):
        """
        Args:
            ts: time step
            dyn: dynamic
            re: reward
            sl: state length
            al: action length
        """
        self.ts = ts
        self.dyn = dyn
        self.re = re
        self.sl = sl
        self.al = al
        
        self.S = torch.rand((self.ts, 1, self.sl))
        self.A = torch.rand((self.ts, 1, self.al))
        self.R = torch.empty((self.ts, 1, 1))
        self.K_arr = torch.zeros(self.ts, self.al, self.sl)
        self.k_arr = torch.zeros(self.ts, 1, self.al)
        self.ifconv = 0

    def _forward(self):
        
        p_S = self.S
        p_A = self.A
        s = p_S[0].clone().detach()
        a = p_A[0].clone().detach()

        i = 0
        while i < self.time_step:
            self.S[i] = s
            a = (torch.matmul(s - p_S[i],torch.transpose((self.K_arr[i]),0,1)) + 
                 self.k_arr[i] + p_A[i]
                )
            self.A[i] = a
            sa_in = torch.cat((s, self.A[i]),dim = 1)
            #sa_in shape = [1,state_size + action_size]

            s = self.dyn(sa_in)
            #state shape = [1,state_size]

            self.R[i] = self.re(sa_in)
            i = i + 1

    def _backward(self):
        
        i = self.ts -1
        self.K_arr = torch.zeros(self.ts, self.al, self.sl )
        self.k_arr = torch.zeros(self.ts, 1, self.al )

        while i > -1:
            sa_in = torch.cat((self.S[i], self.A[i]),dim = 1)
            C_t = F.hessian(self.re, sa_in.view(-1))
            #shape = [state+action, state+action]
            #print(torch.sum(C_t))
            F_t = F.jacobian(self.dyn, sa_in.view(-1))
            transF_t = torch.transpose(F_t,0,1)
            #shape = [state, state+action]
            #print(torch.sum(F_t))
            c_t = F.jacobian(self.re, sa_in.view(-1))
            #shape = [1, state+action]
            #print(torch.sum(c_t))

            if i == self.ts - 1:
                Q_t = C_t
                q_t = c_t
            else:
                Q_t = C_t + torch.matmul(torch.matmul(transF_t, V_t), F_t)
                #eq 5[c~e]
                q_t = c_t + torch.matmul(v_t, F_t)
                #eq 5[a~b]
                
            Q_pre1 = torch.split(Q_t, [self.sl, self.al])[0]
            Q_pre2 = torch.split(Q_t, [self.sl, self.al])[1]
            Q_xx = torch.split(Q_pre1, [self.sl, self.al], dim = 1)[0]
            Q_xu = torch.split(Q_pre1, [self.sl, self.al], dim = 1)[1]
            Q_ux = torch.split(Q_pre2, [self.sl, self.al], dim = 1)[0]
            Q_uu = torch.split(Q_pre2, [self.sl, self.al], dim = 1)[1]
            
            q_t = torch.split(q_t, [self.sl, self.al], dim = 1)
            Q_x = q_t[0]
            Q_u = q_t[1]
            
            try:
                invQuu = torch.linalg.inv(Q_uu - torch.eye(self.al)) #regularize term
                #eq [9]
            except:
                invQuu = torch.linalg.inv(Q_uu + torch.eye(self.al)*0.01)
                self.ifconv = 1

            K_t = -torch.matmul(invQuu, Q_ux)
            transK_t = torch.transpose(K_t, 0, 1)
            #K_t shape = [actlen, statelen]
            
            k_t = -torch.matmul(Q_u, invQuu)
            #k_t shape = [1,actlen]

            V_t = (Q_xx + torch.matmul(Q_xu, K_t) + 
                   torch.matmul(transK_t, Q_ux) +
                   torch.matmul(torch.matmul(transK_t, Q_uu), K_t)
                  )
            # eq 11c
            #V_t shape = [statelen, statelen]

            v_t = (Q_x + torch.matmul(k_t, Q_ux) + 
                   torch.matmul(Q_u, K_t) + 
                   torch.matmul(k_t, torch.matmul(Q_uu, K_t)) 
                  )
            # eq 11b
            #v_t shape = [1, statelen]
            
            self.K_arr[i] = K_t
            self.k_arr[i] = k_t
            i = i - 1
    
    def fit(self):
        
        i = 0
        while(self.ifconv =! 1) and i < 100:
            i = i + 1
            self._forward()
            self._backward()
        
        return self.A

In [None]:
#for param in rew.parameters():
#    print(param)