In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import os
from PIL import Image
from torch.utils.data import Dataset
import numpy as np
import cv2
import glob
from torch.utils.data import DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
from tqdm import tqdm
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import statistics
from torch.utils.tensorboard import SummaryWriter
import time

In [None]:
# Load the TensorBoard notebook extension
%load_ext tensorboard
!pip install albumentations==0.4.6

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting albumentations==0.4.6
  Downloading albumentations-0.4.6.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 7.3 MB/s 
Collecting imgaug>=0.4.0
  Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
[K     |████████████████████████████████| 948 kB 51.1 MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.4.6-py3-none-any.whl size=65174 sha256=82603317883cfa4454a45ebcda2a3fd7107c74886714b849efac1142573b6a5d
  Stored in directory: /root/.cache/pip/wheels/cf/34/0f/cb2a5f93561a181a4bcc84847ad6aaceea8b5a3127469616cc
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Attempting uninstall: imgaug
    Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uni

In [None]:
def create_mask(img_h, img_w, label_path,LINE_HEIGHT_RATIO):
    #line_mask = np.zeros((img_h, img_w))
    table_mask = np.zeros((img_h, img_w))
    with open(label_path, 'r') as f:

        label_lines = f.readlines()
        # draw mask
        for line in label_lines:
            line_content = line.strip().split(' ') if '[' not in line else ['[' + a for a in line.strip().split('[')]
            line_type = line_content[0]
            if 'table' in line_type:
                if len(line_content) == 3:  # type, x_points, y_points
                    if ', ' in line_content[1].strip()[1:-1]:
                        sep = ', '
                    else:
                        sep = ' '
                    x_points, y_points = np.fromstring(line_content[1].strip()[1:-1], sep=sep).astype(
                        np.int32), np.fromstring(line_content[2].strip()[1:-1], sep=sep).astype(np.int32)
                    #print(x_points,y_points)
                    if x_points.shape[0] == 2:
                        #print(x_points[0])
                        #print(y_points[0])
                        x_points_new = [x_points[0], y_points[0]]
                        y_points_new = [x_points[1], y_points[1]]
                        cv2.polylines(table_mask, [np.stack([x_points_new, y_points_new], axis=1)], True, 1, 6)

                    """else:
                        cv2.polylines(table_mask, [np.stack([x_points, y_points], axis=1)], True, 255, 2)
                        print(1)"""
                """if len(line_content) == 2:
                    new_list = re.findall('[0-9]+', line_content[1])
                    x, y, w, h = int(new_list[0]), int(new_list[1]), int(new_list[2]), int(new_list[3])
                    cv2.rectangle(table_mask, (x, y), (w, h), 255, 2)"""

    return table_mask

In [None]:
class SegmentDataset(Dataset):

    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = os.listdir(image_dir)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, self.images[index])
        #mask_path = os.path.join(self.mask_dir, self.images[index])
        mask_path = img_path.replace("img", "label").replace("png", "txt")

        image = np.array(Image.open(img_path).convert("RGB"))
        img_h, img_w, _ = image.shape

        mask = create_mask(img_h=img_h,img_w=img_w,label_path=mask_path,LINE_HEIGHT_RATIO=0.5)

        if self.transform:
            augmentations = self.transform(image=image, mask=mask)
            image = augmentations['image']
            mask = augmentations['mask']
        image = image[:,7:-7,7:-7]     
        mask = mask[7:-7,7:-7]


        #assert image.shape[:2] == mask.shape[:2]
        return image, mask

In [None]:
#parameter data
pin_memory = True
batch_size = 1
num_workers = 2
image_height = 512
image_width = 512
train_img_dir = "/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/img_train"
train_mask_dir = "/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/label_train"
val_img_dir = "/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/img_val"
val_mask_dir = "/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/label_val"
test_mask_dir = "/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/label_test"
test_img_dir = "/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/img_test"


train_transform = A.Compose(
        [
            A.Resize(height=image_height, width=image_width,interpolation=cv2.INTER_NEAREST),
            A.Rotate(limit=90, p=1.0,interpolation=cv2.INTER_NEAREST),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.IAAAdditiveGaussianNoise (loc=0, scale=(0.01, 0.05), per_channel=False, always_apply=False, p=0.5),
            A.GaussNoise(),
            A.Normalize(),
            ToTensorV2(),
        ],
    )

val_transform = A.Compose(
        [
            A.Resize(height=image_height, width=image_width,interpolation=cv2.INTER_NEAREST),
            A.Normalize(),
            ToTensorV2(),
        ]
    )

def get_loader(
            train_dir,
            train_maskdir,
            val_dir,
            val_maskdir,
            batch_size,
            train_transform=True,
            val_transform=True,
            num_workers=4,
            pin_memory=True):
    train_ds = SegmentDataset(
        image_dir=train_dir,
        mask_dir=train_maskdir,
        transform= train_transform,
    )

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    val_ds = SegmentDataset(
        image_dir=val_dir,
        mask_dir=val_maskdir,
        transform=val_transform,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=pin_memory,
        shuffle=True
    )

    return train_loader,val_loader

train_loader, val_loader = get_loader(
        train_img_dir,
        train_mask_dir,
        val_img_dir,
        val_mask_dir,
        batch_size,
        train_transform,
        val_transform,
        num_workers,
        pin_memory
    )

In [None]:
from torch.nn.modules.dropout import Dropout
class DoubleConv(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(DoubleConv, self).__init__()
    self.conv = nn.Sequential(
        nn.Conv2d(in_channels,out_channels, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_channels,out_channels, kernel_size=3, stride=1, padding=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=True),
        nn.Dropout(0.25)
    )
  def forward(self, x):
    return self.conv(x)


class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        # Down part
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        # Up part
        for feature in reversed(features):
            self.ups.append(nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature * 2, feature))
        self.bottleneck = DoubleConv(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)
        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for index in range(0, len(self.ups), 2):
            x = self.ups[index](x)
            skip_connection = skip_connections[index // 2]

            if x.shape != skip_connection.shape:
                x = TF.resize(x, size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[index + 1](concat_skip)

        return self.final_conv(x)

In [None]:
def save_checkpoint(state, filename="/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/my_checkpoint_table_10.pth"):
    print("=>Saving checkpoint")
    torch.save(state, filename)


def load_checkpoint(checkpoint, model):
    print("=>Loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])




def save_predictions_as_imgs(loader, model, folder="/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/save_img_train", device="cuda"):
    model.eval()
    for idx, (x, y) in enumerate(loader):
        x = x.to(device=device)
       
        with torch.no_grad():
            preds = torch.sigmoid(model(x))
            preds = (preds > 0.5).float()
        torchvision.utils.save_image(
            preds, f"{folder}/pred_{idx}.png"
        )
        torchvision.utils.save_image(y.unsqueeze(1), f"{folder}{idx}.png")


In [None]:
learning_rate = 3e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
num_epochs = 100
load_model = True
model = UNet(in_channels=3,out_channels=1).to(device)
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


In [None]:
def train(loader, model, optimizer, loss_fn):
    writer = SummaryWriter('runs/seg_table')
    loop = tqdm(loader)
    running_loss = 0.0
    running_correct = 0
    tsb = 0
    for batch_index,(data,targets) in enumerate(loop):
        data = data.to(device=device)
        targets = targets.float().unsqueeze(1).to(device=device)
        #forward
        #start_time = time.time()
        predictions = model(data)
        loss = loss_fn(predictions, targets)
        
        #print("--- %s seconds ---" % (time.time() - start_time))

        #backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #update tqdm loop
        loop.set_postfix(loss=loss.item())

        running_loss += loss.item()
        running_correct += (predictions == targets).sum().item()
        if tsb % 100 == 0:    # every 100 mini-batches...

            # ...log the running loss
            writer.add_scalar('training loss',
                            running_loss / 100,
                            epoch * len(train_loader) + batch_index)
            writer.add_scalar('val loss',
                            running_loss / 100,
                            epoch * len(val_loader) + batch_index)
            writer.add_scalar('training accuracy',
                            running_correct/len(train_loader), batch_index)
            writer.add_scalar('val accuracy',
                            running_correct/len(val_loader), batch_index)             
            """writer.add_scalar('training dice_coeff',
                            running_loss / 100,
                            epoch * len(train_loader) + batch_index)
            writer.add_scalar('val dice_coeff',
                            running_loss / 100,
                            epoch * len(train_loader) + batch_index)"""
            running_loss = 0  
        tsb = tsb + 1  
        


In [None]:
def dice_coeff(loader,model,device):
    dice_score = 0
    model.eval()
    with torch.no_grad():
        for data,target in loader:
            data = data.to(device=device)
            target = target.float().unsqueeze(1).to(device=device).detach().cpu().numpy()
            #img_path = img_path[0]
            #img_name = img_path.rsplit('/', 1)[-1]
            #image = Image.open(img_path)
            pred = torch.sigmoid(model(data))
            #print(pred.shape) #(4,1,630,630)
            pred = (pred > 0.5).float().detach().cpu().numpy()
            #pred = pred.reshape(512, 512).detach().cpu().numpy()
            #image = np.array(image)
            #image = cv2.resize(image,(512,512))
            #print(image.shape)
            #image[pred == 1] = (0, 255, 0)

            #cv2.imwrite('/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/vis/' + img_name,image)
            smooth = 1e-8
            intersection = (pred * target).sum()
            dice_score += ((2. * intersection) + smooth) / (pred.sum() + target.sum() + smooth)
    #print(len(loader))
    print(f"Dice score:{dice_score/len(loader)}")

In [None]:
load_checkpoint(torch.load("/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/my_checkpoint_table_10.pth",map_location=torch.device('cpu')),model)    
dice_coeff(val_loader, model, device)

=>Loading checkpoint
Dice score:0.7924194109598858


In [None]:
if load_model:
    load_checkpoint(torch.load("/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/my_checkpoint_table_10.pth",map_location=torch.device('cpu')),model)    
    #model.load_state_dict(torch.load("/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/resnet_weight.pth", map_location="cpu"))
for epoch in range(num_epochs):
    train(train_loader, model, optimizer, loss_fn)

    #save model
    checkpoint={
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint)
    #check accuracy
    dice_coeff(val_loader, model, device)
    #print ex
save_predictions_as_imgs(
    val_loader, model, folder="/content/drive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/save_img_train",
    device=device
)
%tensorboard --logdir=runs   

=>Loading checkpoint


100%|██████████| 187/187 [00:48<00:00,  3.89it/s, loss=0.0355]


=>Saving checkpoint
Dice score:0.8000049999440781


100%|██████████| 187/187 [00:49<00:00,  3.81it/s, loss=0.0448]


=>Saving checkpoint
Dice score:0.8013131348367308


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0205]


=>Saving checkpoint
Dice score:0.7969011772904584


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0625]


=>Saving checkpoint
Dice score:0.794164731119331


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0242]


=>Saving checkpoint
Dice score:0.8010096407971105


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0434]


=>Saving checkpoint
Dice score:0.7883194869918528


100%|██████████| 187/187 [00:49<00:00,  3.80it/s, loss=0.087]


=>Saving checkpoint
Dice score:0.8038737127211016


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0624]


=>Saving checkpoint
Dice score:0.7950514140888955


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0915]


=>Saving checkpoint
Dice score:0.8027347667522537


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.11]


=>Saving checkpoint
Dice score:0.8010353422564652


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0666]


=>Saving checkpoint
Dice score:0.8025442981531041


100%|██████████| 187/187 [00:48<00:00,  3.85it/s, loss=0.17]


=>Saving checkpoint
Dice score:0.801017244480034


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.0374]


=>Saving checkpoint
Dice score:0.7991162311774266


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.13]


=>Saving checkpoint
Dice score:0.8024167678782038


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.036]


=>Saving checkpoint
Dice score:0.801854897218121


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.116]


=>Saving checkpoint
Dice score:0.7938427091484811


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0519]


=>Saving checkpoint
Dice score:0.7976658558743859


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0278]


=>Saving checkpoint
Dice score:0.801846540669589


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.0283]


=>Saving checkpoint
Dice score:0.7997113134290726


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0725]


=>Saving checkpoint
Dice score:0.8055929343870711


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0849]


=>Saving checkpoint
Dice score:0.8030548957466825


100%|██████████| 187/187 [00:49<00:00,  3.81it/s, loss=0.0677]


=>Saving checkpoint
Dice score:0.7949141346728217


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0369]


=>Saving checkpoint
Dice score:0.7954821091005869


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0583]


=>Saving checkpoint
Dice score:0.8009775818087149


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.124]


=>Saving checkpoint
Dice score:0.796781910725272


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.1]


=>Saving checkpoint
Dice score:0.78574828683323


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0526]


=>Saving checkpoint
Dice score:0.8028519228388159


100%|██████████| 187/187 [00:49<00:00,  3.81it/s, loss=0.0507]


=>Saving checkpoint
Dice score:0.7960869836195943


100%|██████████| 187/187 [00:48<00:00,  3.85it/s, loss=0.124]


=>Saving checkpoint
Dice score:0.8009866712546597


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0413]


=>Saving checkpoint
Dice score:0.7953844292436834


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0464]


=>Saving checkpoint
Dice score:0.7981496134001052


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.0692]


=>Saving checkpoint
Dice score:0.8075429961928754


100%|██████████| 187/187 [00:48<00:00,  3.85it/s, loss=0.127]


=>Saving checkpoint
Dice score:0.8018409705937166


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0834]


=>Saving checkpoint
Dice score:0.8015778897033216


100%|██████████| 187/187 [00:48<00:00,  3.84it/s, loss=0.0338]


=>Saving checkpoint
Dice score:0.7904494062659655


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0495]


=>Saving checkpoint
Dice score:0.7926253516182434


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0614]


=>Saving checkpoint
Dice score:0.7971446356828397


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0456]


=>Saving checkpoint
Dice score:0.7931817756404234


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.105]


=>Saving checkpoint
Dice score:0.7949963982603527


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.048]


=>Saving checkpoint
Dice score:0.7977775619987478


100%|██████████| 187/187 [00:49<00:00,  3.81it/s, loss=0.0832]


=>Saving checkpoint
Dice score:0.7993173771010875


100%|██████████| 187/187 [00:48<00:00,  3.83it/s, loss=0.0436]


=>Saving checkpoint
Dice score:0.7944774562698846


100%|██████████| 187/187 [00:48<00:00,  3.82it/s, loss=0.0559]


=>Saving checkpoint
Dice score:0.7977960668702428


 25%|██▌       | 47/187 [00:12<00:36,  3.85it/s, loss=0.106]

In [None]:
import logging
import os
from typing import Tuple

import cv2
import gdown
import numpy as np
import torch
from PIL import Image
from torchvision import transforms

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2), DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(
                in_channels, in_channels // 2, kernel_size=2, stride=2
            )
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
from torch import nn
from torch.nn.modules.batchnorm import BatchNorm2d

def conv(ni, nf, ks=3, stride=1, act=True, bn=True):
  layers = []
  layers.append(
      nn.Conv2d(
          ni, nf, kernel_size=ks, padding=(ks - 1) // 2, stride=stride, bias=False
      )
  )
  if act:
      layers.append(nn.ReLU(inplace=True))
  if bn:
      layers.append(BatchNorm2d(nf))

  return nn.Sequential(*layers)


def conv_block(ni, nf, stride):
  return nn.Sequential(
      conv(ni, nf // 4, ks=1),
      conv(nf // 4, nf // 4, stride=stride),
      conv(nf // 4, nf, ks=1, bn=True, act=False),
  )


def _resnet_stem(*sizes):
  convs = [
      conv(sizes[i], sizes[i + 1], ks=3, stride=2 if i == 0 else 1)
      for i in range(len(sizes) - 1)
  ]
  return nn.Sequential(*convs)


def noop(x):
  return x


def block(ni, nf, idx, stride=2, nblocks=2):
  return nn.Sequential(
      *[
          ResBlock(ni if i == 0 else nf, nf, stride=stride if i == 0 else 1)
          for i in range(nblocks)
      ]
  )


class ResBlock(nn.Module):
  def __init__(self, ni, nf, stride=1):
      super(ResBlock, self).__init__()
      self.convs = conv_block(ni, nf, stride)
      self.idconv = noop if ni == nf else conv(ni, nf, 1, act=None)
      self.pool = noop if stride == 1 else nn.AvgPool2d(2, ceil_mode=True)

  def forward(self, x):
      return self.convs(x) + self.idconv(self.pool(x))


class ResUnet(nn.Module):
  def __init__(self, n_classes, bilinear=True):
      super(ResUnet, self).__init__()
      self.n_classes = n_classes
      self.bilinear = bilinear

      self.stem = _resnet_stem(3, 32, 32, 64)
      self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
      self.block1 = block(64, 64, 0, stride=1, nblocks=3)
      self.block2 = block(64, 128, 0, stride=2, nblocks=4)
      self.block3 = block(128, 256, 0, stride=2, nblocks=6)
      self.block4 = block(256, 512, 0, stride=2, nblocks=3)

      self.up6 = Up(512 + 256, 256, bilinear)
      self.up5 = Up(256 + 128, 128, bilinear)
      self.up4 = Up(128 + 64, 64, bilinear)
      self.up3 = Up(64 + 64, 64, bilinear)
      self.up2 = Up(64 + 64, 32, bilinear)
      self.up1 = Up(32 + 3, 32, bilinear)

      self.outconv = OutConv(32, n_classes)

  def forward(self, x):
      d1 = self.stem(x)  # 64
      d2 = self.pool(d1)

      d3 = self.block1(d2)  # 64
      d4 = self.block2(d3)  # 128
      d5 = self.block3(d4)  # 256
      d6 = self.block4(d5)  # 512

      h = self.up6(d6, d5)  # 512 + 256 -> 256
      h = self.up5(h, d4)  # 256 + 128 -> 128
      h = self.up4(h, d3)  #
      h = self.up3(h, d2)
      h = self.up2(h, d1)
      h = self.up1(h, x)

      h = self.outconv(h)
      return h

In [None]:
def load_model_unet(weight_path: str, device: torch.device) -> torch.nn.Module:
    """Initialize the line segment model
    Args:
        weight_path (str): path to weight file
        device (torch.device): torch device
    Returns:
        torch.nn.Module: Unet model
    """
    net = ResUnet(n_classes=1)
    net.to(device=device)
    # load pretrained weight
    net.load_state_dict(torch.load(weight_path, map_location=device))
    return net

In [None]:
def predict(
    img: np.ndarray,
    scale_factor: float = 0.25,
    out_threshold: float = 0.5,
    mask_path: str=None,
) -> np.ndarray:
    """Take input as an table image and return a mask of the same size
    as the input image and each pixel has a value of 1 if that pixel belongs
    to a line otherwise it will be 0.
    Args:
        img (np.array): table image
        scale_factor (float, optional): factor for downscaling original image.
        Defaults to 0.5.
        out_threshold (float, optional): confidence threshold. Defaults to 0.5.
    Returns:
        numpy.ndarray: mask image has same size with original image
    """
    h, w, _ = img.shape
    dice_score = 0
    model.eval()
    padding_pil_img, preprocessed_img, pad = _preprocess(
        img=img,
        scale=scale_factor,
    )
    print(padding_pil_img.size, preprocessed_img.shape, pad)
    padding_pil_img.save('img.jpg')
    ts_img = torch.from_numpy(preprocessed_img)
    ts_img = ts_img.unsqueeze(0)
    ts_img = ts_img.to(device=device, dtype=torch.float32)
    
    with torch.no_grad():
        output = model(ts_img)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)
        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.ToTensor(),
            ]
        )
        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()
        mask = full_mask > out_threshold
        mask = _normalize(padding_pil_img, mask_img=mask)
        mask = np.array(mask[50:-50, 50:-50])
        cv2.imwrite('test.jpg', mask * 255)
        mask_truth_new = create_mask(img_h=h,img_w=w,label_path=mask_path,LINE_HEIGHT_RATIO=0.5)
        mask_truth = _normalize(padding_pil_img, mask_img=mask_truth_new)
        mask_truth = np.array(mask_truth[50:-50, 50:-50])
        cv2.imwrite('test1.jpg', mask_truth * 255)
        smooth = 1e-8
        intersection = (mask* mask_truth).sum()
        dice_score += ((2.*intersection)+ smooth)/(mask.sum() + mask_truth.sum() + smooth)
    print(dice_score)

def _preprocess(
    img: np.ndarray,
    scale: float,
    pad: int = 0,
):
    """Add pad to table image from path then resize image
    Args:
        img (np.array): table image
        scale (float): Scale factor
        pad (int, optional): Pad to add to image. Defaults to 5.
    Returns:
        PIL.Image.Image: PIL image for size-recovering purpose
        numpy.array: image after preprocessing
        int: pad value for size-recovering purpose
    """
    # Padding
    h, w, _ = img.shape
    assert pad >= 0, "Pad must great than 0"
    padding_img = np.ones((h + pad * 2, w + pad * 2, 3), dtype=np.uint8) * 255
    padding_img[pad : h + pad, pad : w + pad, :] = img
    pil_img = Image.fromarray(padding_img)

    # Resize
    newW, newH = int(scale * (w + pad * 2)), int(scale * (h + pad * 2))
    assert newW > 0 and newH > 0, "Scale is too small"
    rz_pil_img = pil_img.resize((newW, newH))
    img_nd = np.array(rz_pil_img)

    # HWC to CHW
    img_trans = img_nd.transpose((2, 0, 1))
    if img_trans.max() > 1:
        img_trans = img_trans / 255
    return pil_img, img_trans, pad

def _normalize(img: Image.Image, mask_img: np.ndarray) -> np.ndarray:
    """Convert shape of mask image to shape of img
    Args:
        img (PIL.Image.Image): original table image (H, W, C)
        mask_img (numpy.ndarray): binary image of original table image (H1, W1, C1)
    Returns:
        numpy.ndarray: Mask image has shape of (H, W, C)
    """
    mask = np.asarray(mask_img)
    img = np.asarray(img)
    img_h, img_w = img.shape[:2]
    mask = mask.reshape(mask.shape[0], mask.shape[1])
    mask = mask.astype(np.uint8)
    mask = cv2.resize(mask, (img_w, img_h), cv2.INTER_AREA)

    return mask

In [None]:
model = load_model_unet(weight_path="resnet_weight.pth",device="cuda")

In [None]:
MODEL_PATH = "/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/resnet_weight.pth"
WEIGHT_URL = "https://drive.google.com/u/0/uc?id=18YEiAzUs9NXz0FwBuU0JicEWc_F2V7tq"

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
learning_rate = 3e-4
num_epochs = 1
#model = load_model_unet(weight_path="/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/resnet_weight.pth",device="cuda")

load_checkpoint(torch.load("/content/gdrive/MyDrive/Báo cáo thực tập - Nguyễn Hữu Khải/Segmentation/my_checkpoint_table_repo_10.pth",map_location=torch.device('cuda')),model)       
loss_fn = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
    train(train_loader, model, optimizer, loss_fn)
    dice_coeff(val_loader, model, device)
    checkpoint={
        "state_dict": model.state_dict(),
    }
    save_checkpoint(checkpoint)

In [None]:
import glob
for img_path in glob.glob("img_train/*.png"):
    print(img_path)
    image = np.array(cv2.imread(img_path))
    cv2.imwrite('test_img.jpg', image)
    label_path = img_path.replace('img', 'label').replace('png','txt')
    predict(img=image,mask_path=label_path)
    break
