# Unet
This is a project about Image Segmentation. The dataset of the project is from DRIVE,a Retinal Vessel Segmentation datasset, and Unet network is used to finish this task.
# Data processing

In [4]:
import os
import cv2
from torchvision import transforms
def read_image(train):
    path=''
    if train:
        img_dir=path+'training/images/'
        label_dir=path+'training/1st_manual/'
        img=os.listdir(img_dir)
        label=os.listdir(label_dir)
        images=[img_dir+i for i in img]
        labels=[label_dir+i for i in label]
    else:
        img_dir=path+'test/images/'
        label_dir=path+'test/1st_manual/'
        img=os.listdir(img_dir)
        label=os.listdir(label_dir)
        images=[img_dir+i for i in img]
        labels=[label_dir+i for i in label]
    return images,labels
def crop(img,size):
    h,w,c=img.shape
    _w,_h=size
    # h w ratio not change
    scale=min(_h/h,_w/w)
    h=int(h*scale)
    w=int(w*scale)
    img=cv2.resize(img,(w,h),interpolation=cv2.INTER_CUBIC)
    top=(_h-h)//2
    left=(_w-w)//2
    bottom=_h-h-top
    right=_w-w-left
    #create a new img with color black on edge
    new_img=cv2.copyMakeBorder(img,top,bottom,left,right,cv2.BORDER_CONSTANT,value=(0,0,0))
    return new_img
def image_transform(data,label,size):
    data=crop(data,size)
    label=crop(label,size)
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
    ])
    data=transform(data)
    label=transform(label)
    return data,label

In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
class Drive(Dataset):
    def __init__(self,train,h,w,transform):
        self.h=h
        self.w=w
        self.transform=transform
        if train:
            self.data_list, self.label_list=read_image(train=True)
        else:
            self.data_list,self.label_list=read_image(train=False)
    def __getitem__(self,index):
        img_dir=self.data_list[index]
        label_dir=self.label_list[index]
        img=cv2.imread(img_dir)
        _,label=cv2.VideoCapture(label_dir).read()
        img=cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
        label=cv2.cvtColor(label,cv2.COLOR_BGR2RGB)
        img,label=self.transform(img,label,(self.h,self.w))
        return img,label
    def __len__(self):
        return len(self.data_list)

In [6]:
h=572
w=572
train_set=Drive(train=True,h=h,w=w,transform=image_transform)
train_loader=DataLoader(train_set,batch_size=20,shuffle=True)

# Unet Model

In [7]:
import torch
import torch.nn as nn
class Doubleconv(nn.Module):
    def __init__(self,in_channel,out_channel):
        super().__init__()
        self.conv=nn.Sequential(nn.Conv2d(in_channel,out_channel,kernel_size=3,padding=1),
                               nn.BatchNorm2d(out_channel),
                               nn.Dropout(0.3),
                               nn.ReLU(inplace=True),
                               nn.Conv2d(out_channel,out_channel,kernel_size=3,padding=1),
                               nn.BatchNorm2d(out_channel),
                               nn.Dropout(0.4),
                               nn.ReLU(inplace=True))

    def forward(self,x):
        return self.conv(x)
class Down(nn.Module):
    def __init__(self,in_channel,out_channel):
        super().__init__()
        self.maxpool=nn.Sequential(nn.MaxPool2d(2,2),
                                  Doubleconv(in_channel,out_channel))
    def forward(self,x):
        return self.maxpool(x)
class Up(nn.Module):
    def __init__(self,in_channel,out_channel,bilinear=True):
        super().__init__()
        if bilinear:
            self.up=nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
        else:
            self.up=nn.ConvTranspose2d(n_channel//2,in_channel//2,kernel_size=2,stride=2)
        self.conv=Doubleconv(in_channel,out_channel)
    def forward(self,x1,x2):
        x1=self.up(x1)
        diffY=torch.tensor([x2.size()[2]-x1.size()[2]])
        diffX=torch.tensor([x2.size()[3]-x1.size()[3]])
        x1=nn.functional.pad(x1,[diffX//2,diffX-diffX//2,diffY//2,diffY-diffY//2])
        x=torch.cat([x2,x1],dim=1)
        return self.conv(x)
class Outconv(nn.Module):
    def __init__(self,in_channel,out_channel):
        super(Outconv,self).__init__()
        self.conv=nn.Conv2d(in_channel,out_channel,kernel_size=1)
    def forward(self,x):
        return self.conv(x)

In [10]:
class Unet(nn.Module):
    def __init__(self,n_channels,bilinear=True):
        super(Unet,self).__init__()
        self.n_channel=n_channels
        self.bilinear=bilinear
        self.inc=Doubleconv(n_channels,64)
        self.down1=Down(64,128)
        self.down2=Down(128,256)
        self.down3=Down(256,512)
        self.down4=Down(512,512)
        #out_channel= next step in_channel
        self.up1=Up(1024,256,bilinear)
        self.up2=Up(512,128,bilinear)
        self.up3=Up(256,64,bilinear)
        self.up4=Up(128,64,bilinear)
        self.outc=Outconv(64,1)
        self.sigmoid=nn.Sigmoid()
    def forward(self,x):
        x1=self.inc(x)
        x2=self.down1(x1)
        x3=self.down2(x2)
        x4=self.down3(x3)
        x5=self.down4(x4)
        x=self.up1(x5,x4)
        x=self.up2(x,x3)
        x=self.up3(x,x2)
        x=self.up4(x,x1)
        x=self.outc(x)
        x=self.sigmoid(x)
        return logits


# training process

In [None]:
model=Unet(n_channels=3)
cri=torch.nn.CrossEntropyLoss()
optimizer=torch.optim.SGD(model.parameters(),lr=1e-4,weight_decay=1e-8,momentum=0.9)
best_loss=0.0
for epoch in range(10):
    total_loss=0.0
    for data,label in train_loader:
        optimizer.zero_grad()
        print(label.shape)
        output=model(data)
        print(output.shape)
        batch_loss=cri(output,label)
        batch_loss.backward()
        optimizer.step()
        
        total_loss+=batch_loss.item()
        
    print('epoch:',epoch,'loss:',total_loss/len(train_loader))

torch.Size([20, 3, 572, 572])
