In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
%matplotlib notebook
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [2]:
w,h=10,10

In [3]:
def delete_index(T, index):
    return torch.cat([T[:i], T[i+1:]])

In [4]:
def standardize_state(state):
    state = np.copy(state)
    state[0] = (state[0] - ((w-1)/2)) / ((w-1)/2)
    state[1] = (state[1] - ((h-1)/2)) / ((h-1)/2)
    return state
standardize_state([4., 9.]), standardize_state([0., 5.]), standardize_state([4.5, 4.5])

(array([-0.11111111,  1.        ]),
 array([-1.        ,  0.11111111]),
 array([0., 0.]))

In [5]:
def readable_state(state):
    state = np.copy(state)
    state[0] = state[0] * ((w-1)/2) + ((w-1)/2)
    state[1] = state[1] * ((h-1)/2) + ((h-1)/2)
    return state
readable_state(standardize_state([4., 9.])), \
readable_state(standardize_state([0., 5.])), \
readable_state(standardize_state([4.5, 4.5]))

(array([4., 9.]), array([0., 5.]), array([4.5, 4.5]))

In [24]:
class MomentumObject:
    def __init__(self, w, h):
        self.w, self.h = w,h
        self.position = np.array([w/2, h/2])
        self.blank_state = np.zeros((w,h))
        self.speed = np.random.randint(-2, 2, size=2)
    def _jump_bound(self, position):
        if position[0] >= self.w:
            position[0] = position[0] - self.w
        elif position[0] < 0:
            position[0] = self.w + position[0]
        if position[1] >= self.h:
            position[1] = position[1] - self.w
        elif position[1] < 0:
            position[1] = self.h + position[1]
    def get_state_action(self):
        return np.concatenate((self.position, self.speed))
    def step(self):
        self.position += self.speed
        self._jump_bound(self.position)
        self.speed = np.random.randint(-2, 2, size=2)
    def render(self, predict_position=None):
        state = self.blank_state.copy()
        position = np.round(self.position).astype(int)
        state[position[0], position[1]] = 1
        if predict_position is not None:
            pred_x = np.round(predict_position[0]*(self.w-1)).astype(int)
            pred_y = np.round(predict_position[1]*(self.h-1)).astype(int)
            state[pred_x, pred_y] = 0.2
        return state
momentum_object = MomentumObject(w, h)
momentum_object.render()

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [25]:
num_buffer = 128
dimension_state = 2
num_state = momentum_object.position.shape[0] + momentum_object.speed.shape[0]
states_actions = torch.zeros((num_buffer, num_state))

def update_state_buffer(states_actions, remove_index=None):
    new_state_action = torch.unsqueeze(torch.tensor(momentum_object.get_state_action()), 0)
    if remove_index is None:
        states_actions = states_actions[:-1]
    else:
        states_actions = delete_index(states_actions, remove_index)
    return torch.cat((new_state_action, states_actions))
for i in range(num_buffer):
    momentum_object.step()
    states_actions=update_state_buffer(states_actions)
states_actions

tensor([[ 4.,  5.,  1., -1.],
        [ 6.,  4., -2.,  1.],
        [ 6.,  4.,  0.,  0.],
        [ 7.,  3., -1.,  1.],
        [ 9.,  2., -2.,  1.],
        [ 9.,  4.,  0., -2.],
        [ 1.,  6., -2., -2.],
        [ 1.,  5.,  0.,  1.],
        [ 1.,  5.,  0.,  0.],
        [ 3.,  4., -2.,  1.],
        [ 5.,  4., -2.,  0.],
        [ 6.,  4., -1.,  0.],
        [ 7.,  3., -1.,  1.],
        [ 7.,  5.,  0., -2.],
        [ 6.,  7.,  1., -2.],
        [ 7.,  8., -1., -1.],
        [ 7.,  8.,  0.,  0.],
        [ 7.,  9.,  0., -1.],
        [ 6.,  9.,  1.,  0.],
        [ 8.,  1., -2., -2.],
        [ 7.,  3.,  1., -2.],
        [ 9.,  4., -2., -1.],
        [ 9.,  5.,  0., -1.],
        [ 9.,  5.,  0.,  0.],
        [ 9.,  4.,  0.,  1.],
        [ 1.,  4., -2.,  0.],
        [ 1.,  5.,  0., -1.],
        [ 1.,  7.,  0., -2.],
        [ 3.,  6., -2.,  1.],
        [ 4.,  7., -1., -1.],
        [ 5.,  6., -1.,  1.],
        [ 5.,  7.,  0., -1.],
        [ 5.,  8.,  0., -1.],
        [ 

In [26]:
def train(i, encoder, states_actions, optimizer_encoder):
    encoder.train()
    states_actions = states_actions.to(device)
    optimizer_encoder.zero_grad()
    
    current_state = torch.clone(states_actions[1:]).to(device).to(device)
    next_state = torch.clone(states_actions[:-1, :2]).to(device)
    
    current_state[:, 0] = current_state[:, 0] - ( (w-1)/2)
    current_state[:, 1] = current_state[:, 1] - ((h-1)/2)
    next_state[:, 0] = next_state[:, 0]/(w-1)
    next_state[:, 1] = next_state[:, 1]/(h-1)
    
    state_prediction = encoder(current_state)
    loss = mse_loss(state_prediction, next_state)
    print("loss[", i,"]:", loss.mean())
    
    loss.mean().backward()
    optimizer_encoder.step()
    
    return torch.argmin(loss.mean(axis=1)).type(torch.int)
    
torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = nn.Sequential(
                nn.Linear(4,16),
                nn.ReLU(),
                nn.Linear(16,32),
                nn.ReLU(),
                nn.Linear(32,8),
                nn.ReLU(),
                nn.Linear(8,2),
                nn.Sigmoid()).to(device).to(torch.float64)

optimizer_encoder = optim.Adam(encoder.parameters(), lr=0.001)
mse_loss = nn.MSELoss(reduction='none')

lowest_loss_index = None
for i in range(5000):
    momentum_object.step()
    states_actions=update_state_buffer(states_actions, lowest_loss_index)
    lowest_loss_index = train(i, encoder, states_actions, optimizer_encoder)


loss[ 0 ]: tensor(0.0959, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1 ]: tensor(0.0953, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2 ]: tensor(0.0949, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 3 ]: tensor(0.0946, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 4 ]: tensor(0.0943, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 5 ]: tensor(0.0936, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 6 ]: tensor(0.0940, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 7 ]: tensor(0.0935, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 8 ]: tensor(0.0932, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 9 ]: tensor(0.0927, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 10 ]: tensor(0.0924, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 11 ]: tensor(0

loss[ 112 ]: tensor(0.0600, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 113 ]: tensor(0.0622, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 114 ]: tensor(0.0621, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 115 ]: tensor(0.0626, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 116 ]: tensor(0.0633, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 117 ]: tensor(0.0634, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 118 ]: tensor(0.0619, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 119 ]: tensor(0.0616, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 120 ]: tensor(0.0602, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 121 ]: tensor(0.0592, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 122 ]: tensor(0.0591, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 237 ]: tensor(0.0536, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 238 ]: tensor(0.0533, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 239 ]: tensor(0.0531, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 240 ]: tensor(0.0529, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 241 ]: tensor(0.0529, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 242 ]: tensor(0.0529, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 243 ]: tensor(0.0526, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 244 ]: tensor(0.0524, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 245 ]: tensor(0.0521, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 246 ]: tensor(0.0518, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 247 ]: tensor(0.0517, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 366 ]: tensor(0.0361, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 367 ]: tensor(0.0358, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 368 ]: tensor(0.0355, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 369 ]: tensor(0.0353, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 370 ]: tensor(0.0351, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 371 ]: tensor(0.0349, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 372 ]: tensor(0.0346, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 373 ]: tensor(0.0343, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 374 ]: tensor(0.0340, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 375 ]: tensor(0.0338, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 376 ]: tensor(0.0335, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 494 ]: tensor(0.0210, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 495 ]: tensor(0.0209, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 496 ]: tensor(0.0208, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 497 ]: tensor(0.0209, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 498 ]: tensor(0.0208, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 499 ]: tensor(0.0207, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 500 ]: tensor(0.0206, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 501 ]: tensor(0.0204, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 502 ]: tensor(0.0203, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 503 ]: tensor(0.0204, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 504 ]: tensor(0.0202, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 607 ]: tensor(0.0118, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 608 ]: tensor(0.0117, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 609 ]: tensor(0.0116, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 610 ]: tensor(0.0115, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 611 ]: tensor(0.0114, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 612 ]: tensor(0.0113, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 613 ]: tensor(0.0112, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 614 ]: tensor(0.0113, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 615 ]: tensor(0.0112, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 616 ]: tensor(0.0111, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 617 ]: tensor(0.0110, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 731 ]: tensor(0.0061, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 732 ]: tensor(0.0060, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 733 ]: tensor(0.0060, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 734 ]: tensor(0.0059, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 735 ]: tensor(0.0059, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 736 ]: tensor(0.0058, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 737 ]: tensor(0.0060, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 738 ]: tensor(0.0059, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 739 ]: tensor(0.0059, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 740 ]: tensor(0.0058, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 741 ]: tensor(0.0058, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 849 ]: tensor(0.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 850 ]: tensor(0.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 851 ]: tensor(0.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 852 ]: tensor(0.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 853 ]: tensor(0.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 854 ]: tensor(0.0036, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 855 ]: tensor(0.0035, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 856 ]: tensor(0.0035, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 857 ]: tensor(0.0035, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 858 ]: tensor(0.0035, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 859 ]: tensor(0.0035, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 966 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 967 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 968 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 969 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 970 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 971 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 972 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 973 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 974 ]: tensor(0.0026, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 975 ]: tensor(0.0025, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 976 ]: tensor(0.0025, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)

loss[ 1083 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1084 ]: tensor(0.0019, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1085 ]: tensor(0.0019, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1086 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1087 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1088 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1089 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1090 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1091 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1092 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1093 ]: tensor(0.0020, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1200 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1201 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1202 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1203 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1204 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1205 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1206 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1207 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1208 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1209 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1210 ]: tensor(0.0016, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1317 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1318 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1319 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1320 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1321 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1322 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1323 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1324 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1325 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1326 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1327 ]: tensor(0.0013, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1434 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1435 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1436 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1437 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1438 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1439 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1440 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1441 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1442 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1443 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1444 ]: tensor(0.0011, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1549 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1550 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1551 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1552 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1553 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1554 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1555 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1556 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1557 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1558 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1559 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1666 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1667 ]: tensor(0.0009, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1668 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1669 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1670 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1671 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1672 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1673 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1674 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1675 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1676 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1783 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1784 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1785 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1786 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1787 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1788 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1789 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1790 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1791 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1792 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1793 ]: tensor(0.0008, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 1897 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1898 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1899 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1900 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1901 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1902 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1903 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1904 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1905 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1906 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 1907 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2011 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2012 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2013 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2014 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2015 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2016 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2017 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2018 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2019 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2020 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2021 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2125 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2126 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2127 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2128 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2129 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2130 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2131 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2132 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2133 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2134 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2135 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2237 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2238 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2239 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2240 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2241 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2242 ]: tensor(0.0007, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2243 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2244 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2245 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2246 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2247 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2349 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2350 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2351 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2352 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2353 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2354 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2355 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2356 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2357 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2358 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2359 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2462 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2463 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2464 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2465 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2466 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2467 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2468 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2469 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2470 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2471 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2472 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2574 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2575 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2576 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2577 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2578 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2579 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2580 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2581 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2582 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2583 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2584 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2692 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2693 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2694 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2695 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2696 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2697 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2698 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2699 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2700 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2701 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2702 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

loss[ 2806 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2807 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2808 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2809 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2810 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2811 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2812 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2813 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2814 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2815 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<MeanBackward0>)
loss[ 2816 ]: tensor(0.0006, device='cuda:0', dtype=torch.float64, grad_fn=<Mean

KeyboardInterrupt: 

In [None]:
fig = plt.figure()
im = plt.imshow(momentum_object.render(), animated=True)
def updatefig(*args):
    position = torch.tensor([momentum_object.get_state_action()]).to(device)
    position[:, 0] = position[:, 0] - ( (w-1)/2)
    position[:, 1] = position[:, 1] - ((h-1)/2)
    
    predict_position = encoder(position).cpu().detach().numpy()[0]
    momentum_object.step()
    im.set_array(momentum_object.render(predict_position))
ani = FuncAnimation(fig, updatefig, interval=500, blit=True)
# plt.show()

In [10]:
for i in range(100):
    momentum_object.step()
    position = torch.tensor([momentum_object.get_state_action()]).to(device)
    predict_position = encoder(position).cpu().detach().numpy()[0]
    print(predict_position)
    im.set_array(momentum_object.render(predict_position))

[0.99989771 0.76936865]
[0.96353125 0.67160851]
[0.89724227 0.67147141]
[0.77119219 0.6711651 ]
[0.77014911 0.55693823]
[0.55611583 0.32529832]
[0.55597464 0.32526446]
[0.32164272 0.32699519]
[0.43527096 0.32574755]
[0.55602211 0.32501891]
[0.43671023 0.22514227]
[0.4457642 0.110193 ]
[0.32904494 0.9991435 ]
[0.44018009 0.99019879]
[0.43995323 0.99014715]
[0.43995323 0.99014715]
[0.32564901 0.99014945]
[0.43370031 0.00382012]
[0.44558486 0.11052904]
[0.33303576 0.00247311]
[0.44810503 1.        ]
[0.2262158  0.88820255]
[0.32401082 0.7721383 ]
[0.22933378 0.67344814]
[0.23175572 0.43709529]
[0.32794304 0.22605561]
[0.33464876 0.11232607]
[0.33126836 0.00242148]
[0.44810503 1.        ]
[0.55872918 0.77143021]
[0.43599322 0.88779671]
[0.43769416 0.77212218]
[0.44032525 0.67214079]
[0.3260151  0.67307647]
[0.23122742 0.55511823]
[0.0013097  0.33116823]
[1.         0.33138632]
[0.67170247 0.43864389]
[0.55754912 0.43880582]
[0.32661393 0.22198879]
[0.32732766 0.22497006]
[0.22432116 0.0025