In [7]:
import torch
import torch.nn as  nn
import torch.nn.functional as F
import torchvision 
from torchvision.models import resnet50,resnet18, ResNet50_Weights
    
class Convolution(nn.Module):
    def __init__(self,input = 3, output = 64):
        super().__init__()
        self.conv = nn.Conv2d(input,output, kernel_size=3)
        self.batch_norm = nn.BatchNorm2d(output)
        self.relu = nn.LeakyReLU()
    def forward(self,x):
        return self.relu(self.batch_norm(self.conv(x)))
    
class TwoStreamCNN(nn.Module):
    def __init__(self,num_classes,type='tsma'):
        super().__init__()
        self.conv1 = Convolution()
        self.conv2 = Convolution()
        self.conv3 = Convolution(128,3)
        self.type = type
        self.num_classes= num_classes
        self.relu = nn.LeakyReLU()
        #self.blockend = Resnet.ResNet50(num_classes= 29, channels= 128)
        self.block()
    def block(self):
        
        self.model = resnet50(pretrained = True)
            
        for param in self.model.parameters():
            param.requires_grad = True  
            
        fc_inputs = self.model.fc.in_features
        self.model.conv1 = nn.Conv2d(128, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.model.fc = nn.Sequential(
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(fc_inputs, self.num_classes), # Since 29 possible outputs
            nn.LogSoftmax(dim=1) # For using NLLLoss()
        )
    def forward(self, streamA, streamB):
        ht = self.conv1(streamA)
        ht1 = self.conv2(streamB)
        z = torch.add(ht, ht1)

        if self.type == 'tsma':
            y = torch.cat((z, ht), dim=1)
        elif self.type == 'tsmb':
            y = torch.cat((z, ht1), dim=1)
        yhat = self.model(y)
        return yhat

In [8]:
# Install torchinfo if it's not available, import it if it is
model_0 = TwoStreamCNN(29)
try: 
    import torchinfo
except:
    !pip install torchinfo
    import torchinfo
    
from torchinfo import summary
summary(model_0, input_size=[(1,3,226,226),(1,3,226,226)], col_names=['input_size', 'output_size', 'num_params',"params_percent", "kernel_size", "mult_adds", "trainable"]) # do a test pass through of an example input size 

Layer (type:depth-idx)                        Input Shape               Output Shape              Param #                   Param %                   Kernel Shape              Mult-Adds                 Trainable
TwoStreamCNN                                  [1, 3, 226, 226]          [1, 29]                   3,465                       0.01%                   --                        --                        True
├─Convolution: 1-1                            [1, 3, 226, 226]          [1, 64, 224, 224]         --                             --                   --                        --                        True
│    └─Conv2d: 2-1                            [1, 3, 226, 226]          [1, 64, 224, 224]         1,792                       0.01%                   [3, 3]                    89,915,392                True
│    └─BatchNorm2d: 2-2                       [1, 64, 224, 224]         [1, 64, 224, 224]         128                         0.00%                   --               