## Import

In [36]:
import os
import cv2
from PIL import Image
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
from albumentations.core.transforms_interface import DualTransform
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Utils

In [37]:
# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

## Custom Dataset

In [38]:
class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image
        
        mask_path = self.data.iloc[idx, 2]
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask[mask == 255] = 12 #배경을 픽셀값 12로 간주

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

In [39]:
scale = 5.5
def fisheye_distortion_rm(image):
    
    h,w = image.shape[:2]

    focal_length = w / 4
    center_x = w / 2 
    center_y = h / 2
    camera_matrix = np.array([[focal_length,0,center_x],[0,focal_length,center_y],[0,0,1]],dtype=np.float32)

    dist_coeffs = np.array([0,0.5,0,0],dtype=np.float32)

    map_x, map_y = cv2.initUndistortRectifyMap(camera_matrix, dist_coeffs, None, None, (w, h), cv2.CV_32FC1)
    undistorted_image = cv2.remap(image, map_x, map_y, interpolation=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT, borderValue=0)
    #undistorted_image = cv2.undistort(image,camera_matrix,dist_coeffs)
    undistorted_image = undistorted_image[int(h/scale):int(h-h/scale),int(w/scale):int(w-w/scale)]

    return undistorted_image

def fisheye_distortion_mask_rm(image):
    #image[image==0] = 255
    h,w = image.shape[:2]

    focal_length = w / 4
    center_x = w / 2 
    center_y = h / 2
    camera_matrix = np.array([[focal_length,0,center_x],[0,focal_length,center_y],[0,0,1]],dtype=np.float32)

    dist_coeffs = np.array([0,0.5,0,0],dtype=np.float32)

    map_x, map_y = cv2.initUndistortRectifyMap(camera_matrix, dist_coeffs, None, None, (w, h), cv2.CV_32FC1)
    undistorted_image = cv2.remap(image, map_x, map_y, interpolation=cv2.INTER_NEAREST, borderMode=cv2.BORDER_CONSTANT,borderValue=12)
    #undistorted_image = cv2.undistort(image,camera_matrix,dist_coeffs)
    undistorted_image = undistorted_image[int(h/scale):int(h-h/scale),int(w/scale):int(w-w/scale)].astype(np.uint8)
    #undistorted_image = np.round(undistorted_image).astype(np.uint8)
    #undistorted_image[undistorted_image==0] = 12
    #undistorted_image = filter(undistorted_image)
    #undistorted_image[undistorted_image>12] = 0

    return undistorted_image

class Fisheye(DualTransform):
    def __init__(self):
        super(Fisheye,self).__init__()
    
    def apply(self, img, **params):
        return fisheye_distortion_rm(img)

    def apply_to_mask(self, mask, **params):
        return fisheye_distortion_mask_rm(mask)
    

## Data Loader

In [40]:
transform = A.Compose(
    [   
        Fisheye(),
        A.Resize(512, 512),
        A.Normalize(),
        ToTensorV2()
    ]
)


dataset = CustomDataset(csv_file='./train_source.csv', transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4)

## Define Model

In [41]:
# U-Net의 기본 구성 요소인 Double Convolution Block을 정의합니다.
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

# 간단한 U-Net 모델 정의
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, 13, 1) # 12개 class + 1 background

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)

        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   

        x = self.dconv_down4(x)

        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        return out

## Model Train

In [17]:

# training loop
for epoch in range(50):  # 20 에폭 동안 학습합니다.
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(dataloader):
        images = images.float().to(device)
        masks = masks.long().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.squeeze(1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader)}')

torch.save(model.state_dict(),'./save_path/baselinefish_save50.pth')

100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 1, Loss: 1.6633301835129226


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 2, Loss: 0.9308408801106439


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 3, Loss: 0.6407064311746238


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 4, Loss: 0.5287067259567372


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 5, Loss: 0.4645630639532338


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 6, Loss: 0.4755633585694907


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 7, Loss: 0.37407413384188776


100%|██████████| 69/69 [00:57<00:00,  1.21it/s]


Epoch 8, Loss: 0.33937658823054767


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 9, Loss: 0.3043020717475725


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 10, Loss: 0.282417270368424


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 11, Loss: 0.2715987191684004


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 12, Loss: 0.2459740936756134


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 13, Loss: 0.2270699707062348


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 14, Loss: 0.21602274045564127


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 15, Loss: 0.20893793391144794


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 16, Loss: 0.19844070789606674


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 17, Loss: 0.18161127770292587


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 18, Loss: 0.17193472277427066


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 19, Loss: 0.16775086932424185


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 20, Loss: 0.1640603401954623


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 21, Loss: 0.1502240647872289


100%|██████████| 69/69 [00:57<00:00,  1.21it/s]


Epoch 22, Loss: 0.14023071106361307


100%|██████████| 69/69 [00:57<00:00,  1.21it/s]


Epoch 23, Loss: 0.13474664739940478


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 24, Loss: 0.12955447729083075


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 25, Loss: 0.12277502318223317


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 26, Loss: 0.11648498572733092


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 27, Loss: 0.1124882290976635


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 28, Loss: 0.10841789388138315


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 29, Loss: 0.11034300642600958


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 30, Loss: 0.10608381616032642


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 31, Loss: 0.10000637076471162


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 32, Loss: 0.09510342463635016


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 33, Loss: 0.09160587731478871


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 34, Loss: 0.09039122848839


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 35, Loss: 0.08767027839802313


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 36, Loss: 0.30848876084538474


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 37, Loss: 0.2038141385368679


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 38, Loss: 0.14091523293999658


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 39, Loss: 0.11204705300970354


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 40, Loss: 0.09730236415845761


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 41, Loss: 0.08925580006578694


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 42, Loss: 0.08418388394773871


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 43, Loss: 0.080176940538745


100%|██████████| 69/69 [00:55<00:00,  1.24it/s]


Epoch 44, Loss: 0.07780708901692128


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]


Epoch 45, Loss: 0.07410551078509593


100%|██████████| 69/69 [00:56<00:00,  1.22it/s]


Epoch 46, Loss: 0.07284118650832037


100%|██████████| 69/69 [00:57<00:00,  1.21it/s]


Epoch 47, Loss: 0.0715793824714163


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 48, Loss: 0.06995246850925943


100%|██████████| 69/69 [00:56<00:00,  1.21it/s]


Epoch 49, Loss: 0.06787389101109643


100%|██████████| 69/69 [00:56<00:00,  1.23it/s]

Epoch 50, Loss: 0.06597072263990623





## Inference

In [43]:
test_dataset = CustomDataset(csv_file='./test.csv', transform=transform, infer=True)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False,num_workers=1)

In [44]:
palette = [[0, 94, 135], [242, 255, 97], [165, 42, 42], [0, 0, 192],
                [197, 226, 255], [0, 60, 100], [0, 0, 142], [62, 200, 71],
                [255,207,157], [0, 187, 255], [255, 102, 163],[166,97,247],[0,0,0]]

palette = np.array(palette)

with torch.no_grad():
    model.load_state_dict(torch.load('./save_path/baselinefish_save50.pth'))
    model.eval()
    result = []
    iter=0
    for images in tqdm(test_dataloader):
        images = images.float().to(device)
        outputs = model(images)
        outputs = torch.softmax(outputs, dim=1).cpu()
        outputs = torch.argmax(outputs, dim=1).numpy()
        # batch에 존재하는 각 이미지에 대해서 반복
        for (i,pred) in enumerate(outputs,1):
            pred = pred.astype(np.uint8)
            pred = Image.fromarray(pred) # 이미지로 변환
            pred = pred.resize((960, 540), Image.NEAREST) # 960 x 540 사이즈로 변환
            pred = np.array(pred) # 다시 수치로 변환
            # class 0 ~ 11에 해당하는 경우에 마스크 형성 / 12(배경)는 제외하고 진행
            for class_id in range(12):
                class_mask = (pred == class_id).astype(np.uint8)
                mask_img = palette[pred]
                
                # if np.sum(class_mask) > 0: # 마스크가 존재하는 경우 encode
                #     mask_rle = rle_encode(class_mask)
                #     result.append(mask_rle)
                # else: # 마스크가 존재하지 않는 경우 -1
                #     result.append(-1)

            test = pd.read_csv('./test.csv')
            img_name=test['id'][iter*1+i-1]
            img_org = cv2.imread(test['img_path'][iter*1+i-1])
            img_org = cv2.resize(img_org,(960,540))
            result = np.hstack((img_org,mask_img))
            cv2.imwrite(f"./mask_save_base/{img_name}.png",result)
        iter+=1

  4%|▍         | 79/1898 [00:15<05:48,  5.22it/s]


KeyboardInterrupt: 

## Submission

In [21]:
submit = pd.read_csv('./sample_submission.csv')
submit['mask_rle'] = result
submit

Unnamed: 0,id,mask_rle
0,TEST_0000_class_0,456845 5 457805 5 458761 13 459721 13 460681 1...
1,TEST_0000_class_1,-1
2,TEST_0000_class_2,1 133 644 51 747 347 1604 51 1707 351 2534 8 2...
3,TEST_0000_class_3,450168 4 451128 4 452071 21 453031 21 453991 2...
4,TEST_0000_class_4,169810 4 170770 4 171730 4 172690 4 173650 4 1...
...,...,...
22771,TEST_1897_class_7,63322 9 64282 9 65242 9 66198 9 67158 9 68114 ...
22772,TEST_1897_class_8,95 545 674 128 1055 545 1634 128 2020 540 2594...
22773,TEST_1897_class_9,169870 8 169882 5 170830 8 170842 5 171777 30 ...
22774,TEST_1897_class_10,203898 4 203920 4 204858 4 204880 4 205818 4 2...


In [22]:
submit.to_csv('./baselinefish50_submit.csv', index=False)