# Mount tới Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/Training_Model/Segmentation

Mounted at /content/drive
/content/drive/MyDrive/Training_Model/Segmentation


# Import thư viện cần thiết

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms

import os
import cv2
import numpy as np
import albumentations as A
import matplotlib.pyplot as plt
import time


from tqdm import tqdm
from copy import deepcopy
from datetime import datetime
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# Configuration

In [23]:

TQDM_BAR_FORMAT = '{l_bar}{bar:10}{r_bar}'  # Cần thiết cho tqdm

# DATASET
images_file = 'data/images' # Tập chứa ảnh training
segs_file = 'data/labels'  # Tập chứa label của ảnh training
test_file = 'data/test'   # Tập chứa ảnh test
save_dir = 'train_val'   # Path lưu kết quả của tập test
ckpt_file = "best.pt"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
val_size = 0.1          # Tỷ lệ chia tập val

# AUGMENTATION
augmentation = True   # True nếu muốn tăng cường dữ liệu
hflip = True
vflip = True
rotate = True

# TRAINING
lr = 5e-2         # learning rate
num_epochs = 30  # Số lượng epoch training
batch_size = 16   # Batch size

# Chuẩn bị dữ liệu (Data Loader)


In [24]:
class CustomDataLoader(Dataset):
    def __init__(self, data,
                     images_file,
                     segs_file,
                     is_val,
                     augmentation=False,
                     hflip=False,
                     vflip=False,
                     rotate=False):

        self.ims_file = data
        self.images_file = images_file
        self.segs_file = segs_file

        self.Tensor = transforms.ToTensor()

        if is_val:
            self.aug = [False, hflip, vflip, rotate]
        else:
            self.aug = [augmentation, hflip, vflip, rotate]

        self.do_aug = [A.HorizontalFlip(p=1), A.VerticalFlip(p=1), A.Rotate(limit=15, p=1.0)]



    def __len__(self):
        return (len(self.ims_file))

    def img2seg(self, path):
        return path.replace(self.images_file, self.segs_file).replace('.jpg', '.png')


    def preprocess(self, img, seg):
        h0, w0 = img.shape[:2]

        img = cv2.resize(img, (160, 80), interpolation = cv2.INTER_LINEAR)
        seg = cv2.resize(seg, (160, 80), interpolation = cv2.INTER_LINEAR)

        img = img/255
        img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        img = np.ascontiguousarray(img)  # contiguous

        seg = seg/255
        seg = seg.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
        seg = np.ascontiguousarray(seg)  # contiguous

        return img, seg


    def __getitem__(self, index):

        img = cv2.imread(self.ims_file[index])
        seg = cv2.imread(self.img2seg(self.ims_file[index]))


        if self.aug[0]:
            img_pre, seg_pre = self.preprocess(img, seg)
            img_out, seg_out = [torch.from_numpy(img_pre)], [seg_pre]
            for i, f in enumerate(self.do_aug):
                if self.aug[i+1]:
                    au = f(image=img, mask=seg)
                    img_aug, seg_aug = au['image'], au['mask']

                    img_aug, seg_aug = self.preprocess(img_aug, seg_aug)

                    img_out.append(torch.from_numpy(img_aug))
                    seg_out.append(torch.from_numpy(seg_aug))

            return img_out, seg_out

        else:
            img_pre, seg_pre = self.preprocess(img, seg)
            img_out, seg_out = [torch.from_numpy(img_pre)], [torch.from_numpy(seg_pre)]

            return img_out, seg_out

In [25]:
def split_data(images_file, segs_file, val_size):
    pbar = tqdm(os.listdir(images_file), total=len(os.listdir(images_file)), desc='Loading data', bar_format=TQDM_BAR_FORMAT)
    f = []
    for name in pbar:
        f.append(os.path.join(images_file, name))

    train, val = train_test_split(f, test_size=val_size, train_size=(1-val_size))

    return train, val


train, val = split_data(images_file, segs_file, val_size)

train_dataset = CustomDataLoader(train,
                                 images_file,
                                 segs_file,
                                 is_val=False,
                                 augmentation=augmentation,
                                 hflip=hflip,
                                 vflip=vflip,
                                 rotate=rotate)

val_dataset = CustomDataLoader(val,
                               images_file,
                               segs_file,
                               is_val=True,
                               augmentation=augmentation,
                               hflip=hflip,
                               vflip=vflip,
                               rotate=rotate)

print(f'Training size: {len(train_dataset)}')
print(f'Val size: {len(val_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=batch_size, pin_memory=True, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, pin_memory=True, shuffle=False)

Loading data: 100%|██████████| 434/434 [00:00<00:00, 227228.55it/s]

Training size: 390
Val size: 44





# Khởi tạo Model

In [26]:
def double_conv(in_ch, out_ch):
    conv_op = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True)
    )
    return conv_op

class UNet(nn.Module):
    def __init__(self, in_ch=3, out_ch=1):
        super(UNet, self).__init__()
        self.conv1 = double_conv(in_ch, 8)
        self.conv2 = double_conv(8, 16)
        self.conv3 = double_conv(16, 32)
        self.conv4 = double_conv(32, 64)

        self.conv5 = double_conv(96, 32)
        self.conv6 = double_conv(48, 16)
        self.conv7 = double_conv(24, 8)
        self.pooling = nn.MaxPool2d(kernel_size=2)

        self.upsample1 = nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=2, stride=2)
        self.upsample2 = nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=2, stride=2)
        self.upsample3 = nn.ConvTranspose2d(in_channels=16, out_channels=16, kernel_size=2, stride=2)

        self.conv0 = nn.Conv2d(in_channels=8, out_channels=out_ch, kernel_size=1)


    def forward(self, x):
        #Encoder
        down1 = self.conv1(x)
        pool1 = self.pooling(down1)
        down2 = self.conv2(pool1)
        pool2 = self.pooling(down2)
        down3 = self.conv3(pool2)
        pool3 = self.pooling(down3)
        down4 = self.conv4(pool3)

        #Decoder
        upsample1 = self.upsample1(down4)
        cat1 = torch.cat([down3, upsample1], dim=1)
        up1 = self.conv5(cat1)
        upsample2 = self.upsample2(up1)
        cat2 = torch.cat([down2, upsample2], dim=1)
        up2 = self.conv6(cat2)
        upsample3 = self.upsample3(up2)
        cat3 = torch.cat([down1, upsample3], dim=1)
        up3 = self.conv7(cat3)

        outputs = self.conv0(up3)

        return outputs

img = torch.rand(1, 3, 80, 160)
model = UNet()
total_params = sum(p.numel() for p in model.parameters())

print(f'Input size: {img.size()}')
print(f'Total params: {total_params}')

Input size: torch.Size([1, 3, 80, 160])
Total params: 144433


# Training

In [27]:
def dice_loss(input: torch.Tensor, target: torch.Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
    # Average of Dice coefficient for all batches, or for a single mask
#     assert input.size() == target.size()
#     assert input.dim() == 3 or not reduce_batch_first

#     sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)

#     print(input.size())
#     print(target.size())
    inter = 2 * (input * target)
    sets_sum = input + target
    sets_sum = torch.where(sets_sum == 0, inter, sets_sum)

    dice = (inter + epsilon) / (sets_sum + epsilon)
    return 1.0 - dice.mean()


def poly_lr_scheduler(lr, max_epochs, optimizer, epoch, power=2):
    lr = round(lr * (1 - epoch / max_epochs) ** power, 8)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return lr


def evaluate(model, val_loader):
    model.eval()
    num_val_batches = len(val_loader)
    dice_score = 0

    for i, (img_l, target_l) in enumerate(val_loader):

        for img, target in zip(img_l, target_l):
            img = img.to(device).float()
            true_masks = target.to(device).float()
            true_masks = torch.mean(true_masks, dim=1, keepdim=True)
            mask_pred = model(img)
            mask_pred = (torch.sigmoid(mask_pred) > 0.5).float()


            dice_score += dice_loss(mask_pred, true_masks, reduce_batch_first=False)


    model.train()

    return dice_score / max(num_val_batches, 1)

In [28]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4)
model = model.to(device)

best = 10000
dice_score = 0

for epoch in range(num_epochs):

    poly_lr_scheduler(lr, num_epochs, optimizer, epoch)
    for param_group in optimizer.param_groups:
        lr = param_group['lr']

    model.train()

    print(('\n' + '%11s' * 4) % ('Epoch', 'Loss', 'Score', 'Lr'))
    pbar = enumerate(train_loader)
    total_batch = len(train_loader)
    pbar = tqdm(pbar, total=total_batch, bar_format=TQDM_BAR_FORMAT)

    for i, (img_l, target_l) in pbar:
        for img, target in zip(img_l, target_l):
            img = img.to(device).float()
            true_masks = target.to(device).float()
            true_masks = torch.mean(true_masks, dim=1, keepdim=True)
            mask_pred = model(img)

            optimizer.zero_grad()
            loss = criterion(mask_pred, true_masks)
            loss += dice_loss(torch.sigmoid(mask_pred), true_masks, reduce_batch_first=True)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            last = loss.item()

            pbar.set_description(('%13s' * 1 + '%13.4g'*3) %
                                     (f'{epoch}/{num_epochs - 1}', last, dice_score, lr))

            dice_score = evaluate(model, val_loader)

            if last < best:
                best = last

            ckpt = {
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'loss': best,
                    'dice_score': dice_score,
                    'date': datetime.now().isoformat()
            }

            torch.save(ckpt, ckpt_file)



      Epoch       Loss      Score         Lr


         0/29       0.2293       0.2038         0.05: 100%|██████████| 25/25 [00:57<00:00,  2.30s/it]



      Epoch       Loss      Score         Lr


         1/29       0.3782        0.202      0.04672: 100%|██████████| 25/25 [00:56<00:00,  2.25s/it]



      Epoch       Loss      Score         Lr


         2/29       0.2249       0.1267       0.0407: 100%|██████████| 25/25 [00:56<00:00,  2.27s/it]



      Epoch       Loss      Score         Lr


         3/29       0.1554       0.1699      0.03297: 100%|██████████| 25/25 [00:56<00:00,  2.26s/it]



      Epoch       Loss      Score         Lr


         4/29      0.07945      0.07318      0.02476: 100%|██████████| 25/25 [00:57<00:00,  2.28s/it]



      Epoch       Loss      Score         Lr


         5/29      0.08092      0.03032       0.0172: 100%|██████████| 25/25 [00:56<00:00,  2.25s/it]



      Epoch       Loss      Score         Lr


         6/29      0.08028      0.02932      0.01101: 100%|██████████| 25/25 [00:55<00:00,  2.24s/it]



      Epoch       Loss      Score         Lr


         7/29      0.07953      0.02248     0.006469: 100%|██████████| 25/25 [00:56<00:00,  2.25s/it]



      Epoch       Loss      Score         Lr


         8/29      0.06749       0.0194     0.003479: 100%|██████████| 25/25 [00:56<00:00,  2.25s/it]



      Epoch       Loss      Score         Lr


         9/29      0.09795      0.01595     0.001705: 100%|██████████| 25/25 [01:00<00:00,  2.43s/it]



      Epoch       Loss      Score         Lr


        10/29      0.07371      0.01462    0.0007576: 100%|██████████| 25/25 [00:58<00:00,  2.35s/it]



      Epoch       Loss      Score         Lr


        11/29      0.08395      0.01385    0.0003039: 100%|██████████| 25/25 [00:56<00:00,  2.28s/it]



      Epoch       Loss      Score         Lr


        12/29      0.09085      0.01433    0.0001094: 100%|██████████| 25/25 [01:01<00:00,  2.44s/it]



      Epoch       Loss      Score         Lr


        13/29      0.07034      0.01364    3.513e-05: 100%|██████████| 25/25 [00:56<00:00,  2.26s/it]



      Epoch       Loss      Score         Lr


        14/29        0.159      0.01356     9.99e-06: 100%|██████████| 25/25 [00:55<00:00,  2.24s/it]



      Epoch       Loss      Score         Lr


        15/29      0.06462      0.01337      2.5e-06: 100%|██████████| 25/25 [00:56<00:00,  2.25s/it]



      Epoch       Loss      Score         Lr


        16/29      0.05442      0.01347      5.4e-07: 100%|██████████| 25/25 [00:56<00:00,  2.25s/it]



      Epoch       Loss      Score         Lr


        17/29       0.1919      0.01372        1e-07: 100%|██████████| 25/25 [00:59<00:00,  2.38s/it]



      Epoch       Loss      Score         Lr


        18/29       0.2086      0.01476        2e-08: 100%|██████████| 25/25 [00:57<00:00,  2.29s/it]



      Epoch       Loss      Score         Lr


        19/29      0.04966      0.01327            0: 100%|██████████| 25/25 [00:56<00:00,  2.26s/it]



      Epoch       Loss      Score         Lr


        20/29      0.08735      0.01307            0: 100%|██████████| 25/25 [00:56<00:00,  2.27s/it]



      Epoch       Loss      Score         Lr


        21/29        0.116      0.01419            0: 100%|██████████| 25/25 [00:56<00:00,  2.26s/it]



      Epoch       Loss      Score         Lr


        22/29      0.04922      0.01327            0:  48%|████▊     | 12/25 [00:27<00:30,  2.33s/it]


KeyboardInterrupt: ignored

# Đánh giá với dữ liệu test

In [30]:
model = UNet()
checkpoint = torch.load(ckpt_file)
model.load_state_dict(checkpoint['state_dict'])
model.to(device)
model.eval()

images = os.listdir(test_file)

for name in images:
    img = cv2.imread(os.path.join(test_file, name))
    img = cv2.resize(img, (160, 80), interpolation = cv2.INTER_LINEAR)

    img = img/255
    img = img.transpose((2, 0, 1))[::-1]  # HWC to CHW, BGR to RGB
    img = np.ascontiguousarray(img)  # contiguous

    img = torch.from_numpy(img)
    img = img.to(device).unsqueeze(0).float()

#     print(img.shape)

    with torch.no_grad():
        mask_pred = model(img)

    mask_pred = (torch.sigmoid(mask_pred) > 0.5).float()
    to_save = mask_pred.squeeze(0).squeeze(0)
    to_save = (to_save.cpu().numpy()*255).astype(np.uint8)

    save = f'train_val/{name}.jpg'
    cv2.imwrite(save, to_save)

tensor(0.9996, device='cuda:0')
