In [45]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions.normal import Normal
import numpy as np

class MemoryModel(nn.Module):
    def __init__(self, z_dim, hidden_size, lr, device):
        super(MemoryModel, self).__init__()

        self.action_dim = 3  # 0 : forward, 1 : left, 2 : right
        self.vector_dim = 3

        self.z_dim = z_dim # 64
        input_size = self.z_dim + self.action_dim + self.vector_dim

        self.input_size =  input_size# 70
        self.hidden_size = hidden_size # 512
        self.lr = lr
        self.device = device

        self.gaussian_const = ((np.pi * 2) ** (1 / 2))


        # [seq, batch, input_size]
        self.lstm = nn.LSTM(input_size=input_size,
                            hidden_size=hidden_size,
                            num_layers=1).to(device)

        # [seq, batch, hidden_size], + (h_n, c_n)

        # [seq * batch, hidden_size]

        self.fc1 = nn.Linear(hidden_size, hidden_size).to(device)

        # [seq * batch, hidden_size]

        self.fc2 = nn.Linear(hidden_size, hidden_size).to(device)

        # [seq * batch, hidden_size]

        #MDN
        self.mu = nn.Linear(hidden_size, z_dim * 5).to(device)
        self.log_std = nn.Linear(hidden_size, z_dim * 5).to(device)
        self.prob = nn.Linear(hidden_size, z_dim * 5).to(device)

        # [seq * batch, z_dim * 5]

        # [seq, batch, z_dim, 5]

        self.optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=1e-6)

    def gaussian_pdf(self, y, mu, sigma):
        return torch.clamp_min(torch.exp((y - mu) ** 2 / (-2) / (sigma ** 2)) / (sigma * self.gaussian_const), 1e-3)

    def gumbel_sample(self, prob, dim):
        z = np.random.gumbel(loc=0.0, scale=1.0, size=prob.shape)
        return torch.argmax(torch.log(prob) + torch.FloatTensor(z).to(self.device), dim=dim, keepdim=True)

    def full_forward(self, x, a, v, h, c):
        # x.shape == (seq, batch, z_dim)
        # a.shape == (seq, batch, a_dim)
        # v.shape == (seq, batch, v_dim)
        # h.shape == (1, batch, hidden_size)
        # c.shape == (1, batch, hidden_size)
        assert len(x.shape) == 3 and x.shape[2] == self.z_dim
        assert len(a.shape) == 3 and a.shape[2] == self.action_dim
        assert len(v.shape) == 3 and v.shape[2] == self.vector_dim
        assert len(h.shape) == 3 and h.shape[2] == self.hidden_size
        assert len(c.shape) == 3 and c.shape[2] == self.hidden_size

        lstm_input = torch.cat([x,a,v], dim=2)

        lstm_output, (h_n, c_n) = self.lstm(lstm_input, (h, c))

        seq = lstm_output.shape[0]
        batch = lstm_output.shape[1]

        lstm_output = lstm_output.view([seq * batch, self.hidden_size])

        fc_output = torch.celu(self.fc1(lstm_output))
        fc_output = torch.celu(self.fc2(fc_output))

        mu = self.mu(fc_output)
        mu = mu.view([seq, batch, self.z_dim, 5])

        log_std = self.log_std(fc_output)
        sigma = torch.exp(torch.clamp(log_std, -3, 3)).view([seq, batch, self.z_dim, 5])

        prob = self.prob(fc_output).view([seq, batch, self.z_dim, 5])
        prob = F.softmax(prob, dim=3)

        gaussian_index = self.gumbel_sample(prob, dim=3)

        selected_mu = torch.gather(mu, dim=3, index= gaussian_index)
        selected_sigma = torch.gather(sigma, dim=3, index= gaussian_index)

        mdn_output = selected_mu + selected_sigma * torch.randn_like(selected_sigma)

        return mdn_output, mu, sigma, prob, h_n, c_n





    def hidden_forward(self, x):
        pass



    def train(self, z_in_batch, z_out_batch, a_batch, v_batch, mask):
        assert z_in_batch.shape == (40, 32, 64)
        assert z_out_batch.shape == (40, 32, 64)
        assert a_batch.shape == (40, 32, 3)
        assert v_batch.shape == (40, 32, 3)
        assert mask.shape == (40, 32)

        mask = mask.unsqueeze(2).float()
        z_out_batch = z_out_batch.unsqueeze(3)

        h_0 = torch.zeros([1, 32, self.hidden_size]).to(self.device)
        c_0 = torch.zeros([1, 32, self.hidden_size]).to(self.device)

        _, mu, sigma, prob, _, _ = self.full_forward(z_in_batch, a_batch, v_batch, h_0, c_0)

        assert mu.shape == (40, 32, 64, 5)
        assert sigma.shape == (40, 32, 64, 5)
        assert prob.shape == (40, 32, 64, 5)

        p_y = self.gaussian_pdf(z_out_batch, mu, sigma) # [40, 32, 64, 5]
        

        loss = torch.mean(-torch.log(torch.sum(p_y * prob, dim=3)) * mask)

        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_value_(self.parameters(), 1.0)
        self.optimizer.step()

        return loss.detach().cpu().numpy()


    def train_batch(self):
        pass


In [46]:
z_dim = 64
hidden_size = 512
lr = 1e-3
device = torch.device("cuda")

In [47]:
seq = 40
batch = 32
a_dim = 3
v_dim = 3

In [191]:
MM = MemoryModel(z_dim, hidden_size, lr, device)

In [198]:
x = torch.ones([seq, batch, z_dim]).to(device) * 0.5
a = torch.ones([seq, batch, a_dim]).to(device) * 0.1
v = torch.ones([seq, batch, v_dim]).to(device) * 0.2
h = torch.zeros([1, batch, hidden_size]).to(device)
c = torch.zeros([1, batch, hidden_size]).to(device)

In [238]:
t = MM.full_forward(x, a, v, h, c)

In [215]:
x[0][1]

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000], device='cuda:0')

In [245]:
t[0][0][8][:,0]

tensor([0.4467, 0.4589, 0.4403, 0.5679, 0.4150, 0.5173, 0.4574, 0.4499, 0.4632,
        0.4766, 0.5393, 0.4617, 0.4283, 0.5350, 0.5314, 0.4750, 0.4849, 0.5398,
        0.3999, 0.5463, 0.5641, 0.5102, 0.4829, 0.4741, 0.5623, 0.5313, 0.5057,
        0.5588, 0.4908, 0.4740, 0.4246, 0.4838, 0.5814, 0.5378, 0.5778, 0.4472,
        0.4999, 0.4392, 0.5171, 0.5424, 0.5231, 0.4568, 0.4869, 0.4973, 0.4446,
        0.5068, 0.5175, 0.5524, 0.5508, 0.4783, 0.4975, 0.5601, 0.5686, 0.4633,
        0.4971, 0.5285, 0.5208, 0.4732, 0.4889, 0.5051, 0.4505, 0.5516, 0.4417,
        0.4075], device='cuda:0', grad_fn=<SelectBackward>)

In [254]:
def one_hot(indices,total_len):
    return np.eye(total_len)[indices]

In [257]:
one_hot([2,0,1],3)

array([[0., 0., 1.],
       [1., 0., 0.],
       [0., 1., 0.]])

In [202]:
mask = torch.ones([40,32]).to(device)

In [213]:
for _ in range(10000):
    print(MM.train(x, x, a, v, mask))


-2.0804627
-2.0804682
-2.0804737
-2.0804794
-2.0804849
-2.08049
-2.080495
-2.0805004
-2.0805051
-2.08051
-2.0805154
-2.0805202
-2.0805247
-2.0805295
-2.080534
-2.0805383
-2.0805426
-2.080547
-2.0805516
-2.080556
-2.08056
-2.0805643
-2.080568
-2.0805721
-2.0805762
-2.08058
-2.080584
-2.0805879
-2.0805912
-2.080595
-2.0805986
-2.0806022
-2.0806053
-2.0806088
-2.0806124
-2.0806158
-2.0806186
-2.0806222
-2.0806253
-2.0806282
-2.0806313
-2.080634
-2.0806367
-2.0806396
-2.0806415
-2.0806425
-2.0806422
-2.0806382
-2.0806267
-2.0806003
-2.0805414
-2.080408
-2.0801094
-2.0794303
-2.077872
-2.0742981
-2.066218
-2.0492866
-2.0188131
-1.9839019
-1.9743885
-2.0205967
-2.065258
-2.0536194
-2.0294647
-2.0542524
-2.0611627
-2.0497262
-2.0614817
-2.0685291
-2.0628884
-2.0723422
-2.0720625
-2.0708008
-2.073905
-2.0710537
-2.073747
-2.0731432
-2.074592
-2.0765736
-2.0758789
-2.0775802
-2.077484
-2.0769684
-2.0776305
-2.0772736
-2.077481
-2.0778186
-2.0783145
-2.0786042
-2.0789056
-2.0792315
-2.0791187
-2

-2.0778205
-2.0774581
-2.0795276
-2.0795364
-2.078422
-2.0791156
-2.079974
-2.0795684
-2.0793355
-2.079853
-2.0801058
-2.0798728
-2.0798886
-2.0801916
-2.0802367
-2.0801418
-2.0802329
-2.0803602
-2.0803611
-2.0803616
-2.0804172
-2.080462
-2.0804813
-2.0805035
-2.0805204
-2.0805404
-2.0805786
-2.080588
-2.0805795
-2.0806134
-2.0806444
-2.0806296
-2.0806355
-2.0806699
-2.0806744
-2.0806677
-2.0806832
-2.0806983
-2.0806983
-2.080703
-2.0807126
-2.0807164
-2.0807214
-2.080729
-2.0807314
-2.080732
-2.0807405
-2.0807483
-2.080746
-2.0807455
-2.080756
-2.0807621
-2.0807583
-2.0807579
-2.0807674
-2.0807729
-2.0807695
-2.0807693
-2.080776
-2.0807812
-2.0807805
-2.0807798
-2.0807834
-2.0807884
-2.0807896
-2.080789
-2.0807908
-2.0807943
-2.080797
-2.0807977
-2.0807981
-2.0808005
-2.080803
-2.0808046
-2.0808055
-2.0808067
-2.0808089
-2.080811
-2.0808122
-2.0808132
-2.0808144
-2.080816
-2.080818
-2.0808191
-2.08082
-2.0808213
-2.0808232
-2.0808244
-2.0808256
-2.0808268
-2.0808282
-2.0808294
-2.0808

KeyboardInterrupt: 