In [109]:
from IPython.display import clear_output
from cartpole import format_state, sim, ConvNet4 as ConvNet, FCNet0 as FCNet, CartPoleTorch
from gym.envs.classic_control.cartpole import CartPoleEnv
from gym.utils import seeding
from torch import nn
from torch import optim
import numpy as np
import torch

In [110]:
def format_array(array, float_format='.4f'):
    formatter = {'float': lambda v: format(v, float_format)}
    return np.array2string(array, formatter=formatter)

In [111]:
checkpoint = torch.load('FC-150-v3.pt')
policy = FCNet().eval()
policy.load_state_dict(checkpoint['policy'])

# Freeze all parameters
for param in policy.parameters():
    param.requires_grad_(False)

In [108]:
sim(policy, env_raw, animation=True, terminal=True)

In [112]:
env_raw = CartPoleEnv()
env_torch = CartPoleTorch()
np_random, _ = seeding.np_random(1)

In [113]:
format_state(np_random.uniform(low=-0.05, high=0.05, size=(4,)))

'(x = +0.0307, dx = +0.0015, θ =  -1.7698, dθ = -0.0313)'

In [116]:
initial_state = env_torch.reset(np_random)
initial_state.fill_(0)
initial_state.requires_grad_()
print(format_state(initial_state))

optimizer = optim.SGD([initial_state], lr=0.005)

(x = +0.0000, dx = +0.0000, θ =  +0.0000, dθ = +0.0000)


In [130]:
for _ in range(1):
    clear_output(wait=True)
    
    state = initial_state
    prob_batch = []
    done = False
    t = 0
    
    print('INITIAL STATE = {}'.format(format_state(state)))
    while not done:
        act_prob = policy(state.unsqueeze(0)).squeeze()
        prob_L, prob_R = act_prob.detach().numpy()
        print('#{:2d}: L = {:.4f}, R = {:.4f}, {}'.format(t, prob_L, prob_R, format_state(state)))
        prob_batch.append(act_prob[0])
        state, done = env_torch.step(state, 0)
        t += 1
    print('END STATE     = {}'.format(format_state(state)))

    prob_batch = torch.stack(prob_batch)
    loss = torch.sum(-torch.log(prob_batch))
    
    optimizer.zero_grad()
    loss.backward()
    # initial_state.grad[0] = 0.0
    # initial_state.grad[1] = 0.0
    initial_state.grad[2] = 0.0
    # initial_state.grad[3] = 0.0
    print('grad', format_array(initial_state.grad.numpy()))
    optimizer.step()

    print('loss = {:.7f}'.format(loss.item()))
    print(format_state(initial_state))

INITIAL STATE = (x = -0.2751, dx = -0.8771, θ =  +0.0000, dθ = -2.3757)
# 0: L = 1.0000, R = 0.0000, (x = -0.2751, dx = -0.8771, θ =  +0.0000, dθ = -2.3757)
# 1: L = 1.0000, R = 0.0000, (x = -0.2926, dx = -1.0722, θ =  -2.7224, dθ = -2.0830)
# 2: L = 1.0000, R = 0.0000, (x = -0.3141, dx = -1.2668, θ =  -5.1093, dθ = -1.8054)
# 3: L = 1.0000, R = 0.0000, (x = -0.3394, dx = -1.4608, θ =  -7.1782, dθ = -1.5417)
# 4: L = 1.0000, R = 0.0000, (x = -0.3686, dx = -1.6542, θ =  -8.9448, dθ = -1.2906)
# 5: L = 1.0000, R = 0.0000, (x = -0.4017, dx = -1.8471, θ = -10.4238, dθ = -1.0506)
# 6: L = 1.0000, R = 0.0000, (x = -0.4387, dx = -2.0394, θ = -11.6276, dθ = -0.8201)
END STATE     = (x = -0.4795, dx = -2.2312, θ = -12.5674, dθ = -0.5975)
grad [0.0000 0.0000 0.0000 0.0000]
loss = 0.0000000
(x = -0.2751, dx = -0.8771, θ =  +0.0000, dθ = -2.3757)


In [124]:
format_state(initial_state)

'(x = -0.2751, dx = -0.8771, θ =  +0.0000, dθ = -2.3757)'

In [104]:
sim(policy, env_raw, limit=100, initial_state=initial_state.detach().numpy(), animation=True, sleep=None, live_state=False, terminal=False)