# This notebook will be used to test the correctness of my implementation

## Testing Encoder block

In [39]:
import sys

sys.path.append("../src")
from encoder import Encoder

In [40]:
import torch

encoder = Encoder(channels=(3, 64, 128, 256, 512, 1024))

In [41]:
print(encoder)

Encoder(
  (encoder_blocks): ModuleList(
    (0): Block(
      (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): Block(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): Block(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    )
    (3): Block(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    )
    (4): Block(
      (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
    )
  )
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_

In [42]:

x = torch.randn(1, 3, 572, 572)  # Batch size of 1, 3 input channels, 572x572 image
features = encoder(x)
for f in features:
    print(f.shape)
    

torch.Size([1, 64, 568, 568])
torch.Size([1, 128, 280, 280])
torch.Size([1, 256, 136, 136])
torch.Size([1, 512, 64, 64])
torch.Size([1, 1024, 28, 28])


## Testing Encoder block

In [43]:
from decoder import Decoder

In [44]:
decoder = Decoder(channels=(1024, 512, 256, 128, 64))

In [47]:
print(decoder)

Decoder(
  (upconvs): ModuleList(
    (0): ConvTranspose2d(1024, 512, kernel_size=(2, 2), stride=(2, 2))
    (1): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (2): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (3): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
  )
  (decoder_blocks): ModuleList(
    (0): Block(
      (conv1): Conv2d(1024, 512, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
    )
    (1): Block(
      (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
    )
    (2): Block(
      (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    )
    (3): Block(
      (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1))
      (relu): ReLU()
      (conv2): 

In [49]:
x = features[-1]
x.shape

torch.Size([1, 1024, 28, 28])

In [57]:
# feature_maps = [features[-i] for i in range(2,len(features)+1)]
feature_maps = features[::-1]

In [58]:
for fm in feature_maps:
    print(fm.shape)

torch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 64, 64])
torch.Size([1, 256, 136, 136])
torch.Size([1, 128, 280, 280])
torch.Size([1, 64, 568, 568])


In [59]:
output = decoder(feature_maps[0], feature_maps[1:])
print(output.shape) 

torch.Size([1, 64, 388, 388])


## Testing UNet

In [61]:
from unet import UNet

In [63]:
model = UNet(ouput_dim=(256, 256), encoder_channels=(3, 64, 128, 256, 512, 1024), decoder_channels=(1024, 512, 256, 128, 64))
print(model)

UNet(
  (encoder): Encoder(
    (encoder_blocks): ModuleList(
      (0): Block(
        (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
      )
      (1): Block(
        (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
      )
      (2): Block(
        (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1))
      )
      (3): Block(
        (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
      )
      (4): Block(
        (conv1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1))
        (relu): ReLU()
        (conv2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1))
      )
    )


In [64]:

x = torch.randn(1, 3, 572, 572)  # Example input tensor
output = model(x)
print(output.shape)  # Expected output shape


torch.Size([1, 1, 256, 256])
