In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
from torch import nn
import torch.nn.functional as f
from torch.autograd import Variable
import copy
import time
from mmdet3d.models.temporal import ConvLSTM, ConvGRU

In [7]:
num_seqs = 3
hidden_size = 256
channels_img = 256
size_image = 200
max_epoch = 1
kernel_size = 3

In [8]:
conv_lstm = ConvLSTM(
    in_channels=256,
    hidden_channels=[256], # hparam
    kernel_size=[3,3],
    batch_first=True,
    bias=True,
    return_all_layers=False,
)

conv_lstm.eval()

print(conv_lstm.__repr__())

ConvLSTM(
  (cell_list): ModuleList(
    (0): ConvLSTMCell(
      (conv): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)


In [9]:
conv_gru = ConvGRU(
    in_channels=256,
    hidden_channels=[256], # hparam
    kernel_size=[3,3],
    batch_first=True,
    bias=True,
    return_all_layers=False,
)

conv_gru.eval()

print(conv_gru.__repr__())

ConvGRU(
  (cell_list): ModuleList(
    (0): ConvGRUCell(
      (conv): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (conv_ct): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
)


In [10]:
from torch.nn import GRU

torch_gru = GRU(
    input_size=256,
    hidden_size=256,
    num_layers=1,
    batch_first=True,
    bidirectional=False,
)   
print(torch_gru.__repr__())

GRU(256, 256, batch_first=True)


In [11]:
# B, C, H, W
# 1, 256, 200, 200
batch_size = 6
queue_length = 3
curr_feat = torch.rand(batch_size, 256, 200, 200)
print("curr_feat.shape", curr_feat.shape)
prev_feat = [torch.rand(batch_size, 256, 200, 200) for _ in range(queue_length - 1)]

# stack the previous features -> B, T, C, H, W
prev_feat = torch.stack(prev_feat, dim=1)
# stack the current features -> B, C, H, W -> B, 1, C, H, W
print("prev_feat.shape", prev_feat.shape)
curr_feat = curr_feat.unsqueeze(1)
print("curr_feat.shape", curr_feat.shape)
x = torch.cat((prev_feat, curr_feat), dim=1)
print("x.shape", x.shape)

curr_feat.shape torch.Size([6, 256, 200, 200])
prev_feat.shape torch.Size([6, 2, 256, 200, 200])
curr_feat.shape torch.Size([6, 1, 256, 200, 200])
x.shape torch.Size([6, 3, 256, 200, 200])


In [12]:
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp

In [19]:
n_params_conv_lstm = get_n_params(conv_lstm)
n_params_conv_gru = get_n_params(conv_gru)
n_params_torch_gru = get_n_params(torch_gru)
gru_decrease_perc = (n_params_conv_lstm - n_params_conv_gru) / n_params_conv_gru * 100
torch_gru_decrease_perc = (n_params_conv_lstm - n_params_torch_gru) / n_params_torch_gru * 100
# increase in percetange from gru to lstm
print("n_params_conv_lstm", n_params_conv_lstm, "100%")
print("n_params_conv_gru", n_params_conv_gru, f"{100 - gru_decrease_perc:.2f}%")
print("n_params_torch_gru", n_params_torch_gru, f"{100 - torch_gru_decrease_perc:.2f}%")

n_params_conv_lstm 4719616 100%
n_params_conv_gru 3539712 66.67%
n_params_torch_gru 394752 -995.59%


In [None]:
curr_time = time.time()
out = conv_lstm(x)
print("conv_lstm time", time.time() - curr_time)
print(out.shape)

In [None]:
curr_time = time.time()
out = conv_gru(x)
print("conv_gru time", time.time() - curr_time)
print(out.shape)

In [20]:
curr_time = time.time()
out = torch_gru(x)
print("torch_gru time", time.time() - curr_time)
print(out[0].shape)

RuntimeError: input must have 3 dimensions, got 5