In [1]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
%matplotlib notebook
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

In [142]:
class MomentumObject:
    def __init__(self, w, h):
        self.w, self.h = w,h
        self.position = np.array([w/2, h/2])
        self.blank_state = np.zeros((w,h))
        self.speed = np.array([2,3])
    def _jump_bound(self, position):
        if position[0] >= self.w:
            position[0] = position[0] - self.w
        if position[1] >= self.h:
            position[1] = position[1] - self.w
    def step(self):
        self.position += self.speed
        self._jump_bound(self.position)
    def render(self, predict_position=None):
        state = self.blank_state.copy()
        position = np.round(self.position).astype(int)
        state[position[0], position[1]] = 1
        if predict_position is not None:
            pred_x = np.round(predict_position[0]*(self.w-1)).astype(int)
            pred_y = np.round(predict_position[1]*(self.h-1)).astype(int)
            state[pred_x, pred_y] = 0.2
        return state
w,h=10,10
momentum_object = MomentumObject(w, h)
momentum_object.render()

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [143]:
num_buffer = 64
dimension_state = 2
num_state = momentum_object.position.shape[0] + momentum_object.speed.shape[0]
states_actions = torch.zeros((num_buffer, num_state))

def state_FIFO(states_actions):
    new_state_action = np.concatenate((momentum_object.position, momentum_object.speed))
    new_state_action = torch.unsqueeze(torch.tensor(new_state_action), 0)
    states_actions = states_actions[:-1]
    return torch.cat((new_state_action, states_actions))
for i in range(num_buffer):
    momentum_object.step()
    states_actions=state_FIFO(states_actions)
states_actions

tensor([[3., 7., 2., 3.],
        [1., 4., 2., 3.],
        [9., 1., 2., 3.],
        [7., 8., 2., 3.],
        [5., 5., 2., 3.],
        [3., 2., 2., 3.],
        [1., 9., 2., 3.],
        [9., 6., 2., 3.],
        [7., 3., 2., 3.],
        [5., 0., 2., 3.],
        [3., 7., 2., 3.],
        [1., 4., 2., 3.],
        [9., 1., 2., 3.],
        [7., 8., 2., 3.],
        [5., 5., 2., 3.],
        [3., 2., 2., 3.],
        [1., 9., 2., 3.],
        [9., 6., 2., 3.],
        [7., 3., 2., 3.],
        [5., 0., 2., 3.],
        [3., 7., 2., 3.],
        [1., 4., 2., 3.],
        [9., 1., 2., 3.],
        [7., 8., 2., 3.],
        [5., 5., 2., 3.],
        [3., 2., 2., 3.],
        [1., 9., 2., 3.],
        [9., 6., 2., 3.],
        [7., 3., 2., 3.],
        [5., 0., 2., 3.],
        [3., 7., 2., 3.],
        [1., 4., 2., 3.],
        [9., 1., 2., 3.],
        [7., 8., 2., 3.],
        [5., 5., 2., 3.],
        [3., 2., 2., 3.],
        [1., 9., 2., 3.],
        [9., 6., 2., 3.],
        [7.,

In [145]:
def train(encoder, states_actions, optimizer_encoder):
    encoder.train()
    states_actions = states_actions.to(device)
    optimizer_encoder.zero_grad()
    
    next_state = torch.clone(states_actions[:-1, :2])
    current_state = torch.clone(states_actions[1:, :2])
    
    next_state[:, 0] = next_state[:, 0]/(w-1)
    next_state[:, 1] = next_state[:, 1]/(h-1)
    
    state_prediction = encoder(current_state)
    loss = states_diff_loss(state_prediction, next_state)
    print("loss:", loss)
    
    loss.backward()
    optimizer_encoder.step()

torch.manual_seed(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = nn.Sequential(
                nn.Linear(2,16),
                nn.ReLU(),
                nn.Linear(16,8),
                nn.ReLU(),
                nn.Linear(8,2),
                nn.Sigmoid()).to(device).to(torch.float64)

optimizer_encoder = optim.Adam(encoder.parameters(), lr=0.003)
states_diff_loss = nn.MSELoss()

for i in range(4000):
    momentum_object.step()
    states_actions=state_FIFO(states_actions)
    train(encoder, states_actions, optimizer_encoder)


loss: tensor(0.1066, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.1026, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.1000, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0990, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0956, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0921, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0898, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0913, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0921, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0916, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0904, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0895, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0895, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0912, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0902, dtype=torch.f

loss: tensor(0.0414, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0417, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0413, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0411, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0406, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0417, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0423, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0412, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0391, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0381, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0381, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0385, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0381, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0377, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0371, dtype=torch.f

loss: tensor(0.0096, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0091, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0088, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0087, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0086, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0084, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0084, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0083, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0088, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0089, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0087, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0082, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0080, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0079, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0078, dtype=torch.f

loss: tensor(0.0016, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0016, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0016, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0016, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0017, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0017, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0017, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0016, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0015, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0015, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0015, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0015, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0015, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0015, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0016, dtype=torch.f

loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0007, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0006, dtype=torch.f

loss: tensor(0.0004, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0003, dtype=torch.f

loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.f

loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0002, dtype=torch.f

loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.f

loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(9.7576e-05, dtype=tor

loss: tensor(9.3211e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(9.5153e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(9.8741e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.6757e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1910e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.7835e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tens

loss: tensor(7.0832e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(7.1974e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(7.5660e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1931e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1387e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1308e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1964e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(9.3477e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(9.8960e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(9.0857e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.3748e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1110e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.0733e-05, dtype=torch.float64, grad_fn=<

loss: tensor(5.8976e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8708e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.9032e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.0874e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.4934e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.3938e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.2025e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.7295e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8086e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8551e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8495e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8216e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8532e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.0350e-05, dtype=torch.float64, grad_

loss: tensor(5.3616e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.7200e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.6247e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.4485e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.0725e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.2217e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.2721e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.2322e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.1404e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.1645e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.3502e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.7075e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.6295e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.5061e-05, dtype=torch.float64, grad_

loss: tensor(4.3986e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.3593e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.3975e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.6008e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.9768e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.9837e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.0345e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.1650e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.2245e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.9670e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(7.2170e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.6989e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.6473e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.8801e-05, dtype=torch.float64, grad_

loss: tensor(5.2116e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.3549e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.3552e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.1508e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.4458e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.8697e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.0447e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.5964e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.8171e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.5825e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.3778e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.3414e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.9675e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.5460e-05, dtype=torch.float64, grad_

loss: tensor(3.7832e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.5297e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.4203e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.4774e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.6134e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.5094e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.3414e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1036e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1954e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.2093e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1934e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1340e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1436e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.2827e-05, dtype=torch.float64, grad_

loss: tensor(3.8605e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.0942e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.8810e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.3510e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.0225e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.9341e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.0045e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.9649e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.8958e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.7855e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.9849e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.0682e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.0356e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.8926e-05, dtype=torch.float64, grad_

loss: tensor(2.4920e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.6038e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.5750e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.4934e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.3776e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.3676e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.4974e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.6972e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.7198e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.6676e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.5608e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.7076e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.6676e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.5337e-05, dtype=torch.float64, grad_

loss: tensor(2.1830e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.2196e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.2174e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.2983e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.5291e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.8492e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.9488e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1973e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.7056e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.0776e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.8198e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.7378e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.6037e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.9473e-05, dtype=torch.float64, grad_

loss: tensor(9.1993e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(0.0001, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(8.1882e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(6.0922e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.8055e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.0595e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.3102e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.7184e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.2974e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.7511e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.7493e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.5433e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.0893e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.3545e-05, dtype=torch.float64, grad_fn=<

loss: tensor(5.4607e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(5.9123e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.9974e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.0507e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.0328e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.7394e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.9562e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.3614e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.8080e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.4318e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.2297e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(4.2076e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.1404e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.0321e-05, dtype=torch.float64, grad_

loss: tensor(3.2642e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.9320e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.6279e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.1240e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.7677e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5563e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5578e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5370e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5504e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5280e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5712e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.6971e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.8174e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.8133e-05, dtype=torch.float64, grad_

loss: tensor(1.4655e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.6430e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.8196e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.0353e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.3478e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.9923e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(3.0711e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(2.6635e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.9270e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5253e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.4589e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.5732e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.7432e-05, dtype=torch.float64, grad_fn=<MseLossBackward>)
loss: tensor(1.9039e-05, dtype=torch.float64, grad_

In [146]:
fig = plt.figure()
im = plt.imshow(momentum_object.render(), animated=True)
def updatefig(*args):
    position = torch.tensor([momentum_object.position])
    predict_position = encoder(position).detach().numpy()[0]
    momentum_object.step()
    im.set_array(momentum_object.render(predict_position))
ani = FuncAnimation(fig, updatefig, interval=500, blit=True)
# plt.show()

<IPython.core.display.Javascript object>

In [120]:
for i in range(100):
    momentum_object.step()
    position = torch.tensor([momentum_object.position])
    predict_position = encoder(position).detach().numpy()[0]
    print(predict_position)
    im.set_array(momentum_object.render(predict_position))

[1.11755375e-12 1.00000000e+00]


IndexError: index 100 is out of bounds for axis 1 with size 100