In [22]:
import torch
import numpy
import torch.nn as nn

In [23]:
def non_linear_transition(x):
    x1 = x[0] **2
    x2 = x[1] **3
    x3 = x[2] **4
    x4 = torch.sin(x[3])
    reward = x.sum() #Linear Reward (can be quadratic)
    return (torch.tensor([x1, x2, x3, x4]), torch.tensor([reward]))
train = []
test = []
N = 1000
for i in range(N):
    x = torch.randn(4)
    y, r = non_linear_transition(x)
    if i % 10 == 0:
        test.append((x, y,r))
    else:
        train.append((x, y, r))

In [24]:
enc_dim = 16
state_encoder = torch.nn.Sequential(
    torch.nn.Linear(4, enc_dim//2),
    torch.nn.ReLU(),
    torch.nn.Linear(enc_dim//2, enc_dim),
    torch.nn.ReLU(),
    torch.nn.Linear(enc_dim, enc_dim),
)

state_decoder = torch.nn.Sequential(
    torch.nn.Linear(enc_dim, enc_dim//2),
    torch.nn.ReLU(),
    torch.nn.Linear(enc_dim//2, 4)
)

A = torch.nn.Parameter(torch.randn(enc_dim, enc_dim))
Q = torch.nn.Parameter(torch.randn(enc_dim, enc_dim))
optimizer = torch.optim.Adam(list(state_encoder.parameters()) + list(state_decoder.parameters()) + [A] + [Q], lr=1e-3)

In [36]:
#train, test are lists of tuples of (x, y, r)
epochs = 10000
criterion = torch.nn.MSELoss()
for i in range(epochs):
    total_state_loss = 0
    total_reward_loss = 0
    for x, y, r in train:
        optimizer.zero_grad()
        xx = state_encoder(x)
        pred_state = A @ xx
        Q_psd = Q.T @ Q
        pred_reward = xx.T @ Q_psd @ xx
        dec_state = state_decoder(pred_state)   
        state_loss = criterion(dec_state, y)
        reward_loss = criterion(pred_reward, r)
        loss = state_loss + reward_loss
        loss.backward()
        optimizer.step()
        total_state_loss += state_loss.item()
        total_reward_loss += reward_loss.item()
    if i % 10 == 0:
        with torch.no_grad():
            total_test_state_loss = 0
            total_test_reward_loss = 0
            for x, y,r  in test:
                xx = state_encoder(x)
                pred_state = A @ xx
                Q_psd = Q.T @ Q
                pred_reward = xx.T @ Q_psd @ xx
                dec_state = state_decoder(pred_state)
                state_loss = criterion(dec_state, y)
                reward_loss = criterion(pred_reward, r)
                total_test_state_loss += state_loss.item()
                total_test_reward_loss += reward_loss.item()
            print(f"Epoch {i}: Train State Loss: {total_state_loss/N}, Train Reward Loss: {total_reward_loss/N}, Test State Loss: {total_test_state_loss/len(test)}, Test Reward Loss: {total_test_reward_loss/len(test)}")

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 0: Train State Loss: 5.842352452687919, Train Reward Loss: 31.381448214975993, Test State Loss: 4.450371186621487, Test Reward Loss: 7.446117175092513
Epoch 10: Train State Loss: 1.9988502228418366, Train Reward Loss: 2.228102112329458, Test State Loss: 1.6656546838767827, Test Reward Loss: 3.551352343352046
Epoch 20: Train State Loss: 1.1206838486506603, Train Reward Loss: 2.104244506692847, Test State Loss: 1.0893463388457894, Test Reward Loss: 3.4069696409760946
Epoch 30: Train State Loss: 0.7032914839796722, Train Reward Loss: 2.059865338200907, Test State Loss: 0.8657066050171852, Test Reward Loss: 3.368934662309475
Epoch 40: Train State Loss: 0.8388995628934354, Train Reward Loss: 2.034564663586121, Test State Loss: 1.5374239451624454, Test Reward Loss: 3.328597938924006
Epoch 50: Train State Loss: 0.8845783798075281, Train Reward Loss: 2.0142493329093702, Test State Loss: 2.9323431236017496, Test Reward Loss: 3.255285142699104
Epoch 60: Train State Loss: 0.7355981693575159

KeyboardInterrupt: 

In [37]:
t = train[0]
x, y, r = t
xx = state_encoder(x)
pred_state = A @ xx
dec_state = state_decoder(pred_state)
print(x)
print(xx)
print(y)
print(dec_state)
print("Test")
t = test[0]
x, y,r = t
xx = state_encoder(x)
pred_state = A @ xx
dec_state = state_decoder(pred_state)
print(x)
print(y)
print(dec_state)

tensor([-0.3892, -1.1420,  0.4347,  0.3648])
tensor([ 0.0351, -0.0070,  0.1394, -1.6561,  0.3626, -0.8260,  1.1684,  0.1074,
         0.7020, -0.0528, -0.3366,  0.7623, -0.2779, -0.1570,  0.3622,  0.0454],
       grad_fn=<ViewBackward0>)
tensor([ 0.1514, -1.4894,  0.0357,  0.3568])
tensor([ 0.1218, -1.5987,  0.1711,  0.1662], grad_fn=<ViewBackward0>)
Test
tensor([ 0.6592, -0.8172, -1.5943, -2.5283])
tensor([ 0.4345, -0.5458,  6.4610, -0.5755])
tensor([ 0.3680, -0.9164,  6.7711, -1.4532], grad_fn=<ViewBackward0>)
