<a href="https://colab.research.google.com/github/lvllvl/segmentation10k/blob/main/segmentation10k.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch.nn as nn
import torch
from pathlib import Path 

In [3]:
! git clone https://github.com/commaai/comma10k.git

Cloning into 'comma10k'...
remote: Enumerating objects: 60941, done.[K
remote: Counting objects: 100% (532/532), done.[K
remote: Compressing objects: 100% (484/484), done.[K
remote: Total 60941 (delta 157), reused 352 (delta 48), pack-reused 60409[K
Receiving objects: 100% (60941/60941), 7.46 GiB | 38.10 MiB/s, done.
Resolving deltas: 100% (8190/8190), done.
Checking out files: 100% (20598/20598), done.


In [29]:
dir_comma10k = Path( '/content/comma10k/')

# masks, imgs folders --> use both of those folders 

In [13]:
def double_conv( in_c, out_c ): 
    # input channel, output channel 
    conv = nn.Sequential( 
        nn.Conv2d( in_c, out_c, kernel_size=3), 
        nn.ReLU( inplace=True ), # what does this do?
        
        nn.Conv2d( out_c, out_c, kernel_size=3), 
        nn.ReLU( inplace=True ) # what does this do?
    )
    return conv

def crop_img( tensor, target_tensor ): 
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2 
    return tensor[ :,:,delta:tensor_size-delta, delta:tensor_size-delta ]


class UNet( nn.Module ): 

    def __init__( self ): 
        super( UNet, self ).__init__() # what does this do?

        self.max_pool_2x2 = nn.MaxPool2d( kernel_size=2, stride=2 ) # define this 1x, use it multiple times 
    
        self.down_conv_1 = double_conv( 1, 64)
        self.down_conv_2 = double_conv( 64, 128)
        self.down_conv_3 = double_conv( 128, 256)
        self.down_conv_4 = double_conv( 256, 512)
        self.down_conv_5 = double_conv( 512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d( 
            in_channels = 1024,
            out_channels = 512, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_1 = double_conv( 1024, 512 )
             
        self.up_trans_2 = nn.ConvTranspose2d( 
            in_channels = 512,
            out_channels = 256, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_2 = double_conv( 512, 256 )
        
        self.up_trans_3 = nn.ConvTranspose2d( 
            in_channels = 256,
            out_channels = 128, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_3 = double_conv( 256, 128 )
        
        self.up_trans_4 = nn.ConvTranspose2d( 
            in_channels = 128,
            out_channels = 64, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_4 = double_conv( 128, 64 )


        self.out = nn.Conv2( 
            in_channels = 64, 
            out_channels = 5, # increase out channels based on how many classes youwant to segement 
            kernel_size = 1 # 2d conv, w kernel size of 1
        )

    def forward( self, image ):
        # bs, c, h, w --> batch size, channel, height, width 
        # Encoder file 
        x1 = self.down_conv_1( image ) # 
        x2 = self.max_pool_2x2( x1 )
        x3 = self.down_conv_2( x2 ) # 
        x4 = self.max_pool_2x2( x3 ) 
        x5 = self.down_conv_3( x4 ) # 
        x6 = self.max_pool_2x2( x5 ) 
        x7 = self.down_conv_4( x6 ) # 
        x8 = self.max_pool_2x2( x7 ) 
        x9 = self.down_conv_5( x8 ) # 
        print( x9.size() )

        # decoder part 
        x =  self.up_trans_1( x9 ) 
        y = crop_img( x7, x ) 
        x = self.up_conv_1( torch.cat( [x, y], 1) )
        
        x =  self.up_trans_2( x ) 
        y = crop_img( x5, x ) 
        x = self.up_conv_2( torch.cat( [x, y], 1) )

        x =  self.up_trans_3( x ) 
        y = crop_img( x3, x ) 
        x = self.up_conv_3( torch.cat( [x, y], 1) )

        x =  self.up_trans_4( x9 ) 
        y = crop_img( x1, x ) 
        x = self.up_conv_4( torch.cat( [x, y], 1) )

        x = self.out( x ) 
        print( x.size() ) 
        return x 


if __name__ == "__main__": 
    image = torch.rand( ( 1, 1, 572, 572 )) 
    model = UNet() 

    print( model( image ) ) 





torch.Size([1, 1024, 28, 28])
torch.Size([1, 512, 56, 56])
torch.Size([1, 512, 64, 64])
None
