In [1]:
import os
import pandas as pd
import numpy as np
import cv2
from torchvision.io import read_image
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, random_split, DataLoader
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2

from torchvision.transforms import ToTensor
from PIL import Image
import os

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision 
from torchvision import transforms
from torchinfo import summary
import timm
import segmentation_models_pytorch as smp
import wandb
from torch.optim import lr_scheduler

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
class DatasetCustom(Dataset):
    def __init__(self, img_dir, label_dir, resize=None, transform=None):
        self.img_dir = img_dir
        self.label_dir = label_dir
        self.resize = resize
        self.transform = transform
        self.images = os.listdir(self.img_dir)

    def __len__(self):
        return len(self.images)
    
    def read_mask(self, mask_path):
        image = cv2.imread(mask_path)
        image = cv2.resize(image, self.resize)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)

        lower_red1 = np.array([0, 100, 20])
        upper_red1 = np.array([10, 255, 255])
        lower_red2 = np.array([160,100,20])
        upper_red2 = np.array([179,255,255])
        
        lower_mask_red = cv2.inRange(image, lower_red1, upper_red1)
        upper_mask_red = cv2.inRange(image, lower_red2, upper_red2)
        
        red_mask = lower_mask_red + upper_mask_red
        red_mask[red_mask != 0] = 1

        green_mask = cv2.inRange(image, (36, 25, 25), (70, 255, 255))
        green_mask[green_mask != 0] = 2

        full_mask = cv2.bitwise_or(red_mask, green_mask)
        full_mask = np.expand_dims(full_mask, axis=-1) 
        full_mask = full_mask.astype(np.uint8)
        
        return full_mask

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.images[idx])
        label_path = os.path.join(self.label_dir, self.images[idx])
        image = cv2.imread(img_path)  #  BGR
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # Convert to RGB
        label = self.read_mask(label_path)  
        image = cv2.resize(image, self.resize)
        if self.transform:
            image = self.transform(image)
            
        return image, label

In [4]:
images_path = "../data/train/"
TRAIN_DIR = "../data/train/"

image_path = []
for root, dirs, files in os.walk(TRAIN_DIR):
    for file in files:
        path = os.path.join(root,file)
        image_path.append(path)
        
len(image_path)

1000

In [5]:
mask_path = []
TRAIN_MASK_DIR = '../data/train_gt'
for root, dirs, files in os.walk(TRAIN_MASK_DIR):
    for file in files:
        path = os.path.join(root,file)
        mask_path.append(path)
        
len(mask_path)

1000

In [6]:
dataset = DatasetCustom(img_dir= TRAIN_DIR,
                             label_dir= TRAIN_MASK_DIR,
                             resize= (256,256),
                             transform = None)

In [7]:
batch_size = 8
images_data = []
labels_data = []
for x,y in dataset:
    images_data.append(x)
    labels_data.append(y)

In [8]:
import segmentation_models_pytorch as smp

model = smp.UnetPlusPlus(
    encoder_name="resnet34",        
    encoder_weights="imagenet",     
    in_channels=3,                  
    classes=3     
)

In [9]:
class CustomDataset(Dataset):
    def __init__(self, data, targets, transform=None):
        self.data = data
        self.targets = targets
        self.transform = transform

    def __getitem__(self, index):
        image = self.data[index]
        label = self.targets[index]
        assert image.shape[:2] == label.shape[:2]
        if self.transform:
            transformed = self.transform(image=image, mask=label)
            image = transformed['image'].float()
            label = transformed['mask'].float()
            label = label.permute(2, 0, 1)
        return image, label
    
    def __len__(self):
        return len(self.data)

In [10]:
train_transformation = A.Compose([
    A.HorizontalFlip(p=0.4),
    A.VerticalFlip(p=0.4),
    A.RandomGamma (gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),
    A.RGBShift(p=0.3, r_shift_limit=10, g_shift_limit=10, b_shift_limit=10),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

val_transformation = A.Compose([
    A.Normalize(mean=(0.485, 0.456, 0.406),std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

  A.RandomGamma (gamma_limit=(70, 130), eps=None, always_apply=False, p=0.2),


In [11]:
train_size = int(0.8 * len(images_data))
val_size = len(images_data) - train_size
train_dataset = CustomDataset(images_data[:train_size], labels_data[:train_size], transform=train_transformation)
val_dataset = CustomDataset(images_data[train_size:], labels_data[train_size:], transform=val_transformation)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

In [12]:
learning_rate = 0.01
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
lr_scheduler = lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.6)

In [13]:
color_dict= {0: (0, 0, 0),
             1: (255, 0, 0),
             2: (0, 255, 0)}
def mask_to_rgb(mask, color_dict):
    output = np.zeros((mask.shape[0], mask.shape[1], 3))

    for k in color_dict.keys():
        output[mask==k] = color_dict[k]

    return np.uint8(output)    

In [14]:
# wandb.login(
#     # set the wandb project where this run will be logged
# #     project= "PolypSegment", 
#     key = '4',
# )
wandb.init(
    project = "ProjectNeo"
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mvubkk67[0m ([33mvubkk67-hanoi-university-of-science-and-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [15]:
from tqdm import tqdm
import time

num_epochs = 250

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.CrossEntropyLoss()
best_val_loss = 999

epoch_bar = tqdm(total=num_epochs, desc='Total Progress')

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        labels = labels.squeeze(dim=1).long()
        outputs = model(images)
    
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
    model.eval()
    with torch.no_grad():
        val_loss = 0
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            labels = labels.squeeze(dim=1).long()
            
            outputs = model(images)

            val_loss += criterion(outputs.float(),labels.long()).item()

    
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {val_loss/len(val_loader):.10f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        checkpoint = { 
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': val_loss,
        }
        save_path = f'model_step.pth'
        torch.save(checkpoint, save_path)
        
    lr_scheduler.step()
    
    
    epoch_bar.update(1)
    wandb.log({'Val_loss': val_loss/len(val_loader),'Train_loss': train_loss/len(train_loader)})
epoch_bar.close()

Total Progress:   0%|          | 0/250 [00:00<?, ?it/s]

Epoch [1/250], Loss: 0.2381437624


Total Progress:   0%|          | 1/250 [00:31<2:11:58, 31.80s/it]

Epoch [2/250], Loss: 0.1747382361


Total Progress:   1%|          | 3/250 [01:34<2:08:43, 31.27s/it]

Epoch [3/250], Loss: 0.1810319456
Epoch [4/250], Loss: 0.1585404760


Total Progress:   2%|▏         | 4/250 [02:05<2:08:32, 31.35s/it]

Epoch [5/250], Loss: 0.1460301036


Total Progress:   2%|▏         | 6/250 [03:07<2:06:31, 31.11s/it]

Epoch [6/250], Loss: 0.1636492619


Total Progress:   3%|▎         | 7/250 [03:38<2:05:42, 31.04s/it]

Epoch [7/250], Loss: 0.1516215104
Epoch [8/250], Loss: 0.1454635930


Total Progress:   4%|▎         | 9/250 [04:40<2:04:26, 30.98s/it]

Epoch [9/250], Loss: 0.1681075114
Epoch [10/250], Loss: 0.1296186107


Total Progress:   4%|▍         | 11/250 [05:42<2:03:56, 31.12s/it]

Epoch [11/250], Loss: 0.1343302751
Epoch [12/250], Loss: 0.1269480386


Total Progress:   5%|▍         | 12/250 [06:14<2:03:31, 31.14s/it]

Epoch [13/250], Loss: 0.1213922960


Total Progress:   6%|▌         | 14/250 [07:16<2:02:09, 31.06s/it]

Epoch [14/250], Loss: 0.1264961490


Total Progress:   6%|▌         | 15/250 [07:47<2:01:19, 30.97s/it]

Epoch [15/250], Loss: 0.1245736897
Epoch [16/250], Loss: 0.1176146117


Total Progress:   6%|▋         | 16/250 [08:18<2:01:02, 31.04s/it]

Epoch [17/250], Loss: 0.1165895495


Total Progress:   7%|▋         | 17/250 [08:49<2:00:48, 31.11s/it]

Epoch [18/250], Loss: 0.1141020125


Total Progress:   7%|▋         | 18/250 [09:20<2:00:38, 31.20s/it]

Epoch [19/250], Loss: 0.1083343625


Total Progress:   8%|▊         | 20/250 [10:23<1:59:30, 31.18s/it]

Epoch [20/250], Loss: 0.1122269472


Total Progress:   8%|▊         | 21/250 [10:54<1:58:29, 31.04s/it]

Epoch [21/250], Loss: 0.1092552704
Epoch [22/250], Loss: 0.1034992626


Total Progress:   9%|▉         | 23/250 [11:55<1:57:06, 30.95s/it]

Epoch [23/250], Loss: 0.1035807699
Epoch [24/250], Loss: 0.1032020143


Total Progress:  10%|▉         | 24/250 [12:26<1:56:45, 31.00s/it]

Epoch [25/250], Loss: 0.1005651391


Total Progress:  10%|█         | 26/250 [13:29<1:56:01, 31.08s/it]

Epoch [26/250], Loss: 0.1039748405
Epoch [27/250], Loss: 0.0978643930


Total Progress:  11%|█         | 27/250 [14:00<1:55:38, 31.11s/it]

Epoch [28/250], Loss: 0.0939109805


Total Progress:  11%|█         | 28/250 [14:31<1:55:13, 31.14s/it]

Epoch [29/250], Loss: 0.0924526004


Total Progress:  12%|█▏        | 29/250 [15:03<1:54:57, 31.21s/it]

Epoch [30/250], Loss: 0.0917557710


Total Progress:  12%|█▏        | 30/250 [15:34<1:54:45, 31.30s/it]

Epoch [31/250], Loss: 0.0909858853


Total Progress:  12%|█▏        | 31/250 [16:06<1:54:20, 31.33s/it]

Epoch [32/250], Loss: 0.0882511285


Total Progress:  13%|█▎        | 32/250 [16:37<1:54:19, 31.46s/it]

Epoch [33/250], Loss: 0.0862649795


Total Progress:  13%|█▎        | 33/250 [17:09<1:53:39, 31.42s/it]

Epoch [34/250], Loss: 0.0853711578


Total Progress:  14%|█▎        | 34/250 [17:40<1:53:20, 31.48s/it]

Epoch [35/250], Loss: 0.0836021286


Total Progress:  14%|█▍        | 35/250 [18:12<1:52:33, 31.41s/it]

Epoch [36/250], Loss: 0.0813662465


Total Progress:  14%|█▍        | 36/250 [18:43<1:51:56, 31.39s/it]

Epoch [37/250], Loss: 0.0808849268


Total Progress:  15%|█▌        | 38/250 [19:45<1:50:13, 31.20s/it]

Epoch [38/250], Loss: 0.0810697620


Total Progress:  16%|█▌        | 39/250 [20:16<1:49:13, 31.06s/it]

Epoch [39/250], Loss: 0.0826689726
Epoch [40/250], Loss: 0.0799581149


Total Progress:  16%|█▌        | 40/250 [20:47<1:48:42, 31.06s/it]

Epoch [41/250], Loss: 0.0786469612


Total Progress:  16%|█▋        | 41/250 [21:18<1:48:36, 31.18s/it]

Epoch [42/250], Loss: 0.0782107586


Total Progress:  17%|█▋        | 43/250 [22:20<1:47:17, 31.10s/it]

Epoch [43/250], Loss: 0.0785660766


Total Progress:  18%|█▊        | 44/250 [22:51<1:46:30, 31.02s/it]

Epoch [44/250], Loss: 0.0787437195
Epoch [45/250], Loss: 0.0773919325


Total Progress:  18%|█▊        | 45/250 [23:23<1:46:18, 31.11s/it]

Epoch [46/250], Loss: 0.0758047028


Total Progress:  18%|█▊        | 46/250 [23:54<1:46:01, 31.18s/it]

Epoch [47/250], Loss: 0.0753553486


Total Progress:  19%|█▉        | 48/250 [24:56<1:45:06, 31.22s/it]

Epoch [48/250], Loss: 0.0761983375


Total Progress:  20%|█▉        | 49/250 [25:27<1:44:19, 31.14s/it]

Epoch [49/250], Loss: 0.0762974203
Epoch [50/250], Loss: 0.0753344807


Total Progress:  20%|██        | 51/250 [26:30<1:43:23, 31.17s/it]

Epoch [51/250], Loss: 0.0758359765
Epoch [52/250], Loss: 0.0750648776


Total Progress:  21%|██        | 52/250 [27:01<1:43:04, 31.24s/it]

Epoch [53/250], Loss: 0.0749859259


Total Progress:  21%|██        | 53/250 [27:33<1:42:42, 31.28s/it]

Epoch [54/250], Loss: 0.0747172067


Total Progress:  22%|██▏       | 54/250 [28:04<1:42:16, 31.31s/it]

Epoch [55/250], Loss: 0.0745754164


Total Progress:  22%|██▏       | 56/250 [29:07<1:41:08, 31.28s/it]

Epoch [56/250], Loss: 0.0752752249


Total Progress:  23%|██▎       | 57/250 [29:38<1:40:16, 31.17s/it]

Epoch [57/250], Loss: 0.0754263444


Total Progress:  23%|██▎       | 58/250 [30:09<1:39:36, 31.13s/it]

Epoch [58/250], Loss: 0.0749814218


Total Progress:  24%|██▎       | 59/250 [30:39<1:38:50, 31.05s/it]

Epoch [59/250], Loss: 0.0752688316


Total Progress:  24%|██▍       | 60/250 [31:10<1:38:15, 31.03s/it]

Epoch [60/250], Loss: 0.0754422912


Total Progress:  24%|██▍       | 61/250 [31:41<1:37:38, 31.00s/it]

Epoch [61/250], Loss: 0.0755719675


Total Progress:  25%|██▍       | 62/250 [32:12<1:37:14, 31.03s/it]

Epoch [62/250], Loss: 0.0746522328


Total Progress:  25%|██▌       | 63/250 [32:44<1:36:58, 31.12s/it]

Epoch [63/250], Loss: 0.0755354084


Total Progress:  26%|██▌       | 64/250 [33:15<1:36:19, 31.07s/it]

Epoch [64/250], Loss: 0.0769176827
Epoch [65/250], Loss: 0.0745372160


Total Progress:  26%|██▋       | 66/250 [34:17<1:35:21, 31.09s/it]

Epoch [66/250], Loss: 0.0754941066
Epoch [67/250], Loss: 0.0744648868


Total Progress:  27%|██▋       | 68/250 [35:20<1:34:27, 31.14s/it]

Epoch [68/250], Loss: 0.0745096193


Total Progress:  28%|██▊       | 69/250 [35:51<1:34:05, 31.19s/it]

Epoch [69/250], Loss: 0.0750259422


Total Progress:  28%|██▊       | 70/250 [36:22<1:33:29, 31.17s/it]

Epoch [70/250], Loss: 0.0752294335


Total Progress:  28%|██▊       | 71/250 [36:53<1:32:49, 31.11s/it]

Epoch [71/250], Loss: 0.0745788445
Epoch [72/250], Loss: 0.0743121116


Total Progress:  29%|██▉       | 73/250 [37:55<1:31:49, 31.13s/it]

Epoch [73/250], Loss: 0.0746732931


Total Progress:  30%|██▉       | 74/250 [38:26<1:31:24, 31.16s/it]

Epoch [74/250], Loss: 0.0746867257


Total Progress:  30%|███       | 75/250 [38:57<1:30:43, 31.11s/it]

Epoch [75/250], Loss: 0.0753854963


Total Progress:  30%|███       | 76/250 [39:28<1:30:01, 31.04s/it]

Epoch [76/250], Loss: 0.0750858375


Total Progress:  31%|███       | 77/250 [40:00<1:29:37, 31.09s/it]

Epoch [77/250], Loss: 0.0752473456


Total Progress:  31%|███       | 78/250 [40:31<1:29:03, 31.07s/it]

Epoch [78/250], Loss: 0.0744327100


Total Progress:  32%|███▏      | 79/250 [41:02<1:28:35, 31.08s/it]

Epoch [79/250], Loss: 0.0750304785


Total Progress:  32%|███▏      | 80/250 [41:33<1:28:01, 31.07s/it]

Epoch [80/250], Loss: 0.0747873893


Total Progress:  32%|███▏      | 81/250 [42:04<1:27:23, 31.03s/it]

Epoch [81/250], Loss: 0.0750277229


Total Progress:  33%|███▎      | 82/250 [42:35<1:26:55, 31.05s/it]

Epoch [82/250], Loss: 0.0761634630


Total Progress:  33%|███▎      | 83/250 [43:06<1:26:41, 31.15s/it]

Epoch [83/250], Loss: 0.0747374557


Total Progress:  34%|███▎      | 84/250 [43:37<1:26:13, 31.17s/it]

Epoch [84/250], Loss: 0.0746153918


Total Progress:  34%|███▍      | 85/250 [44:08<1:25:40, 31.15s/it]

Epoch [85/250], Loss: 0.0749258806


Total Progress:  34%|███▍      | 86/250 [44:39<1:24:56, 31.08s/it]

Epoch [86/250], Loss: 0.0751883648


Total Progress:  35%|███▍      | 87/250 [45:10<1:24:25, 31.08s/it]

Epoch [87/250], Loss: 0.0743743290


Total Progress:  35%|███▌      | 88/250 [45:41<1:23:44, 31.02s/it]

Epoch [88/250], Loss: 0.0743819010


Total Progress:  36%|███▌      | 89/250 [46:12<1:23:07, 30.98s/it]

Epoch [89/250], Loss: 0.0748504703


Total Progress:  36%|███▌      | 90/250 [46:43<1:22:30, 30.94s/it]

Epoch [90/250], Loss: 0.0756037848


Total Progress:  36%|███▋      | 91/250 [47:14<1:22:06, 30.98s/it]

Epoch [91/250], Loss: 0.0754282184


Total Progress:  37%|███▋      | 92/250 [47:45<1:21:28, 30.94s/it]

Epoch [92/250], Loss: 0.0749997871


Total Progress:  37%|███▋      | 93/250 [48:16<1:20:54, 30.92s/it]

Epoch [93/250], Loss: 0.0745922750


Total Progress:  38%|███▊      | 94/250 [48:47<1:20:28, 30.95s/it]

Epoch [94/250], Loss: 0.0747389320


Total Progress:  38%|███▊      | 95/250 [49:18<1:20:04, 30.99s/it]

Epoch [95/250], Loss: 0.0751930372


Total Progress:  38%|███▊      | 96/250 [49:49<1:19:30, 30.98s/it]

Epoch [96/250], Loss: 0.0760664272


Total Progress:  39%|███▉      | 97/250 [50:20<1:19:05, 31.01s/it]

Epoch [97/250], Loss: 0.0756342818


Total Progress:  39%|███▉      | 98/250 [50:51<1:18:39, 31.05s/it]

Epoch [98/250], Loss: 0.0744197749


Total Progress:  40%|███▉      | 99/250 [51:22<1:18:09, 31.06s/it]

Epoch [99/250], Loss: 0.0751483127


Total Progress:  40%|████      | 100/250 [51:53<1:17:42, 31.09s/it]

Epoch [100/250], Loss: 0.0752213334


Total Progress:  40%|████      | 101/250 [52:24<1:17:03, 31.03s/it]

Epoch [101/250], Loss: 0.0747448784


Total Progress:  41%|████      | 102/250 [52:55<1:16:25, 30.98s/it]

Epoch [102/250], Loss: 0.0748386395


Total Progress:  41%|████      | 103/250 [53:26<1:15:51, 30.96s/it]

Epoch [103/250], Loss: 0.0744390535


Total Progress:  42%|████▏     | 104/250 [53:57<1:15:21, 30.97s/it]

Epoch [104/250], Loss: 0.0747660345


Total Progress:  42%|████▏     | 105/250 [54:28<1:14:47, 30.95s/it]

Epoch [105/250], Loss: 0.0754875788


Total Progress:  42%|████▏     | 106/250 [54:59<1:14:13, 30.93s/it]

Epoch [106/250], Loss: 0.0745589775


Total Progress:  43%|████▎     | 107/250 [55:30<1:13:38, 30.90s/it]

Epoch [107/250], Loss: 0.0751098812


Total Progress:  43%|████▎     | 108/250 [56:01<1:13:16, 30.96s/it]

Epoch [108/250], Loss: 0.0751617797


Total Progress:  44%|████▎     | 109/250 [56:32<1:12:51, 31.00s/it]

Epoch [109/250], Loss: 0.0746250084


Total Progress:  44%|████▍     | 110/250 [57:03<1:12:26, 31.04s/it]

Epoch [110/250], Loss: 0.0746526477


Total Progress:  44%|████▍     | 111/250 [57:34<1:12:02, 31.10s/it]

Epoch [111/250], Loss: 0.0752949868


Total Progress:  45%|████▍     | 112/250 [58:05<1:11:32, 31.11s/it]

Epoch [112/250], Loss: 0.0749678226


Total Progress:  45%|████▌     | 113/250 [58:36<1:10:58, 31.08s/it]

Epoch [113/250], Loss: 0.0750661781
Epoch [114/250], Loss: 0.0742294760


Total Progress:  46%|████▌     | 115/250 [59:39<1:10:09, 31.18s/it]

Epoch [115/250], Loss: 0.0748701386


Total Progress:  46%|████▋     | 116/250 [1:00:10<1:09:27, 31.10s/it]

Epoch [116/250], Loss: 0.0750709753


Total Progress:  47%|████▋     | 117/250 [1:00:41<1:08:50, 31.05s/it]

Epoch [117/250], Loss: 0.0752064739


Total Progress:  47%|████▋     | 118/250 [1:01:12<1:08:30, 31.14s/it]

Epoch [118/250], Loss: 0.0750775501


Total Progress:  48%|████▊     | 119/250 [1:01:43<1:08:01, 31.16s/it]

Epoch [119/250], Loss: 0.0746861626


Total Progress:  48%|████▊     | 120/250 [1:02:15<1:07:31, 31.17s/it]

Epoch [120/250], Loss: 0.0750993198


Total Progress:  48%|████▊     | 121/250 [1:02:46<1:06:54, 31.12s/it]

Epoch [121/250], Loss: 0.0747344609
Epoch [122/250], Loss: 0.0741944875


Total Progress:  49%|████▉     | 123/250 [1:03:49<1:06:14, 31.29s/it]

Epoch [123/250], Loss: 0.0747812903


Total Progress:  50%|████▉     | 124/250 [1:04:20<1:05:52, 31.37s/it]

Epoch [124/250], Loss: 0.0748486987
Epoch [125/250], Loss: 0.0741056119


Total Progress:  50%|█████     | 126/250 [1:05:23<1:04:52, 31.39s/it]

Epoch [126/250], Loss: 0.0749356213


Total Progress:  51%|█████     | 127/250 [1:05:54<1:04:10, 31.30s/it]

Epoch [127/250], Loss: 0.0745696174


Total Progress:  51%|█████     | 128/250 [1:06:25<1:03:30, 31.23s/it]

Epoch [128/250], Loss: 0.0745619409


Total Progress:  52%|█████▏    | 129/250 [1:06:57<1:02:57, 31.22s/it]

Epoch [129/250], Loss: 0.0747280374


Total Progress:  52%|█████▏    | 130/250 [1:07:28<1:02:27, 31.23s/it]

Epoch [130/250], Loss: 0.0747298819


Total Progress:  52%|█████▏    | 131/250 [1:07:59<1:01:47, 31.16s/it]

Epoch [131/250], Loss: 0.0749492687


Total Progress:  53%|█████▎    | 132/250 [1:08:30<1:01:08, 31.09s/it]

Epoch [132/250], Loss: 0.0748299336


Total Progress:  53%|█████▎    | 133/250 [1:09:01<1:00:36, 31.08s/it]

Epoch [133/250], Loss: 0.0744758178


Total Progress:  54%|█████▎    | 134/250 [1:09:32<1:00:13, 31.15s/it]

Epoch [134/250], Loss: 0.0749971867


Total Progress:  54%|█████▍    | 135/250 [1:10:03<59:44, 31.17s/it]  

Epoch [135/250], Loss: 0.0745606306


Total Progress:  54%|█████▍    | 136/250 [1:10:35<59:16, 31.20s/it]

Epoch [136/250], Loss: 0.0759682123


Total Progress:  55%|█████▍    | 137/250 [1:11:06<58:43, 31.18s/it]

Epoch [137/250], Loss: 0.0752742456


Total Progress:  55%|█████▌    | 138/250 [1:11:37<58:11, 31.17s/it]

Epoch [138/250], Loss: 0.0744793774


Total Progress:  56%|█████▌    | 139/250 [1:12:08<57:49, 31.26s/it]

Epoch [139/250], Loss: 0.0743927754


Total Progress:  56%|█████▌    | 140/250 [1:12:39<57:09, 31.18s/it]

Epoch [140/250], Loss: 0.0749689679


Total Progress:  56%|█████▋    | 141/250 [1:13:10<56:27, 31.08s/it]

Epoch [141/250], Loss: 0.0748808654


Total Progress:  57%|█████▋    | 142/250 [1:13:41<55:51, 31.04s/it]

Epoch [142/250], Loss: 0.0750296648


Total Progress:  57%|█████▋    | 143/250 [1:14:12<55:16, 30.99s/it]

Epoch [143/250], Loss: 0.0747282153


Total Progress:  58%|█████▊    | 144/250 [1:14:43<54:43, 30.98s/it]

Epoch [144/250], Loss: 0.0750341481


Total Progress:  58%|█████▊    | 145/250 [1:15:14<54:17, 31.02s/it]

Epoch [145/250], Loss: 0.0745397888


Total Progress:  58%|█████▊    | 146/250 [1:15:45<53:47, 31.03s/it]

Epoch [146/250], Loss: 0.0752869026


Total Progress:  59%|█████▉    | 147/250 [1:16:16<53:14, 31.02s/it]

Epoch [147/250], Loss: 0.0750453071


Total Progress:  59%|█████▉    | 148/250 [1:16:47<52:39, 30.97s/it]

Epoch [148/250], Loss: 0.0755355242


Total Progress:  60%|█████▉    | 149/250 [1:17:18<52:08, 30.98s/it]

Epoch [149/250], Loss: 0.0755902231


Total Progress:  60%|██████    | 150/250 [1:17:49<51:44, 31.04s/it]

Epoch [150/250], Loss: 0.0749593215


Total Progress:  60%|██████    | 151/250 [1:18:20<51:09, 31.01s/it]

Epoch [151/250], Loss: 0.0744568558


Total Progress:  61%|██████    | 152/250 [1:18:51<50:37, 30.99s/it]

Epoch [152/250], Loss: 0.0749203204


Total Progress:  61%|██████    | 153/250 [1:19:22<50:08, 31.02s/it]

Epoch [153/250], Loss: 0.0747505632


Total Progress:  62%|██████▏   | 154/250 [1:19:53<49:46, 31.11s/it]

Epoch [154/250], Loss: 0.0744573997


Total Progress:  62%|██████▏   | 155/250 [1:20:24<49:11, 31.06s/it]

Epoch [155/250], Loss: 0.0748056275


Total Progress:  62%|██████▏   | 156/250 [1:20:55<48:33, 30.99s/it]

Epoch [156/250], Loss: 0.0748830351


Total Progress:  63%|██████▎   | 157/250 [1:21:26<47:57, 30.94s/it]

Epoch [157/250], Loss: 0.0747957401


Total Progress:  63%|██████▎   | 158/250 [1:21:57<47:29, 30.98s/it]

Epoch [158/250], Loss: 0.0749024014


Total Progress:  64%|██████▎   | 159/250 [1:22:28<47:00, 31.00s/it]

Epoch [159/250], Loss: 0.0750788189


Total Progress:  64%|██████▍   | 160/250 [1:22:59<46:31, 31.01s/it]

Epoch [160/250], Loss: 0.0748211683


Total Progress:  64%|██████▍   | 161/250 [1:23:30<46:01, 31.02s/it]

Epoch [161/250], Loss: 0.0743250278


Total Progress:  65%|██████▍   | 162/250 [1:24:02<45:40, 31.14s/it]

Epoch [162/250], Loss: 0.0748745568


Total Progress:  65%|██████▌   | 163/250 [1:24:33<45:06, 31.11s/it]

Epoch [163/250], Loss: 0.0746070035


Total Progress:  66%|██████▌   | 164/250 [1:25:04<44:35, 31.11s/it]

Epoch [164/250], Loss: 0.0748054038


Total Progress:  66%|██████▌   | 165/250 [1:25:35<44:11, 31.19s/it]

Epoch [165/250], Loss: 0.0749641646


Total Progress:  66%|██████▋   | 166/250 [1:26:06<43:38, 31.17s/it]

Epoch [166/250], Loss: 0.0743394765


Total Progress:  67%|██████▋   | 167/250 [1:26:37<43:07, 31.18s/it]

Epoch [167/250], Loss: 0.0745894104


Total Progress:  67%|██████▋   | 168/250 [1:27:09<42:39, 31.22s/it]

Epoch [168/250], Loss: 0.0754344618


Total Progress:  68%|██████▊   | 169/250 [1:27:40<42:01, 31.12s/it]

Epoch [169/250], Loss: 0.0747314833


Total Progress:  68%|██████▊   | 170/250 [1:28:11<41:28, 31.10s/it]

Epoch [170/250], Loss: 0.0760465114


Total Progress:  68%|██████▊   | 171/250 [1:28:42<41:00, 31.15s/it]

Epoch [171/250], Loss: 0.0746722379


Total Progress:  69%|██████▉   | 172/250 [1:29:13<40:30, 31.16s/it]

Epoch [172/250], Loss: 0.0748110645


Total Progress:  69%|██████▉   | 173/250 [1:29:44<39:55, 31.11s/it]

Epoch [173/250], Loss: 0.0748387682


Total Progress:  70%|██████▉   | 174/250 [1:30:15<39:21, 31.08s/it]

Epoch [174/250], Loss: 0.0746956227


Total Progress:  70%|███████   | 175/250 [1:30:46<38:52, 31.11s/it]

Epoch [175/250], Loss: 0.0750816320


Total Progress:  70%|███████   | 176/250 [1:31:18<38:24, 31.14s/it]

Epoch [176/250], Loss: 0.0750057325


Total Progress:  71%|███████   | 177/250 [1:31:49<37:52, 31.13s/it]

Epoch [177/250], Loss: 0.0741359378


Total Progress:  71%|███████   | 178/250 [1:32:20<37:18, 31.08s/it]

Epoch [178/250], Loss: 0.0750511025


Total Progress:  72%|███████▏  | 179/250 [1:32:50<36:41, 31.01s/it]

Epoch [179/250], Loss: 0.0752971633


Total Progress:  72%|███████▏  | 180/250 [1:33:21<36:05, 30.93s/it]

Epoch [180/250], Loss: 0.0756252688


Total Progress:  72%|███████▏  | 181/250 [1:33:52<35:34, 30.93s/it]

Epoch [181/250], Loss: 0.0748715897


Total Progress:  73%|███████▎  | 182/250 [1:34:23<35:01, 30.90s/it]

Epoch [182/250], Loss: 0.0742054440


Total Progress:  73%|███████▎  | 183/250 [1:34:54<34:34, 30.96s/it]

Epoch [183/250], Loss: 0.0749079555


Total Progress:  74%|███████▎  | 184/250 [1:35:25<33:59, 30.90s/it]

Epoch [184/250], Loss: 0.0751662663


Total Progress:  74%|███████▍  | 185/250 [1:35:55<33:23, 30.82s/it]

Epoch [185/250], Loss: 0.0745786686


Total Progress:  74%|███████▍  | 186/250 [1:36:26<32:48, 30.76s/it]

Epoch [186/250], Loss: 0.0750359067


Total Progress:  75%|███████▍  | 187/250 [1:36:57<32:14, 30.71s/it]

Epoch [187/250], Loss: 0.0748139752


Total Progress:  75%|███████▌  | 188/250 [1:37:27<31:42, 30.68s/it]

Epoch [188/250], Loss: 0.0746909525


Total Progress:  76%|███████▌  | 189/250 [1:37:58<31:10, 30.67s/it]

Epoch [189/250], Loss: 0.0749895109


Total Progress:  76%|███████▌  | 190/250 [1:38:29<30:39, 30.65s/it]

Epoch [190/250], Loss: 0.0743791118


Total Progress:  76%|███████▋  | 191/250 [1:38:59<30:08, 30.65s/it]

Epoch [191/250], Loss: 0.0744327876


Total Progress:  77%|███████▋  | 192/250 [1:39:30<29:37, 30.64s/it]

Epoch [192/250], Loss: 0.0746158771


Total Progress:  77%|███████▋  | 193/250 [1:40:00<29:05, 30.63s/it]

Epoch [193/250], Loss: 0.0744783396


Total Progress:  78%|███████▊  | 194/250 [1:40:31<28:35, 30.63s/it]

Epoch [194/250], Loss: 0.0744110803


Total Progress:  78%|███████▊  | 195/250 [1:41:02<28:03, 30.61s/it]

Epoch [195/250], Loss: 0.0753861074


Total Progress:  78%|███████▊  | 196/250 [1:41:32<27:32, 30.60s/it]

Epoch [196/250], Loss: 0.0748603767


Total Progress:  79%|███████▉  | 197/250 [1:42:03<27:01, 30.59s/it]

Epoch [197/250], Loss: 0.0745733066


Total Progress:  79%|███████▉  | 198/250 [1:42:33<26:29, 30.58s/it]

Epoch [198/250], Loss: 0.0751165259


Total Progress:  80%|███████▉  | 199/250 [1:43:04<25:59, 30.57s/it]

Epoch [199/250], Loss: 0.0745317368


Total Progress:  80%|████████  | 200/250 [1:43:34<25:28, 30.57s/it]

Epoch [200/250], Loss: 0.0745485161


Total Progress:  80%|████████  | 201/250 [1:44:05<24:57, 30.57s/it]

Epoch [201/250], Loss: 0.0747243559


Total Progress:  81%|████████  | 202/250 [1:44:36<24:27, 30.56s/it]

Epoch [202/250], Loss: 0.0744620359


Total Progress:  81%|████████  | 203/250 [1:45:06<23:56, 30.56s/it]

Epoch [203/250], Loss: 0.0748332182


Total Progress:  82%|████████▏ | 204/250 [1:45:37<23:25, 30.56s/it]

Epoch [204/250], Loss: 0.0752467449


Total Progress:  82%|████████▏ | 205/250 [1:46:07<22:55, 30.56s/it]

Epoch [205/250], Loss: 0.0743897696


Total Progress:  82%|████████▏ | 206/250 [1:46:38<22:24, 30.56s/it]

Epoch [206/250], Loss: 0.0750210522


Total Progress:  83%|████████▎ | 207/250 [1:47:08<21:53, 30.56s/it]

Epoch [207/250], Loss: 0.0748803775


Total Progress:  83%|████████▎ | 208/250 [1:47:39<21:23, 30.56s/it]

Epoch [208/250], Loss: 0.0763710225


Total Progress:  84%|████████▎ | 209/250 [1:48:09<20:52, 30.55s/it]

Epoch [209/250], Loss: 0.0743979637


Total Progress:  84%|████████▍ | 210/250 [1:48:40<20:22, 30.56s/it]

Epoch [210/250], Loss: 0.0761539553


Total Progress:  84%|████████▍ | 211/250 [1:49:11<19:52, 30.57s/it]

Epoch [211/250], Loss: 0.0748415655


Total Progress:  85%|████████▍ | 212/250 [1:49:41<19:21, 30.58s/it]

Epoch [212/250], Loss: 0.0747621219


Total Progress:  85%|████████▌ | 213/250 [1:50:12<18:51, 30.58s/it]

Epoch [213/250], Loss: 0.0750173324


Total Progress:  86%|████████▌ | 214/250 [1:50:42<18:20, 30.58s/it]

Epoch [214/250], Loss: 0.0746167444


Total Progress:  86%|████████▌ | 215/250 [1:51:13<17:50, 30.58s/it]

Epoch [215/250], Loss: 0.0743893592


Total Progress:  86%|████████▋ | 216/250 [1:51:43<17:19, 30.57s/it]

Epoch [216/250], Loss: 0.0746676546


Total Progress:  87%|████████▋ | 217/250 [1:52:14<16:48, 30.57s/it]

Epoch [217/250], Loss: 0.0748110604


Total Progress:  87%|████████▋ | 218/250 [1:52:45<16:18, 30.57s/it]

Epoch [218/250], Loss: 0.0749356943


Total Progress:  88%|████████▊ | 219/250 [1:53:15<15:47, 30.58s/it]

Epoch [219/250], Loss: 0.0744022888


Total Progress:  88%|████████▊ | 220/250 [1:53:46<15:17, 30.58s/it]

Epoch [220/250], Loss: 0.0747493419


Total Progress:  88%|████████▊ | 221/250 [1:54:16<14:46, 30.58s/it]

Epoch [221/250], Loss: 0.0747000082


Total Progress:  89%|████████▉ | 222/250 [1:54:47<14:16, 30.59s/it]

Epoch [222/250], Loss: 0.0741844712


Total Progress:  89%|████████▉ | 223/250 [1:55:18<13:45, 30.58s/it]

Epoch [223/250], Loss: 0.0753473799


Total Progress:  90%|████████▉ | 224/250 [1:55:48<13:15, 30.59s/it]

Epoch [224/250], Loss: 0.0753677428


Total Progress:  90%|█████████ | 225/250 [1:56:19<12:44, 30.59s/it]

Epoch [225/250], Loss: 0.0746712665


Total Progress:  90%|█████████ | 226/250 [1:56:49<12:14, 30.58s/it]

Epoch [226/250], Loss: 0.0746539435


Total Progress:  91%|█████████ | 227/250 [1:57:20<11:43, 30.59s/it]

Epoch [227/250], Loss: 0.0754669374


Total Progress:  91%|█████████ | 228/250 [1:57:50<11:12, 30.58s/it]

Epoch [228/250], Loss: 0.0748372577


Total Progress:  92%|█████████▏| 229/250 [1:58:21<10:42, 30.58s/it]

Epoch [229/250], Loss: 0.0748704793


Total Progress:  92%|█████████▏| 230/250 [1:58:52<10:11, 30.58s/it]

Epoch [230/250], Loss: 0.0754999031


Total Progress:  92%|█████████▏| 231/250 [1:59:22<09:41, 30.58s/it]

Epoch [231/250], Loss: 0.0750504039


Total Progress:  93%|█████████▎| 232/250 [1:59:53<09:10, 30.58s/it]

Epoch [232/250], Loss: 0.0749923341


Total Progress:  93%|█████████▎| 233/250 [2:00:23<08:39, 30.58s/it]

Epoch [233/250], Loss: 0.0757219657


Total Progress:  94%|█████████▎| 234/250 [2:00:54<08:09, 30.58s/it]

Epoch [234/250], Loss: 0.0746497700


Total Progress:  94%|█████████▍| 235/250 [2:01:24<07:38, 30.58s/it]

Epoch [235/250], Loss: 0.0750485204


Total Progress:  94%|█████████▍| 236/250 [2:01:55<07:08, 30.58s/it]

Epoch [236/250], Loss: 0.0748986642


Total Progress:  95%|█████████▍| 237/250 [2:02:26<06:37, 30.58s/it]

Epoch [237/250], Loss: 0.0743508677


Total Progress:  95%|█████████▌| 238/250 [2:02:56<06:06, 30.58s/it]

Epoch [238/250], Loss: 0.0747296803


Total Progress:  96%|█████████▌| 239/250 [2:03:27<05:36, 30.58s/it]

Epoch [239/250], Loss: 0.0746049853


Total Progress:  96%|█████████▌| 240/250 [2:03:57<05:05, 30.58s/it]

Epoch [240/250], Loss: 0.0747199711


Total Progress:  96%|█████████▋| 241/250 [2:04:28<04:36, 30.68s/it]

Epoch [241/250], Loss: 0.0741809860


Total Progress:  97%|█████████▋| 242/250 [2:04:59<04:06, 30.75s/it]

Epoch [242/250], Loss: 0.0744139282


Total Progress:  97%|█████████▋| 243/250 [2:05:30<03:35, 30.83s/it]

Epoch [243/250], Loss: 0.0746966252


Total Progress:  98%|█████████▊| 244/250 [2:06:01<03:04, 30.83s/it]

Epoch [244/250], Loss: 0.0749279180


Total Progress:  98%|█████████▊| 245/250 [2:06:32<02:33, 30.79s/it]

Epoch [245/250], Loss: 0.0750731352


Total Progress:  98%|█████████▊| 246/250 [2:07:03<02:03, 30.83s/it]

Epoch [246/250], Loss: 0.0758291602


Total Progress:  99%|█████████▉| 247/250 [2:07:34<01:32, 30.83s/it]

Epoch [247/250], Loss: 0.0755235031


Total Progress:  99%|█████████▉| 248/250 [2:08:05<01:01, 30.91s/it]

Epoch [248/250], Loss: 0.0748184144


Total Progress: 100%|█████████▉| 249/250 [2:08:36<00:30, 30.92s/it]

Epoch [249/250], Loss: 0.0754373924


Total Progress: 100%|██████████| 250/250 [2:09:06<00:00, 30.99s/it]

Epoch [250/250], Loss: 0.0751394673





In [16]:
checkpoint = torch.load('model_step.pth')
model.load_state_dict(checkpoint['model'])
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)
model.load_state_dict(checkpoint['model'])
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
model.to(device)

  checkpoint = torch.load('model_step.pth')


UnetPlusPlus(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=Tru

In [17]:
model.eval()
for i in os.listdir("../data/test"):
    img_path = os.path.join("../data/test", i)
    ori_img = cv2.imread(img_path)
    ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
    ori_w = ori_img.shape[0]
    ori_h = ori_img.shape[1]
    img = cv2.resize(ori_img, (256, 256))
    transformed = val_transformation(image=img)
    input_img = transformed["image"]
    input_img = input_img.unsqueeze(0).to(device)
    with torch.no_grad():
        output_mask = model.forward(input_img).squeeze(0).cpu().numpy().transpose(1,2,0)
    mask = cv2.resize(output_mask, (ori_h, ori_w))
    mask = np.argmax(mask, axis=2)
    mask_rgb = mask_to_rgb(mask, color_dict)
    mask_rgb = cv2.cvtColor(mask_rgb, cv2.COLOR_RGB2BGR)
    cv2.imwrite("../data/prediction/{}".format(i), mask_rgb) 

In [18]:
def rle_to_string(runs):
    return ' '.join(str(x) for x in runs)

def rle_encode_one_mask(mask):
    pixels = mask.flatten()
    pixels[pixels > 225] = 255
    pixels[pixels <= 225] = 0
    use_padding = False
    if pixels[0] or pixels[-1]:
        use_padding = True
        pixel_padded = np.zeros([len(pixels) + 2], dtype=pixels.dtype)
        pixel_padded[1:-1] = pixels
        pixels = pixel_padded
    rle = np.where(pixels[1:] != pixels[:-1])[0] + 2
    if use_padding:
        rle = rle - 1
    rle[1::2] = rle[1::2] - rle[:-1:2]
    
    return rle_to_string(rle)

def rle2mask(mask_rle, shape=(3,3)):
    s = mask_rle.split()
    starts, lengths = [np.asarray(x, dtype=int) for x in (s[0:][::2], s[1:][::2])]
    starts -= 1
    ends = starts + lengths
    img = np.zeros(shape[0]*shape[1], dtype=np.uint8)
    for lo, hi in zip(starts, ends):
        img[lo:hi] = 1
    return img.reshape(shape).T

def mask2string(dir):
    strings = []
    ids = []
    ws, hs = [[] for i in range(2)]
    for image_id in os.listdir(dir):
        id = image_id.split('.')[0]
        path = os.path.join(dir, image_id)
        print(path)
        img = cv2.imread(path)[:,:,::-1]
        h, w = img.shape[0], img.shape[1]
        for channel in range(2):
            ws.append(w)
            hs.append(h)
            ids.append(f'{id}_{channel}')
            string = rle_encode_one_mask(img[:,:,channel])
            strings.append(string)
    r = {
        'ids': ids,
        'strings': strings,
    }
    return r


MASK_DIR_PATH = '../data/prediction'
dir = MASK_DIR_PATH
res = mask2string(dir)
df = pd.DataFrame(columns=['Id', 'Expected'])
df['Id'] = res['ids']
df['Expected'] = res['strings']

df.to_csv(r'../data/output_1.csv', index=False)

../data/prediction/1531871f2fd85a04faeeb2b535797395.jpeg
../data/prediction/d694539ef2424a9218697283baa3657e.jpeg
../data/prediction/0af3feff05dec1eb3a70b145a7d8d3b6.jpeg
../data/prediction/780fd497e1c0e9082ea2c193ac8d551c.jpeg
../data/prediction/6b83ef461c2a337948a41964c1d4f50a.jpeg
../data/prediction/5026b3550534bca540e24f489284b8e6.jpeg
../data/prediction/ea42b4eebc9e5a87e443434ac60af150.jpeg
../data/prediction/0fca6a4248a41e8db8b4ed633b456aaa.jpeg
../data/prediction/6240619ebebe9e9c9d00a4262b4fe4a5.jpeg
../data/prediction/dc70626ab4ec3d46e602b296cc5cfd26.jpeg
../data/prediction/41ed86e58224cb76a67d4dcf9596154e.jpeg
../data/prediction/e1e0ae936de314f2d95e6c487ffa651b.jpeg
../data/prediction/0a5f3601ad4f13ccf1f4b331a412fc44.jpeg
../data/prediction/c22268d4b4ef4d95ceea11957998906d.jpeg
../data/prediction/2d9e593b6be1ac29adbe86f03d900fd1.jpeg
../data/prediction/b21960c94b0aab4c024a573c692195f8.jpeg
../data/prediction/285e26c90e1797c77826f9a7021bab9f.jpeg
../data/prediction/f14e1e0ae936