In [1]:
import torch
import torch.nn as nn 

In [2]:
class CnnEncoder(nn.Module):
    def __init__(self,input_channel, height, width, embedding_dim):
        super().__init__()
        self.features = nn.Sequential(
            #Block 1 
            nn.Conv2d(input_channel,32,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            #Block 2
            nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            #Block 3 
            nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
        )
        with torch.no_grad():
            test = torch.randn(1,input_channel,height,width)
            output = self.features(test)
            flatten_size = output.view(1, -1).shape[1]
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flatten_size,embedding_dim),
            nn.ReLU()
        ) 
    def forward (self,x):
        out = self.features(x)
        out = self.fc(out)
        return out


In [3]:
class cnn(nn.Module):
    def __init__(self, input_channel, height, width, embedding_dim):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channel,32,kernel_size=3, stride=1, padding=1)
        self.act1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(2)
        self.conv2 = nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1)
        self.act2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(2)
        self.conv3 = nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1)
        self.act3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(2)
        with torch.no_grad():
            test = torch.randn(1,input_channel,height,width)
            out = self.pool1(self.act1(self.conv1(test)))
            out = self.pool2(self.act2(self.conv2(out)))
            out = self.pool3(self.act3(self.conv3(out)))
            flaten_size = out.view(1, -1).shape[1]
        self.fl = nn.Flatten()
        self.ln = nn.Linear(flaten_size,256)
        self.act4 = nn.ReLU()
        self.ln2 = nn.Linear(256,embedding_dim)
    def forward(self,x):
        out = self.pool1(self.act1(self.conv1(x)))
        out = self.pool2(self.act2(self.conv2(out)))
        out = self.pool3(self.act3(self.conv3(out)))
        out = self.fl(out)
        out = self.ln(out)
        out = self.act4(out)
        out = self.ln2(out)
        return out

In [4]:
# Test
torch.manual_seed(42) 
img = torch.randn(3, 2, 200, 50)  
encoder = CnnEncoder(2,200,50,256)
# print(img)
output = encoder.forward(img)  
print("Output shape:", output)  


Output shape: tensor([[0.1319, 0.0000, 0.0000, 0.0000, 0.0000, 0.0659, 0.0000, 0.1941, 0.0000,
         0.0000, 0.0965, 0.0613, 0.0000, 0.0000, 0.0000, 0.0000, 0.0949, 0.0000,
         0.0160, 0.0000, 0.0000, 0.1027, 0.1919, 0.0000, 0.0000, 0.0000, 0.0000,
         0.1687, 0.0000, 0.0737, 0.0000, 0.1156, 0.0000, 0.0000, 0.1841, 0.0000,
         0.0000, 0.0909, 0.0000, 0.0000, 0.0000, 0.0000, 0.0741, 0.0000, 0.0366,
         0.1921, 0.0053, 0.0272, 0.0223, 0.1074, 0.0000, 0.0000, 0.0185, 0.0000,
         0.0239, 0.0000, 0.0000, 0.0456, 0.1598, 0.0000, 0.0000, 0.0000, 0.1913,
         0.2194, 0.0000, 0.0009, 0.0479, 0.1017, 0.0000, 0.0000, 0.0000, 0.0000,
         0.2490, 0.0344, 0.0000, 0.0000, 0.0967, 0.0266, 0.0390, 0.0125, 0.0535,
         0.0000, 0.0000, 0.0320, 0.0000, 0.1389, 0.0312, 0.0000, 0.0521, 0.0758,
         0.0317, 0.0661, 0.0576, 0.2269, 0.0979, 0.0477, 0.1148, 0.0000, 0.0000,
         0.0000, 0.0000, 0.0246, 0.0000, 0.0000, 0.0997, 0.0000, 0.0333, 0.0500,
         0.098