# Denoising Testing 2D
- CNN - regular convolutional neural network
- ConvNN_2D - all sample
- ConvNN_2D random sample
- ConvNN_2D_spatial - spatial sample
- Branching Network - CNN + ConvNN_2D

In [2]:
# Torch
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch import optim 


# Train + Data 
import sys 
sys.path.append('../Layers')
from Conv1d_NN import *
from Conv2d_NN import *

from Conv1d_NN_spatial import * 
from Conv2d_NN_spatial import * 

sys.path.append('../Data')
from CIFAR10 import CIFAR10_denoise


sys.path.append('../Train')
from train2d import * 


## I. Models

In [None]:
branching_denoiser = nn.Sequential(
   BranchingNetwork(in_ch = 1, out_ch1 = 16, out_ch2=16, kernel_size = 3), 
   BranchingNetwork(in_ch = 16, out_ch1 = 8, out_ch2=8, kernel_size = 3),
   BranchingNetwork(in_ch = 8, out_ch1 = 4, out_ch2=4, kernel_size =3), 
   BranchingNetwork(in_ch = 4, out_ch1 = 2, out_ch2=2, kernel_size =3), 
   BranchingNetwork(in_ch = 2, out_ch1 = 1, out_ch2=1, kernel_size =3) 
)

summary(branching_denoiser, (1, 40))

In [None]:
# Regular CNN 2D denoising
CNN_denoising = nn.Sequential(
    nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),

    nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
    nn.ReLU(),

    nn.Flatten(),
    nn.Linear(32768, 1024),
    nn.ReLU(),
    nn.Linear(1024, 10)
).to('cpu')

from torchsummary import summary 
summary(CNN_denoising, (3, 32, 32))



class BranchingNetwork(nn.Module):
    def __init__(self, in_ch, out_ch1, out_ch2, kernel_size):
        super().__init__()
        self.kernel_size = kernel_size
        
        self.branch1 = nn.Sequential(
            nn.Conv1d(in_ch, out_ch1, kernel_size),
            nn.ReLU()
        )
        self.branch2 = nn.Sequential(
            Conv1d_NN(in_ch, out_ch2, K = kernel_size, stride = kernel_size), 
            nn.ReLU()
        )
        self.reduce_channels = nn.Conv1d(out_ch1 + out_ch2, (out_ch1 + out_ch2) // 2, 1)


    def forward(self, x):
        
        x1 = self.branch1(x)
        
        x2 = self.branch2(x)
        
        ## Calculate expected Output size of x2 
        expected_x1_size = x2.size(2) 
        # print(expected_x1_size)
        
        ## Calculate padding for x1 to match x2's size   
        total_padding = expected_x1_size - x1.size(2)
        # print(total_padding)
        
        left_padding = total_padding // 2
        right_padding = total_padding - left_padding
        
        ## Apply dynamic padding to x1
        x1 = F.pad(x1, (left_padding, right_padding), 'constant', 0)
        
        ## Concatenate the outputs along the channel dimension
        concat = torch.cat([x1, x2], dim=1)
        # print(concat.shape)
        
        ## Reduce the number of channels
        reduce = self.reduce_channels(concat)
        # print(reduce.shape)
        return reduce

In [3]:
# ConvNN 2d all sample

In [4]:
# ConvNN 2d random sample


In [5]:
# ConvNN 2d spatial sample 

In [6]:
# Branching Network - CNN + ConvNN 2d