In [8]:
import os
import cv2
import pandas as pd
import numpy as np
import datetime as dt


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import StratifiedKFold, KFold

from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm

import albumentations as A
import albumentations.pytorch
import wandb
from typing import List, Union
from joblib import Parallel, delayed

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [12]:
# RLE 디코딩 함수
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s [0:][::2], s [1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape [0]*shape [1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img [lo:hi] = 1
    return img.reshape(shape)

# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels [1:]!= pixels [:-1])[0] + 1
    runs [1::2] -= runs [::2]
    return ' '. join(str(x) for x in runs)

In [23]:
class SatelliteDataset(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv('../Data/satellite/' + csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread('../Data/satellite/'+ img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image

        mask_rle = self.data.iloc[idx, 2]
        mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

In [4]:
data2 = pd.read_csv('../Data/satellite/train.csv')
data2

Unnamed: 0,img_id,img_path,mask_rle
0,TRAIN_0000,./train_img/TRAIN_0000.png,9576 7 10590 17 11614 17 12638 17 13662 17 146...
1,TRAIN_0001,./train_img/TRAIN_0001.png,208402 1 209425 6 210449 10 211473 14 212497 1...
2,TRAIN_0002,./train_img/TRAIN_0002.png,855 34 15654 9 16678 9 16742 8 17702 9 17766 9...
3,TRAIN_0003,./train_img/TRAIN_0003.png,362 6 745 15 798 22 900 25 1385 8 1828 16 1924...
4,TRAIN_0004,./train_img/TRAIN_0004.png,34 27 1058 27 2082 27 3105 27 4129 27 5153 27 ...
...,...,...,...
7135,TRAIN_7135,./train_img/TRAIN_7135.png,193 19 882 18 985 21 1217 17 1782 2 1906 18 20...
7136,TRAIN_7136,./train_img/TRAIN_7136.png,85938 13 86962 20 87986 20 89009 21 90033 21 9...
7137,TRAIN_7137,./train_img/TRAIN_7137.png,100 59 314 28 878 28 997 20 1124 59 1338 28 19...
7138,TRAIN_7138,./train_img/TRAIN_7138.png,789 18 975 17 1814 16 2000 14 2544 2 2839 14 3...


In [27]:
transform_train = A.Compose(    [   
    A.RandomResizedCrop(p=1, height=224 ,width=224, scale=(0.25, 0.35),ratio=(0.90, 1.10)),
    A.ColorJitter(always_apply=True, p=0.5, contrast=0.2, saturation=0.3, hue=0.2),
    A.HorizontalFlip(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=True, p=1.0),
     A.pytorch.transforms.ToTensorV2()
])
transform_test = A.Compose([
    A.Resize(height = 224, width = 224),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), max_pixel_value=255.0, always_apply=True, p=1.0),
     A.pytorch.transforms.ToTensorV2()
])


In [6]:
mask_rle = '9576 7 10590 17 11614 17 12638 17 13662 17 14686 17 15710 17 16734 17 17716 100 18740 100 19764 100 20788 100 21812 100 22836 100 23860 100 24884 100 25908 100 26932 100 27956 100 28980 100 30004 100 31028 100 32052 100 33076 100 34100 100 35124 100 36148 100 37172 100 38196 100 39220 100 40244 100 41268 100 42292 100 43317 99 44341 99 45365 99 46389 99 47413 99 48437 99 49461 99 50485 99 51509 99 52533 99 53557 99 54581 99 55605 99 56629 99 57653 99 58677 99 59701 99 60722 1 60725 99 61749 99 62773 99 63797 99 64821 99 65845 99 66481 45 66869 99 67505 55 67893 99 68529 55 68917 99 69553 55 69941 99 70577 55 70965 99 71601 55 71989 99 72625 55 73013 99 73649 55 74037 99 74673 55 75061 99 75697 55 76085 99 76721 55 77109 99 77745 55 78133 99 78769 55 79157 21 79793 55 80817 55 81251 1 81841 55 82865 55 83889 55 84913 55 85937 55 86961 55 87985 55 89009 55 90033 55 91057 55 92081 55 93105 55 94129 55 95153 55 96177 57 97201 57 98225 57 99249 57 100273 57 101297 57 102321 57 103345 57 104369 57 105393 57 106417 57 107440 58 108464 58 109488 58 110512 58 111536 58 112560 58 113584 58 114608 58 115632 58 116656 58 117680 58 118704 58 119728 58 120752 58 121776 58 122800 58 123824 58 124848 58 125872 58 126896 58 127920 58 128944 58 129968 58 130992 58 132016 58 133040 58 134064 58 135088 58 136112 58 137136 58 138194 24 140233 1 159671 52 160695 52 161719 52 162742 53 163766 53 164790 53 165814 53 166838 53 167862 53 168886 53 169910 53 170934 53 171958 53 172982 53 174006 53 175030 53 176054 53 177078 53 178102 53 179126 53 180150 53 181174 53 182198 53 183222 53 184246 53 185270 53 186294 53 187318 53 188342 53 189366 53 190390 53 191414 53 192438 53 193440 75 194464 75 195488 75 196512 75 197536 75 198560 75 199584 75 200608 75 201632 75 202656 75 203680 75 204704 75 205728 75 206751 76 207775 76 208799 76 209823 85 210847 98 211871 98 212895 98 213919 98 214943 98 215967 98 216991 98 218015 98 219039 98 220063 98 221087 98 222111 98 223135 98 224159 98 225183 98 226207 98 227231 98 228255 98 229279 98 230303 98 231327 98 232351 98 233375 98 234159 29 234399 98 235174 38 235423 98 236198 38 236447 98 237222 38 237471 98 238246 38 238495 98 239270 38 239519 98 240294 38 240543 98 241318 38 241567 98 242342 38 242591 98 243366 38 243615 98 244390 38 244639 98 245414 38 245663 98 246438 38 246687 98 247462 38 247711 98 248486 38 248735 98 249510 38 249759 98 250534 38 250783 98 251558 38 251807 98 252582 38 252831 98 253606 38 253855 98 254630 38 254879 98 255654 38 255903 98 256678 38 256927 98 257702 38 257951 98 258726 38 258975 98 259751 36 259999 98 261023 98 262047 98 263071 98 264095 98 265119 98 266143 98 267167 98 268191 98 269215 98 270239 98 271263 98 272287 98 273311 98 274335 98 275359 98 276383 98 277407 98 278431 98 279455 98 280479 98 281503 98 282527 98 283551 98 284575 98 285599 98 286623 98 287647 98 288671 98 289695 98 290814 3 358257 29 359281 29 360305 29 361329 29 362353 29 363377 29 364401 29 365425 29 366449 29 367473 29 368497 29 369521 29 370545 29 371569 29 372593 29 373617 29 727281 1 728303 3 729325 6 730347 8 731369 11 732391 13 733412 17 734434 19 735456 22 736478 24 737500 27 738521 30 739543 33 740565 35 741587 38 742609 40 743633 41 744658 38 745682 36 746706 34 747731 30 748755 28 749780 25 750804 23 751829 20 752853 18 753878 14 754902 12 755927 9 756951 7 757976 4 759000 2 855061 2 856083 4 857106 6 858128 8 859150 11 860172 13 861194 16 862216 18 863238 21 864262 21 865287 21 866311 21 867336 21 868360 21 869385 21 870409 22 871434 21 872458 22 873483 21 874507 22 875532 21 876556 22 877580 22 878604 23 879626 25 880650 26 881675 25 882700 25 883724 25 884749 25 885773 2 885777 21 886802 21 887826 21 888851 21 889875 21 890900 21 891924 22 892949 21 893973 22 894998 21 896022 22 897047 21 898072 21 899096 21 900121 20 901145 18 902170 15 903194 13 904219 10 905243 8 906268 5 907292 3 921888 2 922911 4 923934 6 924957 8 925980 10 927003 12 928026 15 929049 17 930073 18 931096 20 932119 22 933142 24 934167 24 935192 24 936217 24 937242 24 938266 25 939289 25 940312 25 941335 25 942358 25 943381 25 944404 25 944432 1 945427 25 946450 25 947474 25 948497 25 949522 23 950547 21 951572 19 952597 17 953622 15 954647 13 955672 11 956698 8 957723 6 958748 4 959773 2 961819 1'

In [None]:

plt.figure(figsize=(10,200))
img = cv2.imread("../Data/satellite/train_img/TRAIN_0000.png")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#img = img[:, :, ::-1]
#mask_rle = self.data.iloc[idx, 2]
mask = rle_decode(mask_rle, (img.shape[0], img.shape[1]))
augmented = transform_train(image=img, mask=mask)
image = augmented['image']
mask = augmented['mask']

plt.subplot(40,2,1)
plt.imshow(image, cmap=plt.cm.binary)

plt.subplot(40,2,2)
plt.imshow(mask, cmap=plt.cm.binary)

#mask_rle = self.data.iloc[idx, 2]
mask = rle_decode(mask_rle, (img.shape[0], img.shape[1]))
augmented = transform_train(image=img, mask=mask)
image = augmented['image']
mask = augmented['mask']

plt.grid(False)
plt.subplot(40,2,3)
plt.imshow(image, cmap=plt.cm.binary)

plt.subplot(40,2,4)
plt.imshow(mask, cmap=plt.cm.binary)

plt.show()

In [30]:
# U-Net의 기본 구성 요소인 Double Convolution Block을 정의합니다.
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels, out_channels, 3, padding=1),
        nn.ReLU(inplace=True)
    )

# 간단한 U-Net 모델 정의
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.dconv_down1 = double_conv(3, 64)
        self.dconv_down2 = double_conv(64, 128)
        self.dconv_down3 = double_conv(128, 256)
        self.dconv_down4 = double_conv(256, 512)

        self.maxpool = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)        

        self.dconv_up3 = double_conv(256 + 512, 256)
        self.dconv_up2 = double_conv(128 + 256, 128)
        self.dconv_up1 = double_conv(128 + 64, 64)

        self.conv_last = nn.Conv2d(64, 1, 1)

    def forward(self, x):
        conv1 = self.dconv_down1(x)
        x = self.maxpool(conv1)

        conv2 = self.dconv_down2(x)
        x = self.maxpool(conv2)
        
        conv3 = self.dconv_down3(x)
        x = self.maxpool(conv3)   

        x = self.dconv_down4(x)

        x = self.upsample(x)        
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv2], dim=1)       

        x = self.dconv_up2(x)
        x = self.upsample(x)        
        x = torch.cat([x, conv1], dim=1)   

        x = self.dconv_up1(x)

        out = self.conv_last(x)

        return out

In [12]:
def calculate_dice(pred_rle, gt_rle):
    pred_mask = rle_decode(pred_rle, (224, 224))
    gt_mask = rle_decode(gt_rle, (224, 224))


    if np.sum(gt_mask) > 0 or np.sum(pred_mask) > 0:
        return dice_score(pred_mask, gt_mask)
    else:
        return None  # No valid masks found, return None
def dice_score(prediction: np.array, ground_truth: np.array, smooth=1e-7):#-> float:
    '''
    Calculate Dice Score between two binary masks.
    '''
    intersection = np.sum(prediction * ground_truth)
    return (2.0 * intersection + smooth) / (np.sum(prediction) + np.sum(ground_truth) + smooth)


In [13]:
dataset = SatelliteDataset(csv_file='./train.csv', transform=transform_train)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=8, pin_memory=True)

In [14]:
import numpy as np
import pandas as pd
from typing import List, Union
from joblib import Parallel, delayed

In [None]:
# model 초기화
model = UNet().to(device)

# loss function과 optimizer 정의
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# training loop
for epoch in range(10):  # 10 에폭 동안 학습합니다.
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(dataloader):
        images = images.float().to(device)
        masks = masks.float().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f'Epoch {epoch+1}, Loss: {epoch_loss/len(dataloader)}')

In [54]:
torch.save(model.state_dict(), ('./models'+'/{}.pth').format('200_epoch_model'))

In [28]:
test_dataset = SatelliteDataset(csv_file='./test.csv', transform=transform_test, infer=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=8)

In [33]:
with torch.no_grad():
    model.eval()
    result = []
    for images in tqdm(test_dataloader):
        images = images.float().to(device)
        
        outputs = model(images)
        masks = torch.sigmoid(outputs).cpu().numpy()
        masks = np.squeeze(masks, axis=1)
        masks = (masks > 0.35).astype(np.uint8) # Threshold = 0.35
        
        for i in range(len(images)):
            mask_rle = rle_encode(masks[i])
            if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                result.append(-1)
            else:
                result.append(mask_rle)
    print(mask_rle)

100%|███████████████████████████████████████| 3790/3790 [05:57<00:00, 10.60it/s]

1 50176





In [None]:
result

In [35]:
submit = pd.read_csv('../Data/satellite/sample_submission.csv')
submit['mask_rle'] = result

In [36]:
submit

Unnamed: 0,img_id,mask_rle
0,TEST_00000,1 50176
1,TEST_00001,1 50176
2,TEST_00002,1 50176
3,TEST_00003,1 50176
4,TEST_00004,1 50176
...,...,...
60635,TEST_60635,1 50176
60636,TEST_60636,1 50176
60637,TEST_60637,1 50176
60638,TEST_60638,1 50176


In [51]:
prediction_df = submit[submit.iloc[:, 0].isin(submit.iloc[:, 0])]
prediction_df.index = range(prediction_df.shape[0])
prediction_df

Unnamed: 0,img_id,mask_rle
0,TEST_00000,5562 1 5785 2 6009 2 6233 2 6679 4 6903 10 712...
1,TEST_00001,152 6 186 10 370 1 372 3 378 5 410 10 593 14 6...
2,TEST_00002,32 4 936 1 1160 2 1385 3 1402 2 1610 2 1626 4 ...
3,TEST_00003,5 8 14 31 79 1 81 3 230 4 238 32 305 2 453 5 4...
4,TEST_00004,7404 2 7618 2 7627 6 7635 1 7842 2 7852 2 7859...
...,...,...
60635,TEST_60635,78 5 103 9 298 10 326 9 522 9 550 10 747 4 757...
60636,TEST_60636,129 9 141 7 189 20 353 9 365 6 411 18 577 9 58...
60637,TEST_60637,32 10 56 25 85 10 98 9 172 34 257 6 280 25 309...
60638,TEST_60638,11 9 22 21 44 4 130 9 197 9 235 10 246 21 421 ...


### pred_mask_rle = prediction_df.iloc[:, 1]
pred_mask_rle[0]

In [65]:
submit.to_csv('./submit.csv', index=False)

In [49]:
submit = pd.read_csv('./submit.csv')

In [44]:
data

Unnamed: 0,img_id,mask_rle
0,TEST_00000,5562 1 5785 2 6009 2 6233 2 6679 4 6903 10 712...
1,TEST_00001,152 6 186 10 370 1 372 3 378 5 410 10 593 14 6...
2,TEST_00002,32 4 936 1 1160 2 1385 3 1402 2 1610 2 1626 4 ...
3,TEST_00003,5 8 14 31 79 1 81 3 230 4 238 32 305 2 453 5 4...
4,TEST_00004,7404 2 7618 2 7627 6 7635 1 7842 2 7852 2 7859...
...,...,...
60635,TEST_60635,78 5 103 9 298 10 326 9 522 9 550 10 747 4 757...
60636,TEST_60636,129 9 141 7 189 20 353 9 365 6 411 18 577 9 58...
60637,TEST_60637,32 10 56 25 85 10 98 9 172 34 257 6 280 25 309...
60638,TEST_60638,11 9 22 21 44 4 130 9 197 9 235 10 246 21 421 ...


In [16]:
test_df =pd.read_csv('../Data/satellite/test.csv')

In [17]:
test_df.iloc[:, 0]

0        TEST_00000
1        TEST_00001
2        TEST_00002
3        TEST_00003
4        TEST_00004
            ...    
60635    TEST_60635
60636    TEST_60636
60637    TEST_60637
60638    TEST_60638
60639    TEST_60639
Name: img_id, Length: 60640, dtype: object

In [14]:
def dice_score(prediction: np.array, ground_truth: np.array, smooth=1e-7) -> float:
    '''
    Calculate Dice Score between two binary masks.
    '''
    intersection = np.sum(prediction * ground_truth)
    return (2.0 * intersection + smooth) / (np.sum(prediction) + np.sum(ground_truth) + smooth)


def calculate_dice_scores(ground_truth_df, prediction_df, img_shape=(224, 224)) -> List[float]:
    '''
    Calculate Dice scores for a dataset.
    '''


    # Keep only the rows in the prediction dataframe that have matching img_ids in the ground truth dataframe
    prediction_df = prediction_df[prediction_df.iloc[:, 0].isin(ground_truth_df.iloc[:, 0])]
    prediction_df.index = range(prediction_df.shape[0])


    # Extract the mask_rle columns
    pred_mask_rle = prediction_df.iloc[:, 2]
    gt_mask_rle = ground_truth_df.iloc[:, 2]


    def calculate_dice(pred_rle, gt_rle):
        pred_mask = rle_decode(pred_rle, img_shape)
        gt_mask = rle_decode(gt_rle, img_shape)


        if np.sum(gt_mask) > 0 or np.sum(pred_mask) > 0:
            return dice_score(pred_mask, gt_mask)
        else:
            return None  # No valid masks found, return None


    dice_scores = Parallel(n_jobs=-1)(
        delayed(calculate_dice)(pred_rle, gt_rle) for pred_rle, gt_rle in zip(pred_mask_rle, gt_mask_rle)
    )


    dice_scores = [score for score in dice_scores if score is not None]  # Exclude None values


    return np.mean(dice_scores)

In [10]:
tmp = pd.read_csv('../Data/satellite/train.csv')
tmp2 = pd.read_csv('../Data/satellite/train.csv')


In [15]:
calculate_dice_scores(tmp, tmp2)

1.0

In [21]:
pred_mask_rle= tmp.iloc[:, 2]
pred_mask_rle[1][0]

'2'