In [1]:
import numpy as np
import torch
from torch import nn
import torch.autograd.functional as F
STATELEN = 10
ACTLEN = 10
STEP_SIZE = 4

In [2]:
class Dynamics(nn.Module):

    def __init__(self, input_size, hidden_size, output_size):
        super(Dynamics, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )
    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [3]:
class reward(nn.Module):

    def __init__(self, input_size, hidden_size, output_size = 1):
        super(reward, self).__init__()
        # self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(input_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, hidden_size),
            nn.Tanh(),
            nn.Linear(hidden_size, output_size),
            nn.Tanh()
        )

    def forward(self, input_element):
        output = self.linear_relu_stack(input_element)
        return output

In [4]:
def forward_step(state_col, actions, K_arr, k_arr, time_step):
    state = state_col[0]
    state_collection = torch.zeros((time_step, 1, STATELEN))
    reward_collection = torch.zeros((time_step, 1, 1))
    act_collection = torch.zeros((time_step, 1, ACTLEN))
    
    i = 0
    while i < time_step:
        state_collection[i] = state
        actions[i] = (torch.matmul(state_col[i] - state,torch.transpose((K_arr[i]),0,1)) + 
                      k_arr[i] + actions[i]
                     )
        
        act_collection[i] = actions[i]
        sa_input = torch.cat((state,actions[i]),dim = 1)
        #sa_input shape = [1,state_size + action_size]
        state = my_Dyna(sa_input)
        #state shape = [1,state_size]
        reward = my_reward(sa_input)
        reward_collection[i] = reward
        i = i + 1
    
    return state_collection, act_collection, reward_collection

def backward(state_col,action_col,reward_col,step_size,ifconv):

    i = step_size -1
    K_arr = torch.zeros(step_size,ACTLEN,STATELEN)
    k_arr = torch.zeros(step_size,1,ACTLEN)
    
    while i > -1:
        temp_input = torch.cat((state_col[i],act_col[i]),dim = 1)
        C_t = F.hessian(my_reward, temp_input.view(-1))
        #shape = [state+action, state+action]
        #print(torch.sum(C_t))
        F_t = F.jacobian(my_Dyna, temp_input.view(-1))
        #shape = [state, state+action]
        #print(torch.sum(F_t))
        c_t = F.jacobian(my_reward, temp_input.view(-1))
        #shape = [1, state+action]
        #print(torch.sum(c_t))
        
        if i == step_size-1:
            Q_t = C_t
            q_t = c_t
        else:
            Q_t = C_t + torch.matmul(torch.matmul(torch.transpose(F_t,0,1),V_t),F_t)
            q_t = c_t + torch.matmul(v_t,F_t)
            
        Q_pre1 = torch.split(Q_t,[STATELEN,ACTLEN])[0]
        Q_pre2 = torch.split(Q_t,[STATELEN,ACTLEN])[1]
        Q_xx = torch.split(Q_pre1,[STATELEN,ACTLEN],dim = 1)[0]
        Q_xu = torch.split(Q_pre1,[STATELEN,ACTLEN],dim = 1)[1]
        Q_uu = torch.split(Q_pre2,[STATELEN,ACTLEN],dim = 1)[1]
        
        try:
            invQuu = torch.linalg.inv(Q_uu - torch.eye(ACTLEN)) #regularize term
        except:
            invQuu = torch.linalg.inv(Q_uu + torch.eye(ACTLEN)*0.01)
            ifconv = 1
            
        K_t = -torch.matmul(invQuu, torch.transpose(Q_xu,0,1))
        transK_t = torch.transpose(K_t,0,1)
        #K_t shape = [actlen, statelen]
        q_t = torch.split(q_t, [STATELEN, ACTLEN], dim = 1)
        k_t = -torch.matmul(q_t[1],invQuu)
        #k_t shape = [1,actlen]
        
        V_t = (Q_xx + torch.matmul(Q_xu,K_t)*2 + 
               torch.matmul(torch.matmul(transK_t,Q_uu),K_t)
              )
        #V_t shape = [statelen, statelen]

        v_t = (q_t[0] + torch.matmul(k_t,torch.transpose(Q_xu,0,1)) + 
               torch.matmul(q_t[1],K_t) + 
               torch.matmul(k_t, torch.matmul(Q_uu,K_t)) 
              )
        #v_t shape = [1, statelen]
        K_arr[i] = K_t
        k_arr[i] = k_t
        i = i - 1
        
    return K_arr, k_arr, ifconv

In [5]:
my_Dyna = Dynamics(STATELEN + ACTLEN, STATELEN, STATELEN)
my_reward = reward(STATELEN + ACTLEN, STATELEN , 1)

In [7]:
state_col = torch.zeros((STEP_SIZE, 1, STATELEN))
state_col[0] = torch.rand((1, STATELEN))
act_col = torch.rand((STEP_SIZE, 1, ACTLEN))
K_arr = torch.zeros(STEP_SIZE, ACTLEN, STATELEN)
k_arr = torch.zeros(STEP_SIZE, 1, ACTLEN)
ifconv = 0

i = 0
while i < 10000:
    
    state_col, act_col ,reward_col = forward_step(state_col, act_col, K_arr, k_arr, STEP_SIZE)
    K_arr, k_arr, ifconv = backward(state_col, act_col, reward_col, STEP_SIZE, ifconv)
    if i%100 == 0:
        print(i)
        print(act_col[0])
    if ifconv == 1:
        break
    
    i = i + 1


0
tensor([[0.1890, 0.3997, 0.9167, 0.2740, 0.3681, 0.8574, 0.2950, 0.1964, 0.5738,
         0.8044]], grad_fn=<SelectBackward0>)
100
tensor([[-0.2412, -2.0254,  1.1644,  2.2280,  0.1140,  2.0930,  2.4859,  2.0104,
          0.8614,  1.3154]], grad_fn=<SelectBackward0>)
200
tensor([[-0.4489, -2.6839,  1.4116,  3.1717, -0.1903,  2.4350,  3.1772,  2.6563,
          1.1909,  1.4209]], grad_fn=<SelectBackward0>)
300
tensor([[-0.5847, -3.0824,  1.5803,  3.8811, -0.4266,  2.6149,  3.6221,  3.0756,
          1.4385,  1.4552]], grad_fn=<SelectBackward0>)
400
tensor([[-0.6831, -3.3743,  1.6969,  4.4486, -0.6106,  2.7421,  3.9524,  3.3994,
          1.6271,  1.4666]], grad_fn=<SelectBackward0>)
500
tensor([[-0.7562, -3.6069,  1.7827,  4.9185, -0.7597,  2.8457,  4.2133,  3.6717,
          1.7731,  1.4678]], grad_fn=<SelectBackward0>)
600
tensor([[-0.8129, -3.8025,  1.8483,  5.3167, -0.8847,  2.9365,  4.4272,  3.9101,
          1.8872,  1.4648]], grad_fn=<SelectBackward0>)
700
tensor([[-0.8589, -3.

5800
tensor([[-1.8441, -7.1688,  1.6594, 10.5816, -3.5021,  4.6822,  6.3957,  8.3156,
          1.8236,  1.1886]], grad_fn=<SelectBackward0>)
5900
tensor([[-1.8581, -7.2015,  1.6511, 10.6265, -3.5345,  4.6947,  6.4062,  8.3574,
          1.8130,  1.1803]], grad_fn=<SelectBackward0>)
6000
tensor([[-1.8718, -7.2337,  1.6432, 10.6709, -3.5663,  4.7069,  6.4167,  8.3984,
          1.8028,  1.1718]], grad_fn=<SelectBackward0>)
6100
tensor([[-1.8853, -7.2652,  1.6356, 10.7148, -3.5973,  4.7188,  6.4273,  8.4388,
          1.7928,  1.1631]], grad_fn=<SelectBackward0>)
6200
tensor([[-1.8985, -7.2960,  1.6283, 10.7582, -3.6277,  4.7304,  6.4379,  8.4784,
          1.7831,  1.1542]], grad_fn=<SelectBackward0>)
6300
tensor([[-1.9114, -7.3263,  1.6213, 10.8011, -3.6574,  4.7418,  6.4486,  8.5173,
          1.7738,  1.1451]], grad_fn=<SelectBackward0>)
6400
tensor([[-1.9240, -7.3560,  1.6145, 10.8436, -3.6865,  4.7529,  6.4593,  8.5556,
          1.7646,  1.1359]], grad_fn=<SelectBackward0>)
6500
t

In [None]:
#for param in rew.parameters():
#    print(param)