## Import

In [48]:
import os
import cv2
from PIL import Image
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


## Utils

In [49]:
# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

## Custom Dataset

In [50]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image
        
        mask_path = self.data.iloc[idx, 2]
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask[mask == 255] = 12 #배경을 픽셀값 12로 간주

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

## Data Loader

In [51]:
transform = A.Compose(
    [   
        A.Resize(224, 224),
        A.Normalize(),
        ToTensorV2()
    ]
)

dataset = CustomDataset(csv_file='/home/work/CPS_Project/Samsung AI-Challenge/open/train_source.csv', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

## Define Model

In [52]:
# U-Net의 기본 구성 요소인 Double Convolution Block을 정의합니다.
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

# 간단한 U-Net 모델 정의
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, 13, 1) # 12개 class + 1 background

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   

        x = self.dconv_down4(x)

        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        return out

## Model Train

In [53]:
# model 초기화
model = UNet().to(device)

# loss function과 optimizer 정의
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# training loop
for epoch in range(50):  # 20 에폭 동안 학습합니다.
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(dataloader):
        images = images.float().to(device)
        masks = masks.long().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader)}')

100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 1, Loss: 1.3125518685665682


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 2, Loss: 0.676990296529687


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]


Epoch 3, Loss: 0.5239679325317991


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 4, Loss: 0.4739754418100136


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 5, Loss: 0.39278969989306683


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]


Epoch 6, Loss: 0.36777145145595935


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 7, Loss: 0.30487373859986017


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 8, Loss: 0.2861807427328566


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 9, Loss: 0.2559204589629519


100%|██████████| 138/138 [00:50<00:00,  2.74it/s]


Epoch 10, Loss: 0.2403058582457943


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 11, Loss: 0.22084568386924439


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 12, Loss: 0.20470657523559488


100%|██████████| 138/138 [00:50<00:00,  2.74it/s]


Epoch 13, Loss: 0.18426448214745175


100%|██████████| 138/138 [00:53<00:00,  2.57it/s]


Epoch 14, Loss: 0.16793571909268698


100%|██████████| 138/138 [00:51<00:00,  2.70it/s]


Epoch 15, Loss: 0.1674319269756476


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 16, Loss: 0.14808482023468916


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 17, Loss: 0.13777764098367828


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 18, Loss: 0.1461373482817325


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 19, Loss: 0.1410599599821844


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 20, Loss: 0.12456740514523741


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 21, Loss: 0.1139332607280517


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 22, Loss: 0.10545625084120294


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 23, Loss: 0.10634188877715581


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 24, Loss: 0.09725065869481667


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 25, Loss: 0.09281508613755737


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 26, Loss: 0.09071694589827371


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 27, Loss: 0.0860726121219172


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 28, Loss: 0.08693458741881709


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 29, Loss: 0.0853214672078257


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 30, Loss: 0.5525819098819857


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 31, Loss: 0.4720541480658711


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 32, Loss: 0.3351316988684129


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 33, Loss: 0.27725337758876273


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]


Epoch 34, Loss: 0.23868685496458109


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]


Epoch 35, Loss: 0.2155541345693063


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 36, Loss: 0.20271371089030002


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 37, Loss: 0.21399874456118848


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]


Epoch 38, Loss: 0.17295238375663757


100%|██████████| 138/138 [00:50<00:00,  2.74it/s]


Epoch 39, Loss: 0.15946338968216509


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 40, Loss: 0.14176314502306606


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 41, Loss: 0.15257369392160056


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 42, Loss: 0.1297832510933496


100%|██████████| 138/138 [00:50<00:00,  2.76it/s]


Epoch 43, Loss: 0.11796494654339293


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]


Epoch 44, Loss: 0.11132358408708504


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 45, Loss: 0.10485957087813944


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 46, Loss: 0.10162398300093153


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 47, Loss: 0.09676240082236304


100%|██████████| 138/138 [00:50<00:00,  2.75it/s]


Epoch 48, Loss: 0.09248496764811917


100%|██████████| 138/138 [00:50<00:00,  2.74it/s]


Epoch 49, Loss: 0.08750946317677913


100%|██████████| 138/138 [00:49<00:00,  2.76it/s]

Epoch 50, Loss: 0.08442675295299377





## Inference

In [58]:
test_dataset = CustomDataset(csv_file='/home/work/CPS_Project/Samsung AI-Challenge/open/test.csv', transform=transform, infer=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=4)

In [60]:
with torch.no_grad():
    model.eval()
    result = []
    for images in tqdm(test_dataloader):
        images = images.float().to(device)
        outputs = model(images)
        outputs = torch.softmax(outputs, dim=1).cpu()
        outputs = torch.argmax(outputs, dim=1).numpy()
        # batch에 존재하는 각 이미지에 대해서 반복
        for pred in outputs:
            pred = pred.astype(np.uint8)
            pred = Image.fromarray(pred) # 이미지로 변환
            pred = pred.resize((960, 540), Image.NEAREST) # 960 x 540 사이즈로 변환
            pred = np.array(pred) # 다시 수치로 변환
            # class 0 ~ 11에 해당하는 경우에 마스크 형성 / 12(배경)는 제외하고 진행
            for class_id in range(12):
                class_mask = (pred == class_id).astype(np.uint8)
                if np.sum(class_mask) > 0: # 마스크가 존재하는 경우 encode
                    mask_rle = rle_encode(class_mask)
                    result.append(mask_rle)
                else: # 마스크가 존재하지 않는 경우 -1
                    result.append(-1)

  pred = pred.resize((960, 540), Image.NEAREST) # 960 x 540 사이즈로 변환
100%|██████████| 119/119 [01:26<00:00,  1.38it/s]


## Submission

In [61]:
submit = pd.read_csv('/home/work/CPS_Project/Samsung AI-Challenge/open/sample_submission.csv')
submit['mask_rle'] = result
submit

Unnamed: 0,id,mask_rle
0,TEST_0000_class_0,218405 5 219365 5 220317 21 221277 21 222232 3...
1,TEST_0000_class_1,230075 5 231035 5 231995 17 232955 17 233915 1...
2,TEST_0000_class_2,1 81 742 300 1702 300 2654 308 3614 308 4574 3...
3,TEST_0000_class_3,233932 13 234892 13 235852 13 236812 22 237772...
4,TEST_0000_class_4,-1
...,...,...
22771,TEST_1897_class_7,922 13 1882 13 2842 13 3802 13 4762 13 5714 34...
22772,TEST_1897_class_8,104 536 678 120 858 13 1064 536 1638 120 1818 ...
22773,TEST_1897_class_9,33545 5 34505 5 202240 4 203200 4
22774,TEST_1897_class_10,210485 5 211445 5 212405 5 213365 5 214325 5


In [62]:
submit.to_csv('./baseline_submit.csv', index=False)