In [1]:
import sys
sys.path.insert(0, "../")

import numpy as np
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical

from ConvLSTM_pytorch.convlstm import ConvLSTM

# Test dataset

Testing memory: <br>
Given a sequence of x randomly generated, y is going to be the sum of all the pixels of the 2 frames before this one. 

x_0, x_1, x_2, x_3, ..., x_T  <br>
y_0 = 0  <br>
y_1 = x_0.sum()  <br>
y_2 = x_0.sum() + x_1.sum()  <br>
y_3 = x_1.sum() + x_2.sum() <br>

and so on

In [55]:
N = 10000 # sequences in dataset
T = 10 # length of each sequence
C = 1 # number of channels
W = 10
x = torch.rand(N, T, C, W, W)

In [56]:
x_padded = torch.cat([torch.zeros(N, 1, C, W, W),x], dim=1)
x_padded.shape

torch.Size([10000, 11, 1, 10, 10])

In [57]:
def compute_labels(x):
    y = torch.zeros(x.shape[0],x.shape[1]-1)
    for t in range(x.shape[1]-1):
        y[:,t] = x[:,t:t+2,...].sum(axis=(1,2,3,4))
    return y

In [58]:
y = compute_labels(x_padded)

In [67]:
class Net(nn.Module):
    def __init__(self, C, W):
        super(Net, self).__init__()
        self.conv_lstm = ConvLSTM(
                     input_dim=1, 
                     hidden_dim=1, 
                     kernel_size=(3,3), 
                     num_layers=1,
                     batch_first=True,
                     bias=True,
                     return_all_layers=False
                    )
        
        self.MLP = nn.Sequential(
            nn.Linear(C*W**2, 256),
            nn.ReLU(),
            nn.Linear(256,1)
            )
        
    def forward(self, x, hidden=None):
        B = x.shape[0]
        T = x.shape[1]
        # x: (b,T,C,W,W)
        x, state = self.conv_lstm(x, hidden)
        # x: (b,T,C,W,W)
        x = x[0].view(B,T,-1)
        #print(len(x))
        #print(x.shape)
        out = self.MLP(x)
        #print(out.shape)
        return out

In [68]:
net = Net(C,W)

In [69]:
x_train = x[:8000]
x_test = x[8000:]
y_train = y[:8000]
y_test = y[8000:]

In [70]:
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [71]:
n_epochs = 10
epoch_losses = []
for e in range(n_epochs):
    n_batch = len(x_train)//10
    losses = []
    for b in range(n_batch):
        x = x_train[b*10:(b+1)*10]
        y = y_train[b*10:(b+1)*10]
        y_pred = net(x).squeeze()
        loss = F.mse_loss(y,y_pred)
        losses.append(loss.item())
        optim.zero_grad()
        loss.backward()
        optim.step()
    print("Epoch %d - loss %.4f"%(e+1,np.mean(losses)))
    epoch_losses.append(np.mean(losses))

Epoch 1 - loss 9309.8496
Epoch 2 - loss 9306.8770
Epoch 3 - loss 9304.0684
Epoch 4 - loss 9301.4180
Epoch 5 - loss 9298.9229
Epoch 6 - loss 9296.5811
Epoch 7 - loss 9294.3867
Epoch 8 - loss 9292.3340
Epoch 9 - loss 9290.4170
Epoch 10 - loss 9288.6260


In [64]:
y_pred = net(x_test[:1]).squeeze()

In [65]:
y_pred

tensor([ 51.5265,  99.7310,  99.5614, 103.5376,  98.9052, 101.7952,  94.5668,
        102.7352, 101.3337,  98.5705], grad_fn=<SqueezeBackward0>)

In [66]:
y_test[0]

tensor([ 50.5683, 102.1472, 100.3817, 104.2558, 104.5072, 101.9911,  96.8353,
         97.0628, 106.3598, 103.2412])

In [53]:
help(F.mse_loss)

Help on function mse_loss in module torch.nn.functional:

mse_loss(input, target, size_average=None, reduce=None, reduction='mean')
    mse_loss(input, target, size_average=None, reduce=None, reduction='mean') -> Tensor
    
    Measures the element-wise mean squared error.
    
    See :class:`~torch.nn.MSELoss` for details.

