# Unet을 이용한 Image Segmentation 

레퍼런스 : https://colab.research.google.com/github/dhrim/MDC_2021/blob/master/material/deep_learning/unet_segmentation_multi_label.ipynb#scrollTo=cHnifLferK9Z

In [None]:
# from torchvision.datasets import VOCSegmentation

# a = VOCSegmentation(root='./data',year='2012',image_set='train',download=True)

# 데이터 로드 

In [1]:
%cd '/data/Pytorch_vision/data'         

/data/Pytorch_vision/data


In [None]:
import numpy as np 
import torch 
from torch.utils.data import Dataset, DataLoader
import torchvision 
import torch.nn as nn 
import torch.nn.functional as F 
from torchvision import transforms
import matplotlib.pyplot as plt 
from glob import glob
import cv2  
#데이터 로드 
%cd '/data/Pytorch_vision/data'
def data_dir_load():
    label_dirs = sorted(glob('VOCdevkit/VOC2012/SegmentationClass/*.png'))
    img_dirs = sorted(glob('VOCdevkit/VOC2012/JPEGImages/*.jpg'))[:len(label_dirs)]
    return img_dirs ,label_dirs 

class divide(nn.Module):
    def __init__(self,divide_value):
        super().__init__()
        self.divide_value = divide_value

    def forward(self,img):
        return img/self.divide_value

def image_transform():
    transform = transforms.Compose([
        divide(255),
        transforms.ToTensor(),
        transforms.Resize((256,256)),
        transforms.Normalize(mean=0.5,std=0.5)
    ])
    return transform 

class Dset(Dataset):
    def __init__(self,image_dirs,label_dirs,transform):
        super().__init__()
        self.image_dirs = image_dirs 
        self.label_dirs = label_dirs 
        self.transform = transform
    
    def __len__(self):
        return len(self.image_dirs)

    def image_transform(self,img):
        return self.transform(img)

    def image_label_load(self,img_dir,label_dir):
        img = cv2.imread(img_dir)
        label = cv2.imread(label_dir)
        return img,label

    def __getitem__(self,idx):
        self.idx = idx 
        img,label = self.image_label_load(self.image_dirs[idx],self.label_dirs[idx])

        self.image = self.image_transform(img)
        self.label = self.image_transform(label)

        return self.image, self.label 

def train_valid_split(images,labels):
    length = len(images)
    split_index = int(length*0.8)
    train_images,test_images = images[:split_index], images[split_index:]
    train_labels,test_labels = labels[:split_index], labels[split_index:]
    return train_images,test_images,train_labels,test_labels 

class Conv_Block(nn.Module):
    def __init__(self,input_c,output_c):
        super().__init__()
        self.conv_in = nn.Conv2d(in_channels=input_c,out_channels=output_c,kernel_size=3,padding=1)
        self.conv = nn.Conv2d(in_channels=output_c,out_channels=output_c,kernel_size=3,padding=1)
        self.batchnorm = nn.BatchNorm2d(output_c)
        self.maxpool = nn.MaxPool2d(2)

    def forward(self,x):
        x = self.conv_in(x)
        x = F.relu(x)
        x = self.batchnorm(x)
        x = self.conv(x)
        x = F.relu(x)
        conv = self.batchnorm(x)
        pool = self.maxpool(conv)
        return conv,pool

class Down(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv_Block(3,32)
        self.conv2 = Conv_Block(32,64)
        self.conv3 = Conv_Block(64,128)
        self.conv4 = Conv_Block(128,256)
        self.conv5 = self.Conv_last(256,512)

    def Conv_last(self,input_c,output_c):
        block = nn.Sequential(
                                nn.Conv2d(input_c,output_c,3,padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(output_c),
                                nn.Conv2d(output_c,output_c,3,padding=1),
                                nn.ReLU(),
                                nn.BatchNorm2d(output_c)
        )
        return block 

    def forward(self,x):
        conv1,pool1 = self.conv1(x)
        conv2,pool2 = self.conv2(pool1)
        conv3,pool3 = self.conv3(pool2)
        conv4,pool4 = self.conv4(pool3)
        conv5 = self.conv5(pool4)
        return conv1,conv2,conv3,conv4,conv5

class Conv_up_block(nn.Module):
    def __init__(self,input_c,output_c):
        super().__init__()
        self.up_sample = nn.ConvTranspose2d(input_c,output_c,2,stride=2,padding=0)
        self.conv1 = nn.Conv2d(input_c,output_c,3,padding=1)
        self.conv2 = nn.Conv2d(output_c,output_c,3,padding=1)
        self.batchnorm = nn.BatchNorm2d(output_c)


    def forward(self,x,conv):
        self.up = self.up_sample(x)
        x = torch.cat((self.up,conv),dim=1)

        x = self.conv1(x)
        x = F.relu(x)
        x = self.batchnorm(x)

        x = self.conv2(x)
        x = F.relu(x)
        x = self.batchnorm(x)
        
        return x 

class Up(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_up1 = Conv_up_block(512,256)
        self.conv_up2 = Conv_up_block(256,128)
        self.conv_up3 = Conv_up_block(128,64)
        self.conv_up4 = Conv_up_block(64,32)
        self.conv_last = nn.Conv2d(32,3,1,padding=0)
    
    def forward(self,conv1,conv2,conv3,conv4,conv5):
        up = self.conv_up1(conv5,conv4)
        up = self.conv_up2(up,conv3)
        up = self.conv_up3(up,conv2)
        up = self.conv_up4(up,conv1)
        up = self.conv_last(up)
        up = F.softmax(up,dim=1)
        return up 


class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        self.Down = Down()
        self.Up = Up()

    def forward(self,x):
        conv1,conv2,conv3,conv4,conv5 = self.Down(x)
        x = self.Up(conv1,conv2,conv3,conv4,conv5)
        return x   

In [3]:
class CFG:
    batch_size = 2
    epoch = 50

img_dirs,label_dirs = data_dir_load() 
train_img_dirs,test_img_dirs,train_label_dirs,test_label_dirs = train_valid_split(img_dirs,label_dirs)

image_transformer = image_transform()
train_dataset = Dset(train_img_dirs, train_label_dirs,image_transformer)
test_dataset = Dset(test_img_dirs,test_label_dirs,image_transformer)

train_dataloader = DataLoader(train_dataset,CFG.batch_size,shuffle=False)
test_dataloader = DataLoader(test_dataset,CFG.batch_size,shuffle=False)



In [8]:
a

Up(
  (conv_up1): Conv_up_block(
    (up_sample): ConvTranspose2d(512, 256, kernel_size=(2, 2), stride=(2, 2))
    (conv1): Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_up2): Conv_up_block(
    (up_sample): ConvTranspose2d(256, 128, kernel_size=(2, 2), stride=(2, 2))
    (conv1): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (batchnorm): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv_up3): Conv_up_block(
    (up_sample): ConvTranspose2d(128, 64, kernel_size=(2, 2), stride=(2, 2))
    (conv1): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), paddi

In [None]:
import torch.optim as optim 
from tqdm import tqdm
device = 'cuda'
unet = Unet().to(device)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(unet.parameters(),lr=1e-4)

for epoch in tqdm(range(CFG.epoch)):
    running_loss = 0.0
    unet.train()
    for i, (inputs,labels) in enumerate(train_dataloader):
        inputs,labels = inputs.to(device).type(torch.float),labels.to(device).type(torch.float)

        optimizer.zero_grad()

        outputs = unet(inputs)
        loss = loss_fn(outputs,labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f'loss : {running_loss}')

    with torch.no_grad():
        valid_loss = 0.0 
        for inputs,labels in test_dataloader:
            inputs,labels = inputs.to(device).type(torch.float),labels.to(device).type(torch.float)
            

            outputs = unet(inputs)
            loss = loss_fn(outputs,labels)
            valid_loss += loss.item() 

        print(f'valid_loss : {valid_loss}')

    
