In [1]:
import os
import cv2
import pandas as pd
import numpy as np
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from tqdm import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"  # Arrange GPU devices starting from 0
# os.environ["CUDA_VISIBLE_DEVICES"]= "2"  # Set the GPU 2 to use

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print('Device:', device)
print('Current cuda device:', torch.cuda.current_device())
print('Count of using GPUs:', torch.cuda.device_count())

Device: cuda
Current cuda device: 0
Count of using GPUs: 2


In [2]:
# RLE 디코딩 함수
def rle_decode(mask_rle, shape):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape)

# RLE 인코딩 함수
def rle_encode(mask):
    pixels = mask.flatten()
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

In [3]:
class SatelliteDataset(Dataset):
    def __init__(self, csv_file, transform=None, infer=False):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
        self.infer = infer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path = self.data.iloc[idx, 1]
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.infer:
            if self.transform:
                image = self.transform(image=image)['image']
            return image

        mask_rle = self.data.iloc[idx, 2]
        if type(mask_rle) == str:
            mask = rle_decode(mask_rle, (image.shape[0], image.shape[1]))
        else:
            mask = np.zeros((image.shape[0], image.shape[1]), np.uint8)

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

transform = A.Compose(
    [
        A.Normalize(),
        ToTensorV2()
    ]
)

dataset = SatelliteDataset(csv_file='./train_256+.csv', transform=transform)
print(len(dataset))
train_set, val_set = torch.utils.data.random_split(dataset, [128520,32130])
train_dataloader = DataLoader(train_set, batch_size=8, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_set, batch_size=8, shuffle=True, num_workers=0)
del dataset


160650


In [4]:
class Unet_block(nn.Module):
    def __init__(self, in_channels, mid_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(mid_channels)
        self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)        
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)
        return out

class Nested_UNet(nn.Module):

    def __init__(self, num_classes = 1, input_channels=3, deep_supervision=False):
        super().__init__()

        num_filter = [32, 64, 128, 256, 512, 768, 1024]
        self.deep_supervision = deep_supervision
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        
        # DownSampling
        self.conv0_0 = Unet_block(input_channels, num_filter[0], num_filter[0])
        self.conv1_0 = Unet_block(num_filter[0], num_filter[1], num_filter[1])
        self.conv2_0 = Unet_block(num_filter[1], num_filter[2], num_filter[2])
        self.conv3_0 = Unet_block(num_filter[2], num_filter[3], num_filter[3])
        self.conv4_0 = Unet_block(num_filter[3], num_filter[4], num_filter[4])

        # Upsampling & Dense skip
        # N to 1 skip
        self.conv0_1 = Unet_block(num_filter[0] + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_1 = Unet_block(num_filter[1] + num_filter[2], num_filter[1], num_filter[1])
        self.conv2_1 = Unet_block(num_filter[2] + num_filter[3], num_filter[2], num_filter[2])
        self.conv3_1 = Unet_block(num_filter[3] + num_filter[4], num_filter[3], num_filter[3])
       
        # N to 2 skip
        self.conv0_2 = Unet_block(num_filter[0]*2 + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_2 = Unet_block(num_filter[1]*2 + num_filter[2], num_filter[1], num_filter[1])
        self.conv2_2 = Unet_block(num_filter[2]*2 + num_filter[3], num_filter[2], num_filter[2])

        # N to 3 skip
        self.conv0_3 = Unet_block(num_filter[0]*3 + num_filter[1], num_filter[0], num_filter[0])
        self.conv1_3 = Unet_block(num_filter[1]*3 + num_filter[2], num_filter[1], num_filter[1])

        # N to 4 skip
        self.conv0_4 = Unet_block(num_filter[0]*4 + num_filter[1], num_filter[0], num_filter[0])

        if self.deep_supervision:
            self.output1 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output2 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output3 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)
            self.output4 = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)

        else:
            self.output = nn.Conv2d(num_filter[0], num_classes, kernel_size=1)

        # initialize weights
        # for m in self.modules():
        #     if isinstance(m, nn.Conv2d):
        #         init_weights(m, init_type='kaiming')
        #     elif isinstance(m, nn.BatchNorm2d):
        #         init_weights(m, init_type='kaiming')


    def forward(self, x):                    # (Batch, 3, 256, 256)

        x0_0 = self.conv0_0(x)               
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(torch.cat([x0_0, self.up(x1_0)], dim=1))
        
        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(torch.cat([x1_0, self.up(x2_0)], dim=1))
        x0_2 = self.conv0_2(torch.cat([x0_0, x0_1, self.up(x1_1)], dim=1))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(torch.cat([x2_0, self.up(x3_0)], dim=1))
        x1_2 = self.conv1_2(torch.cat([x1_0, x1_1, self.up(x2_1)], dim=1))
        x0_3 = self.conv0_3(torch.cat([x0_0, x0_1, x0_2, self.up(x1_2)], dim=1))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(torch.cat([x3_0, self.up(x4_0)], dim=1))
        x2_2 = self.conv2_2(torch.cat([x2_0, x2_1, self.up(x3_1)], dim=1))
        x1_3 = self.conv1_3(torch.cat([x1_0, x1_1, x1_2, self.up(x2_2)], dim=1))
        x0_4 = self.conv0_4(torch.cat([x0_0, x0_1, x0_2, x0_3, self.up(x1_3)], dim=1))

        if self.deep_supervision:
            output1 = self.output1(x0_1)
            output2 = self.output2(x0_2)
            output3 = self.output3(x0_3)
            output4 = self.output4(x0_4)
            output = (output1 + output2 + output3 + output4) / 4
        else:
            output = self.output(x0_4)

        return output

In [5]:
class DiceLoss(nn.Module):
    def __init__(self):
        super(DiceLoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        
        inputs = torch.nn.functional.sigmoid(inputs) # sigmoid를 통과한 출력이면 주석처리
        
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice = (2.*intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  
        
        return 1 - dice 

In [6]:
model = Nested_UNet().to(device)

# loss function과 optimizer 정의
criterion = DiceLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)


In [7]:
# training loop
folder = 'unet++_256+'
timestamp = time.time()
It = time.localtime(timestamp)
formatted = time.strftime("%m%d_%H_%M_%S", It)
path = os.path.join(folder, formatted)
os.makedirs(path)
best_loss = 1
# model.load_state_dict(torch.load('unet++_saved_weights_256_Dice\\0722_17_44_52\\best_model_epoch45_loss_0.16985569155516744.pth'))

for epoch in range(51):  # 10 에폭 동안 학습합니다.
    epoch_losses = []
    model.train()
    epoch_loss = 0
    for images, masks in tqdm(train_dataloader):
        images = images.float().to(device)
        masks = masks.float().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks.unsqueeze(1))
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        epoch_losses.append(loss.item())
    
    valid_epoch_losses = []
    model.eval()

    with torch.no_grad():
        for images, masks in tqdm(val_dataloader):

            images = images.float().to(device)
            masks = masks.float().to(device)

            outputs = model(images)
            
            valid_loss = criterion(outputs, masks.unsqueeze(1))
            valid_epoch_losses.append(valid_loss.item())
        
    
    if np.mean(valid_epoch_losses) < best_loss:
        best_loss = np.mean(valid_epoch_losses)
        saved_path = os.path.join(path,'best_model_epoch{0}_loss_{1}.pth'.format(epoch+1, np.mean(valid_epoch_losses)))
        torch.save(model.state_dict(), saved_path)
        print("Best model saved at epoch", epoch)
    

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {np.mean(epoch_losses)}')
    print(f'Valid_Mean_loss: {np.mean(valid_epoch_losses)}')


100%|██████████| 16065/16065 [46:32<00:00,  5.75it/s]
100%|██████████| 4017/4017 [05:08<00:00, 13.00it/s]


Best model saved at epoch 0
EPOCH: 0
Mean loss: 0.6066510289285993
Valid_Mean_loss: 0.4245412319727967


100%|██████████| 16065/16065 [41:41<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:56<00:00, 17.01it/s]


Best model saved at epoch 1
EPOCH: 1
Mean loss: 0.3167151631501617
Valid_Mean_loss: 0.254840088671109


100%|██████████| 16065/16065 [41:35<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.89it/s]


Best model saved at epoch 2
EPOCH: 2
Mean loss: 0.22827431687379676
Valid_Mean_loss: 0.2072008716576367


100%|██████████| 16065/16065 [41:37<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.92it/s]


Best model saved at epoch 3
EPOCH: 3
Mean loss: 0.2078226339323588
Valid_Mean_loss: 0.19930533651157728


100%|██████████| 16065/16065 [41:54<00:00,  6.39it/s]
100%|██████████| 4017/4017 [04:01<00:00, 16.60it/s]


Best model saved at epoch 4
EPOCH: 4
Mean loss: 0.1965495870115599
Valid_Mean_loss: 0.19121454706944604


100%|██████████| 16065/16065 [41:57<00:00,  6.38it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


Best model saved at epoch 5
EPOCH: 5
Mean loss: 0.18685742962074695
Valid_Mean_loss: 0.18893266783622417


100%|██████████| 16065/16065 [41:38<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.86it/s]


Best model saved at epoch 6
EPOCH: 6
Mean loss: 0.18012713232479696
Valid_Mean_loss: 0.17966746065727596


100%|██████████| 16065/16065 [41:46<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


Best model saved at epoch 7
EPOCH: 7
Mean loss: 0.17419991351123765
Valid_Mean_loss: 0.1756040692062321


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


Best model saved at epoch 8
EPOCH: 8
Mean loss: 0.16873857148369636
Valid_Mean_loss: 0.17437919331095483


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.99it/s]


Best model saved at epoch 9
EPOCH: 9
Mean loss: 0.1640653951569506
Valid_Mean_loss: 0.17003558769373267


100%|██████████| 16065/16065 [41:43<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.86it/s]


Best model saved at epoch 10
EPOCH: 10
Mean loss: 0.15965330495368346
Valid_Mean_loss: 0.16527684583039554


100%|██████████| 16065/16065 [41:45<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.95it/s]


EPOCH: 11
Mean loss: 0.15625512438550607
Valid_Mean_loss: 0.1675779360842877


100%|██████████| 16065/16065 [41:41<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.87it/s]


EPOCH: 12
Mean loss: 0.1528054823431685
Valid_Mean_loss: 0.16688952153376568


100%|██████████| 16065/16065 [41:45<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.88it/s]


Best model saved at epoch 13
EPOCH: 13
Mean loss: 0.14994244372180568
Valid_Mean_loss: 0.16024032267047483


100%|██████████| 16065/16065 [41:38<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.88it/s]


Best model saved at epoch 14
EPOCH: 14
Mean loss: 0.14653893081413932
Valid_Mean_loss: 0.1597843610560802


100%|██████████| 16065/16065 [41:35<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:55<00:00, 17.03it/s]


EPOCH: 15
Mean loss: 0.14442350951600513
Valid_Mean_loss: 0.16219736911302662


100%|██████████| 16065/16065 [41:37<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.87it/s]


Best model saved at epoch 16
EPOCH: 16
Mean loss: 0.1417384308407769
Valid_Mean_loss: 0.15628212536037162


100%|██████████| 16065/16065 [41:37<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.96it/s]


EPOCH: 17
Mean loss: 0.13971944998403668
Valid_Mean_loss: 0.15844019022574315


100%|██████████| 16065/16065 [41:47<00:00,  6.41it/s] 
100%|██████████| 4017/4017 [04:00<00:00, 16.69it/s]


Best model saved at epoch 18
EPOCH: 18
Mean loss: 0.13799907216170923
Valid_Mean_loss: 0.15462731273428906


100%|██████████| 16065/16065 [41:57<00:00,  6.38it/s]
100%|██████████| 4017/4017 [03:59<00:00, 16.76it/s]


EPOCH: 19
Mean loss: 0.1354704068154514
Valid_Mean_loss: 0.15733868371855717


100%|██████████| 16065/16065 [41:43<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:56<00:00, 17.00it/s]


EPOCH: 20
Mean loss: 0.13297428570912703
Valid_Mean_loss: 0.16023509261143395


100%|██████████| 16065/16065 [41:44<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.87it/s]


Best model saved at epoch 21
EPOCH: 21
Mean loss: 0.13095339430183628
Valid_Mean_loss: 0.15146355086486377


100%|██████████| 16065/16065 [41:43<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:56<00:00, 17.01it/s]


EPOCH: 22
Mean loss: 0.12902688105233243
Valid_Mean_loss: 0.15177019211615025


100%|██████████| 16065/16065 [41:44<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.85it/s]


Best model saved at epoch 23
EPOCH: 23
Mean loss: 0.12731109608603608
Valid_Mean_loss: 0.1500569918513565


100%|██████████| 16065/16065 [41:48<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:55<00:00, 17.03it/s]


EPOCH: 24
Mean loss: 0.12583067027571934
Valid_Mean_loss: 0.15244090080795403


100%|██████████| 16065/16065 [41:46<00:00,  6.41it/s]
100%|██████████| 4017/4017 [04:06<00:00, 16.32it/s]


Best model saved at epoch 25
EPOCH: 25
Mean loss: 0.12451154969022758
Valid_Mean_loss: 0.14840445672513122


100%|██████████| 16065/16065 [41:45<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.89it/s]


Best model saved at epoch 26
EPOCH: 26
Mean loss: 0.12261471816249032
Valid_Mean_loss: 0.1473381683130243


100%|██████████| 16065/16065 [41:44<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.92it/s]


Best model saved at epoch 27
EPOCH: 27
Mean loss: 0.12137613236032244
Valid_Mean_loss: 0.14492709079665986


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.94it/s]


EPOCH: 28
Mean loss: 0.11978636967442023
Valid_Mean_loss: 0.15124502615576096


100%|██████████| 16065/16065 [41:46<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.82it/s]


EPOCH: 29
Mean loss: 0.1183712191868155
Valid_Mean_loss: 0.1458763404512156


100%|██████████| 16065/16065 [41:45<00:00,  6.41it/s] 
100%|██████████| 4017/4017 [03:56<00:00, 16.98it/s]


EPOCH: 30
Mean loss: 0.11745258302210723
Valid_Mean_loss: 0.1475872447575566


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.86it/s]


Best model saved at epoch 31
EPOCH: 31
Mean loss: 0.1159833147709072
Valid_Mean_loss: 0.14418926027125614


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


Best model saved at epoch 32
EPOCH: 32
Mean loss: 0.11482926004221608
Valid_Mean_loss: 0.14408484964130944


100%|██████████| 16065/16065 [41:41<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.89it/s]


Best model saved at epoch 33
EPOCH: 33
Mean loss: 0.11341478425025049
Valid_Mean_loss: 0.14264913732387666


100%|██████████| 16065/16065 [41:40<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


Best model saved at epoch 34
EPOCH: 34
Mean loss: 0.11269433195628817
Valid_Mean_loss: 0.14141660906469047


100%|██████████| 16065/16065 [41:44<00:00,  6.41it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.91it/s]


Best model saved at epoch 35
EPOCH: 35
Mean loss: 0.11143487480396873
Valid_Mean_loss: 0.1412051581968441


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:56<00:00, 17.02it/s]


EPOCH: 36
Mean loss: 0.1099494730539497
Valid_Mean_loss: 0.14794248280194022


100%|██████████| 16065/16065 [41:50<00:00,  6.40it/s]
100%|██████████| 4017/4017 [03:59<00:00, 16.74it/s]


EPOCH: 37
Mean loss: 0.1091889329910872
Valid_Mean_loss: 0.14248959812501902


100%|██████████| 16065/16065 [41:42<00:00,  6.42it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.97it/s]


Best model saved at epoch 38
EPOCH: 38
Mean loss: 0.10827043445318835
Valid_Mean_loss: 0.14043823443214068


100%|██████████| 16065/16065 [41:35<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.87it/s]


Best model saved at epoch 39
EPOCH: 39
Mean loss: 0.10690576400557968
Valid_Mean_loss: 0.13976873442222743


100%|██████████| 16065/16065 [41:37<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.94it/s]


EPOCH: 40
Mean loss: 0.10606873482512122
Valid_Mean_loss: 0.14327407110262666


100%|██████████| 16065/16065 [41:33<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.95it/s]


Best model saved at epoch 41
EPOCH: 41
Mean loss: 0.10529806399100523
Valid_Mean_loss: 0.13973029939755


100%|██████████| 16065/16065 [41:32<00:00,  6.45it/s]
100%|██████████| 4017/4017 [03:55<00:00, 17.03it/s]


Best model saved at epoch 42
EPOCH: 42
Mean loss: 0.10398142804915461
Valid_Mean_loss: 0.13914746666298833


100%|██████████| 16065/16065 [41:34<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


Best model saved at epoch 43
EPOCH: 43
Mean loss: 0.10368929277726359
Valid_Mean_loss: 0.13719912275082144


100%|██████████| 16065/16065 [41:33<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]


EPOCH: 44
Mean loss: 0.10263019084856216
Valid_Mean_loss: 0.1384660553392083


100%|██████████| 16065/16065 [41:38<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.84it/s]


Best model saved at epoch 45
EPOCH: 45
Mean loss: 0.10132894620248324
Valid_Mean_loss: 0.1371867954360035


100%|██████████| 16065/16065 [41:35<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.97it/s]


EPOCH: 46
Mean loss: 0.10025856605468048
Valid_Mean_loss: 0.13946346646670474


100%|██████████| 16065/16065 [41:37<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:58<00:00, 16.86it/s]


EPOCH: 47
Mean loss: 0.09961202254816263
Valid_Mean_loss: 0.137996917314603


100%|██████████| 16065/16065 [41:38<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:56<00:00, 16.98it/s]


EPOCH: 48
Mean loss: 0.09907008446973319
Valid_Mean_loss: 0.14000851553947083


100%|██████████| 16065/16065 [41:35<00:00,  6.44it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.90it/s]


Best model saved at epoch 49
EPOCH: 49
Mean loss: 0.09803181461111021
Valid_Mean_loss: 0.13544352678686403


100%|██████████| 16065/16065 [41:38<00:00,  6.43it/s]
100%|██████████| 4017/4017 [03:57<00:00, 16.93it/s]

EPOCH: 50
Mean loss: 0.09709518626958383
Valid_Mean_loss: 0.13663760472329245





In [8]:
test_dataset = SatelliteDataset(csv_file='./test.csv', transform=transform, infer=True)
test_dataloader = DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)

In [9]:
#model.load_state_dict(torch.load("unet++_512+\\0819_11_34_27\\best_model_epoch51_loss_0.10628819498613497.pth"))

In [10]:
with torch.no_grad():
    model.eval()
    result = []
    for images in tqdm(test_dataloader):
        images = images.float().to(device)

        outputs = model(images)
        masks = torch.sigmoid(outputs).cpu().numpy()
        masks = np.squeeze(masks, axis=1)
        masks = (masks > 0.35).astype(np.uint8) # Threshold = 0.35

        for i in range(len(images)):
            mask_rle = rle_encode(masks[i])
            if mask_rle == '': # 예측된 건물 픽셀이 아예 없는 경우 -1
                result.append(-1)
            else:
                result.append(mask_rle)

submit = pd.read_csv('./sample_submission.csv')
submit['mask_rle'] = result

submit.to_csv('./submit.csv', index=False)

100%|██████████| 179/179 [01:18<00:00,  2.27it/s]
