In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader
from torchinfo import summary

In [11]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(ConvBlock, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, **kwargs)

    def forward(self, x):
        return F.relu(self.conv(x))

In [10]:
class Inception_Block(nn.Module):
    def __init__(
            self,
            in_channels,
            output_1x1,
            output_1x1_block2,
            output_3x3,
            output_5x5_reduce,
            output_5x5,
            output_pool,
    ):
            super(Inception_Block, self).__init__()
            self.branch1 = ConvBlock(in_channels, output_1x1, kernel_size = 1)
            self.branch2 = nn.Sequential(
            ConvBlock(in_channels, out_channels = output_1x1_block2, kernel_size = 1),
            ConvBlock(output_1x1_block2, output_3x3, kernel_size = 3, padding = 1)
            )
            self.branch3 = nn.Sequential(
                  ConvBlock(in_channels, out_channels = output_5x5_reduce, kernel_size = 1),
                  ConvBlock(in_channels= output_5x5_reduce, out_channels=output_5x5, kernel_size = 5, padding = 2),
            )
            self.branch4 = nn.Sequential(
                  nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
                  ConvBlock(in_channels, output_pool, kernel_size = 1)
            )
    def forward(self, x):
          first_block = self.branch1(x)
          second_block = self.branch2(x)
          third_block = self.branch3(x)
          fourth_block = self.branch4(x)
          output_concat = torch.cat([first_block, second_block, third_block, fourth_block], dim  = 1)

          return output_concat
         

In [34]:
class x_model(nn.Module):
    def __init__(self):
        super(x_model,self).__init__()
        self.conv_7x7 = nn.Conv2d(in_channels=1, out_channels=6,  kernel_size=7, stride=2)
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride = 2)        
        self.lrn_norm = nn.LocalResponseNorm(size=5,  alpha=0.0001, beta=0.75)
        self.conv_1x1 = nn.Conv2d(in_channels=6, out_channels= 12, kernel_size=1)
        self.conv3_3x3 = nn.Conv2d(in_channels=12, out_channels=6, kernel_size=3)
        self.inception3a = Inception_Block(in_channels=6,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
        self.inception3b = Inception_Block(in_channels=256,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
        self.inception4a = Inception_Block(in_channels=256,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
        

    def forward(self, x):
        x = self.max_pool(F.relu(self.conv_7x7(x)))
        x = self.lrn_norm(x)
        x = F.relu(self.conv_1x1(x))
        x = F.relu(self.conv3_3x3(x))
        x = self.max_pool(self.lrn_norm(x))
        x = self.inception3a(x)
        x = self.inception3b(x)
        x = self.max_pool(x)
        x = self.inception4a(x)
        return x

## Testing Dimensions of Inception Network and X_model

In [35]:
hep_data = torch.randn(1,1, 256, 256)
model_x = x_model()
output = model_x(hep_data)
output.shape

torch.Size([1, 256, 14, 14])

In [36]:
summary(model_x, input_size=(1, 1,256,256))

Layer (type:depth-idx)                   Output Shape              Param #
x_model                                  [1, 256, 14, 14]          --
├─Conv2d: 1-1                            [1, 6, 125, 125]          300
├─MaxPool2d: 1-2                         [1, 6, 62, 62]            --
├─LocalResponseNorm: 1-3                 [1, 6, 62, 62]            --
├─Conv2d: 1-4                            [1, 12, 62, 62]           84
├─Conv2d: 1-5                            [1, 6, 60, 60]            654
├─LocalResponseNorm: 1-6                 [1, 6, 60, 60]            --
├─MaxPool2d: 1-7                         [1, 6, 29, 29]            --
├─Inception_Block: 1-8                   [1, 256, 29, 29]          --
│    └─ConvBlock: 2-1                    [1, 64, 29, 29]           --
│    │    └─Conv2d: 3-1                  [1, 64, 29, 29]           448
│    └─Sequential: 2-2                   [1, 128, 29, 29]          --
│    │    └─ConvBlock: 3-2               [1, 96, 29, 29]           672
│    │    └

In [33]:
inception3a = Inception_Block(in_channels=6,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
output3a = inception3a(output)
print(f'Inception 3a Dimensions: {output3a.shape}')
inception3b = Inception_Block(in_channels=256,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
output3b = inception3b(output3a)
print(f'Inception 3b Dimensions: {output3b.shape}')
pool = nn.MaxPool2d(stride=2, kernel_size=3)
max_pool_output = pool(output3b)
print(f'Output after MaxPool: {max_pool_output.shape}')
inception4a = Inception_Block(in_channels=256,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
output4a = inception4a(max_pool_output)
print(f'Inception 4a Dimensions: {output4a.shape}')

Inception 3a Dimensions: torch.Size([1, 256, 29, 29])
Inception 3b Dimensions: torch.Size([1, 256, 29, 29])
Output after MaxPool: torch.Size([1, 256, 14, 14])
Inception 4a Dimensions: torch.Size([1, 256, 14, 14])


In [56]:
class combineXY(nn.Module):
    def __init__(self):
        super(combineXY, self).__init__()
        self.x_model = x_model()
        self.y_model = x_model()
        # concatenating both models gives us channels of 512
        self.final_inception = Inception_Block(in_channels=512,output_1x1=64, output_1x1_block2=96, output_3x3=128, output_5x5_reduce=16, output_5x5= 32, output_pool=32)
        self.avg_pooling = nn.AvgPool2d(kernel_size=(6,5))
        # self.softmax = nn.Softmax(dim=2)
    def forward(self, x_data, y_data):
        x = self.x_model(x_data)
        y = self.y_model(y_data)
        concat =  torch.cat([x, y], dim  = 1)
        combined_data = self.final_inception(concat)
        combined_data = self.avg_pooling(combined_data)
        #combined_data = self.softmax(combined_data)
        return combined_data
        return 0

In [58]:
x_hep_data = torch.randn(1,1, 256, 256)
y_hep_data = torch.randn(1,1, 256, 256)

combined_model = combineXY()
output_combined = combined_model(x_hep_data, y_hep_data)

In [59]:
output_combined.shape

torch.Size([1, 256, 2, 2])

In [44]:
summary(combined_model, input_size=[(1, 1,256,256),(1, 1,256,256)])

Layer (type:depth-idx)                   Output Shape              Param #
combineXY                                [1, 512, 14, 14]          --
├─x_model: 1-1                           [1, 256, 14, 14]          --
│    └─Conv2d: 2-1                       [1, 6, 125, 125]          300
│    └─MaxPool2d: 2-2                    [1, 6, 62, 62]            --
│    └─LocalResponseNorm: 2-3            [1, 6, 62, 62]            --
│    └─Conv2d: 2-4                       [1, 12, 62, 62]           84
│    └─Conv2d: 2-5                       [1, 6, 60, 60]            654
│    └─LocalResponseNorm: 2-6            [1, 6, 60, 60]            --
│    └─MaxPool2d: 2-7                    [1, 6, 29, 29]            --
│    └─Inception_Block: 2-8              [1, 256, 29, 29]          --
│    │    └─ConvBlock: 3-1               [1, 64, 29, 29]           448
│    │    └─Sequential: 3-2              [1, 128, 29, 29]          111,392
│    │    └─Sequential: 3-3              [1, 32, 29, 29]           12,944
│  