<a href="https://colab.research.google.com/github/lvllvl/segmentation10k/blob/main/Seg10k_050421.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#from google.colab import drive
#drive.mount('/content/gdrive')

## Imports

In [57]:
import os 
from pathlib import Path 
import pandas as pd 
from torchvision.io import read_image

import torch 
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from collections import defaultdict 

## Create Dataset

In [None]:
class CustomImageDataset( Dataset ):

    def __init__( self, annotations_file, img_dir, mask_dir,  transform=None,
            target_transform=None ):
        self.img_labels = pd.read_csv( annotations_file ) # the whole dataframe
        self.img_dir = img_dir # directory for the images 
        self.mask_dir = mask_dir # directory for the masks 

        self.transform = transform
        self.target_transform = target_transform

    def __len__( self ):
        return len( self.img_labels )

    def __getitem__( self, idx ):
        img_path = os.path.join( self.img_dir, self.img_labels.iloc[ idx, 1 ] )
        mask_path = os.path.join( self.mask_dir, self.img_labels.iloc[ idx, 1 ]) 

        image = read_image( img_path )
        mask = read_image( mask_path )

        if self.transform:
            image = self.transform( image )
        
        if self.target_transform:
            mask = self.target_transform( mask )
        
        sample = {'image': image, "mask": mask } 
        return sample


    # def organize_files():
    #     file_dir = self.img_dir # directory to all images 
    #     file_path = Path( file_dir ) 


In [None]:
def organize_files(): 
    file_dir = 'drive/MyDrive/projects/datasets/comma10k/imgs'
    file_path = Path( file_dir )
    fn_arr = [] 
    
    # loop through files 
    for fn in file_path.iterdir():
        fn_arr.insert( 0, fn.name )
    
    
    df_dic = {'filename': fn_arr, 
            'maskpath': 'drive/MyDrive/projects/datasets/comma10k/masks', 
            'imagepath': 'drive/MyDrive/projects/datasets/comma10k/imgs'}
    
    df = pd.DataFrame( df_dic ) # convert into dataframe
    filename = 'filesData.csv'
    df.to_csv( filename ) 

    return df 

In [None]:
df = organize_files() # create dataframe, save csv  
dframe = pd.read_csv( 'filesData.csv' ) # open csv 

In [None]:
# create a dataset 
dataset = CustomImageDataset('filesData.csv', 
                             'drive/MyDrive/projects/datasets/comma10k/imgs', 
                             'drive/MyDrive/projects/datasets/comma10k/masks' )

dataset

<__main__.CustomImageDataset at 0x7f1356bc0f50>

In [None]:
dataset[3]['image']

tensor([[[ 0,  1,  0,  ..., 10, 24, 24],
         [ 0,  0,  0,  ...,  9, 24, 24],
         [ 0,  0,  0,  ..., 13, 29, 29],
         ...,
         [25, 27, 26,  ..., 37, 32, 38],
         [ 2,  4, 10,  ..., 41, 40, 43],
         [ 1,  2,  6,  ..., 42, 42, 43]],

        [[ 0,  0,  6,  ..., 25, 19, 19],
         [ 0,  0,  4,  ..., 24, 19, 19],
         [ 0,  0,  2,  ..., 21, 18, 18],
         ...,
         [ 0,  0,  5,  ...,  4,  2,  8],
         [ 0,  2, 11,  ...,  0,  0,  1],
         [ 0,  0,  7,  ...,  0,  0,  1]],

        [[11, 19, 19,  ..., 11,  5,  5],
         [10, 16, 17,  ..., 10,  5,  5],
         [10, 14, 13,  ...,  9,  3,  3],
         ...,
         [ 0,  0,  0,  ...,  4,  3,  9],
         [ 0,  0,  3,  ..., 34, 29, 32],
         [ 0,  0,  0,  ..., 35, 31, 32]]], dtype=torch.uint8)

In [None]:
dataset[3]['image'].size()

torch.Size([3, 874, 1164])

In [58]:
x = torch.tensor( [[1,2], [2,3], [5,6]])

x.size()

torch.Size([3, 2])

## Model

In [None]:
def double_conv( in_c, out_c ): 
    # input channel, output channel 
    conv = nn.Sequential( 
        nn.Conv2d( in_c, out_c, kernel_size=3), 
        nn.ReLU( inplace=True ), # what does this do?
        
        nn.Conv2d( out_c, out_c, kernel_size=3), 
        nn.ReLU( inplace=True ) # what does this do?
    )
    return conv

def crop_img( tensor, target_tensor ): 
    target_size = target_tensor.size()[2]
    tensor_size = tensor.size()[2]
    delta = tensor_size - target_size
    delta = delta // 2 
    return tensor[ :,:,delta:tensor_size-delta, delta:tensor_size-delta ]


class UNet( nn.Module ): 

    def __init__( self ): 
        super( UNet, self ).__init__() # what does this do?

        self.max_pool_2x2 = nn.MaxPool2d( kernel_size=2, stride=2 ) # define this 1x, use it multiple times 
    
        self.down_conv_1 = double_conv( 1, 64)
        self.down_conv_2 = double_conv( 64, 128)
        self.down_conv_3 = double_conv( 128, 256)
        self.down_conv_4 = double_conv( 256, 512)
        self.down_conv_5 = double_conv( 512, 1024)

        self.up_trans_1 = nn.ConvTranspose2d( 
            in_channels = 1024,
            out_channels = 512, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_1 = double_conv( 1024, 512 )
             
        self.up_trans_2 = nn.ConvTranspose2d( 
            in_channels = 512,
            out_channels = 256, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_2 = double_conv( 512, 256 )
        
        self.up_trans_3 = nn.ConvTranspose2d( 
            in_channels = 256,
            out_channels = 128, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_3 = double_conv( 256, 128 )
        
        self.up_trans_4 = nn.ConvTranspose2d( 
            in_channels = 128,
            out_channels = 64, 
            kernel_size = 2,
            stride = 2 
        )
        self.up_conv_4 = double_conv( 128, 64 )


        self.out = nn.Conv2( 
            in_channels = 64, 
            out_channels = 5, # increase out channels based on how many classes youwant to segement 
            kernel_size = 1 # 2d conv, w kernel size of 1
        )

    def forward( self, image ):
        # bs, c, h, w --> batch size, channel, height, width 
        # Encoder file 
        x1 = self.down_conv_1( image ) # 
        x2 = self.max_pool_2x2( x1 )
        x3 = self.down_conv_2( x2 ) # 
        x4 = self.max_pool_2x2( x3 ) 
        x5 = self.down_conv_3( x4 ) # 
        x6 = self.max_pool_2x2( x5 ) 
        x7 = self.down_conv_4( x6 ) # 
        x8 = self.max_pool_2x2( x7 ) 
        x9 = self.down_conv_5( x8 ) # 
        print( x9.size() )

        # decoder part 
        x =  self.up_trans_1( x9 ) 
        y = crop_img( x7, x ) 
        x = self.up_conv_1( torch.cat( [x, y], 1) )
        
        x =  self.up_trans_2( x ) 
        y = crop_img( x5, x ) 
        x = self.up_conv_2( torch.cat( [x, y], 1) )

        x =  self.up_trans_3( x ) 
        y = crop_img( x3, x ) 
        x = self.up_conv_3( torch.cat( [x, y], 1) )

        x =  self.up_trans_4( x9 ) 
        y = crop_img( x1, x ) 
        x = self.up_conv_4( torch.cat( [x, y], 1) )

        x = self.out( x ) 
        print( x.size() ) 
        return x 


if __name__ == "__main__": 
    image = torch.rand( ( 1, 1, 572, 572 )) 
    model = UNet() 

    print( model( image ) )