In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
torch.manual_seed(0)

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    # image_shifted = (image_tensor + 1) / 2\
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=4)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [2]:
class EncoderBlock(nn.Module):
    def __init__(self , inn ) :
        super(EncoderBlock , self).__init__()
        self.ContractingBlock = nn.Sequential(
            nn.Conv2d(inn , 2*inn , 3 ), nn.ReLU() ,
            nn.Conv2d(2*inn , 2*inn , 3) , nn.ReLU(), 
            nn.MaxPool2d(kernel_size=(2,2), stride=2)
        )
    def forward(self , images):
        return self.ContractingBlock(images)
    def get_Self(self):
        return self

In [11]:
for i in range(4):
        e = EncoderBlock(64*(2**i))
        t = torch.randn((1, 64*(2**i), 256, 256))
        print((e.ContractingBlock(t)).shape)

torch.Size([1, 128, 126, 126])
torch.Size([1, 256, 126, 126])
torch.Size([1, 512, 126, 126])
torch.Size([1, 1024, 126, 126])


In [25]:
shapes = [(1, 128, 126, 126) , (1, 256, 126, 126) , (1, 512, 126, 126) ,(1, 1024, 126, 126)]


In [33]:
class DecoderBlock(nn.Module):
    def __init__(self , inn ):
        super(DecoderBlock , self).__init__()
        self.upsampling = nn.Sequential(
            nn.Upsample(scale_factor=2 , mode = 'bilinear' , align_corners=True) , 
            nn.Conv2d(inn , inn//2 , 2 ),nn.ReLU() , 
        )
        
        self.ExpandingBlock = nn.Sequential( 
            nn.Conv2d(inn , inn//2 , 3 ), nn.ReLU(),
            nn.Conv2d(inn//2 , inn//2 , 3) , nn.ReLU()
        )

    def crop_features(self ,image , new_shape):

        x  , y = image.shape[2] - new_shape[2]   , image.shape[3] -new_shape[3]
        x,y= x//2 , y // 2
        image = image[...,x:x+new_shape[2],y:y+new_shape[3]]

        return image

    def forward(self , image ,encoder_feature):
        x = self.upsampling(image)
        print(x.shape , encoder_feature.shape)

        x = torch.cat([x , self.crop_features(encoder_feature , x.shape)] , axis = 1)
        
        return self.ExpandingBlock(x)
    def get_Self(self):
        return self




In [35]:
for i in range(4,0,-1):
        e = DecoderBlock(64*(2**i))
        t = torch.randn((1, 64*(2**i), 126, 126))
        print(i , t.shape)
        t2 = torch.randn(shapes[i-2])
        e = e(t , t2)

4 torch.Size([1, 1024, 126, 126])
torch.Size([1, 512, 251, 251]) torch.Size([1, 512, 126, 126])


RuntimeError: ignored

In [4]:

class final_head(nn.Module):
    def __init__(self , inn , out):
        super(final_head , self).__init__()
        self.Map = nn.Conv2d(inn , out , 1)
                             
    def forward(self , x):
        x = self.Map(x)
        return x

In [9]:
class U_Net(nn.Module):
    def __init__(self , inn , out ,hidden_channels = 64 , blocks = 4):
        super(U_Net  , self).__init__()
        self.Map_up = final_head(inn ,hidden_channels)
        self.contractingPath = nn.ModuleList([EncoderBlock(hidden_channels* (2**i)) for i in range(blocks)])
        self.ExpandingPath = nn.ModuleList([DecoderBlock(hidden_channels*(2**i)) for i in range(blocks,0,-1)])
        self.Map_down = final_head(hidden_channels , out)
        self.blocks = blocks
        self.e = {}
    def forward(self , image):
        image = self.Map_up(image)
        enc_features = {}
        for i in range(self.blocks):
            image  = self.contractingPath[i](image)
            enc_features[f"contracting_block {i}"] = image
        for i in range(self.blocks):
            image = self.ExpandingPath[i](image ,enc_features[f"contracting_block {2-i}"])
        image = self.Map_down(image)
        self.e = enc_features
        return image
    def get_self(self):
        return self


In [10]:

unet = U_Net(1,3)
x = torch.randn(1, 1, 256, 256)
unet(x).shape

#>> (1, 3, 117, 117)

torch.Size([1, 512, 23, 23]) torch.Size([1, 512, 28, 28])


RuntimeError: ignored

In [None]:
c = unet.get_self()
c = c.e
c.keys()


dict_keys([])