In [10]:
import math 
from typing import Dict,List,Tuple,Union 

import torch 
import torch.nn as nn 
import torch.nn.functional as F

In [12]:
x = torch.randn(16,1,16,3,224,224)

In [11]:
class TimeDistributed(nn.Module):

    def __init__(self, module: nn.Module, batch_first:bool = False):
        super().__init__()
        self.module = module 
        self.batch_first = batch_first 
    
    def forward(self,x):

        if len(x.size())<=2:
            return self.module(x)
        
        x_reshape = x.contiguous().view(-1,x.size(-1))
        # (samples * timesteps, input_size)

        y=self.module(x_reshape)

        # We have to reshape y 

        if self.batch_first:
            y = y.contiguous().view(x.size(0),-1,y.size(-1))
        else:
            y = y.view(-1,x.size(1),y.size(-1))
        return y 
        


In [13]:
tdd = TimeDistributed(x)

In [16]:
class CNN(nn.Module):

    def __init__(self):
        super(CNN,self).__init__()
        self.conv1 = nn.Conv2d(1,10,kernel_size=5)
        self.conv2 = nn.Conv2d(10,20,kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320,50)
        self.fc2 = nn.Linear(50,10)
    
    def forward(self,x):
        x = F.relu(F.max_pool2d(self.conv1(x),2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)),2))
        x = x.view(-1,320)

        return x

In [17]:
class Combine(nn.Module):
    def __init__(self):
        super(Combine,self).__init__()
        self.cnn=CNN()

In [20]:
x = torch.tensor([[1,2,3],[4,5,6]])
print(x.view(3,2))
print(x.permute(1,0))

tensor([[1, 2],
        [3, 4],
        [5, 6]])
tensor([[1, 4],
        [2, 5],
        [3, 6]])
