<a href="https://colab.research.google.com/github/honghanhh/icdar_2024_SAM/blob/main/L3i%2B%2BFewShotLayoutSegmentation_Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
import cv2

import glob
from skimage.filters import threshold_sauvola

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Drive already mounted at /content/drive/; to attempt to forcibly remount, call drive.mount("/content/drive/", force_remount=True).


In [3]:
!rm -rf icdar_2024_SAM
!git clone https://github.com/honghanhh/icdar_2024_SAM.git

Cloning into 'icdar_2024_SAM'...
remote: Enumerating objects: 261, done.[K
remote: Counting objects: 100% (19/19), done.[K
remote: Compressing objects: 100% (13/13), done.[K
remote: Total 261 (delta 6), reused 13 (delta 4), pack-reused 242[K
Receiving objects: 100% (261/261), 44.64 MiB | 13.79 MiB/s, done.
Resolving deltas: 100% (26/26), done.


# Some utility function for process data

In [4]:
def convertRGB_to_label(image):
    """
    Convert RGB image to mask label
    """
    # Define RGB color values
    colors = {
        (0, 0, 0): "Background",
        (255, 255, 0): "Paratext",
        (0, 255, 255): "Decoration",
        (255, 0, 255): "Main Text",
        (255, 0, 0): "Title",
        (0, 255, 0): "Chapter Headings"
    }

    # Convert image to numpy array if it's not already
    image = np.array(image)

    # Convert image to 3D if it's grayscale
    if len(image.shape) == 2:
        image = np.stack((image,) * 3, axis=-1)

    # Initialize labels array with the same shape as the input image
    labels = np.zeros_like(image[:, :, 0], dtype=np.int8)

    # Assign labels based on color
    for color, label in colors.items():
        mask = np.all(image == np.array(color), axis=-1)
        labels[mask] = list(colors.values()).index(label)

    return labels

def convertLabel_to_RGB(labels):
    """
    Convert mask image to mask RGB
    """
    label_colors = {
        0: [0, 0, 0],            # Background
        1: [255, 255, 0],        # Paratext
        2: [0, 255, 255],        # Decoration
        3: [255, 0, 255],        # Main Text
        4: [255, 0, 0],          # Title
        5: [0, 255, 0]           # Chapter Headings
    }

    # Create an empty RGB image with the same shape as labels
    h, w = labels.shape
    rgb_image = np.zeros((h, w, 3), dtype=np.uint8)

    # Assign colors based on label values
    for label_value, color in label_colors.items():
        mask = labels == label_value
        rgb_image[mask] = color

    return rgb_image


# Dataset

In [5]:
class UDIADS_Validation(Dataset):
    """
    Dataset for simple Evaluation and Testing
    """
    def __init__(
            self,
            imagePaths,
            maskPaths,

    ):
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths

    def __getitem__(self, idx):

        # read data
        img_path = self.imagePaths[idx]
        mask_path = self.maskPaths[idx]
        read_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        h, w = read_img.shape[0], read_img.shape[1]
        img = 2*((read_img - read_img.min()) / (read_img.max() - read_img.min())) - 1
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        mask = convertRGB_to_label(mask)


        #To tensor
        Transforms = transforms.Compose([transforms.ToTensor()])
        img = Transforms(img)
        mask = torch.from_numpy(mask).long()

        return img, mask, (h, w)

    def __len__(self):
        return len(self.imagePaths)

# Model

In [6]:
def dil_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=2, dilation=2),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=2, dilation=2),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),


        )
    return conv


def encoding_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        )
    return conv

def encoding_block1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        )
    return conv

class unet_model(nn.Module):
    def __init__(self,out_channels=4,features=[16, 32]):
        super(unet_model,self).__init__()


        self.dil1 = dil_block(3,features[0])

        self.pool1 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil2 = dil_block(features[0],features[0])

        self.pool2 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil3 = dil_block(features[0],features[0])

        self.pool3 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil4 = dil_block(features[0],features[0])

        self.pool4 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.bott = encoding_block1(features[0],features[0])

        self.tconv1 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv1 = encoding_block(features[1],features[0])

        self.tconv2 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv2 = encoding_block(features[1],features[0])

        self.tconv3 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv3 = encoding_block(features[1],features[0])

        self.tconv4 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv4 = encoding_block1(features[1],features[0])

        self.final_layer = nn.Conv2d(features[0],out_channels, kernel_size=1)

    def forward(self,x):
        dil_1 = self.dil1(x)

        pool_1 = self.pool1(dil_1)

        dil_2 = self.dil2(pool_1)

        pool_2 = self.pool2(dil_2)

        dil_3 = self.dil3(pool_2)

        pool_3 = self.pool3(dil_3)

        dil_4 = self.dil4(pool_3)

        pool_4 = self.pool4(dil_4)

        bott = self.bott(pool_4)

        tconv_1 = self.tconv1(bott)

        concat1 = torch.cat((tconv_1, dil_4), dim=1)

        conv_1 = self.conv1(concat1)

        tconv_2 = self.tconv2(conv_1)

        concat2 = torch.cat((tconv_2, dil_3), dim=1)

        conv_2 = self.conv2(concat2)

        tconv_3 = self.tconv3(conv_2)

        concat3 = torch.cat((tconv_3, dil_2), dim=1)

        conv_3 = self.conv3(concat3)

        tconv_4 = self.tconv4(conv_3)

        concat4 = torch.cat((tconv_4, dil_1), dim=1)

        conv_4 = self.conv4(concat4)

        x = self.final_layer(conv_4)

        return x

class finetuning_unet_model(nn.Module):
    def __init__(self, unet_model, out_channels=10, features=[16, 32]):
        super(finetuning_unet_model,self).__init__()
        self.unet_model = unet_model
        self.unet_model.final_layer = nn.Conv2d(features[0],out_channels, kernel_size=1)

    def forward(self,x):
        return self.unet_model(x)

In [7]:
collection_name ='Latin16746FS'
w_ = 256
h_ = 256

In [8]:
img_DIR = f'/content/icdar_2024_SAM/U-DIADS-Bib-FS/{collection_name}/img-{collection_name}/'
mask_DIR = f'/content/icdar_2024_SAM/U-DIADS-Bib-FS/{collection_name}/pixel-level-gt-{collection_name}/'
x_valid_dir = os.path.join(img_DIR, 'validation')
y_valid_dir = os.path.join(mask_DIR, 'validation')
val_img_paths = glob.glob(os.path.join(x_valid_dir, "*.jpg"))
val_mask_paths = glob.glob(os.path.join(y_valid_dir, "*.png"))
val_img_paths.sort()
val_mask_paths.sort()
print(val_img_paths[:5])
print(val_mask_paths[:5])

valid_dataset = UDIADS_Validation(
    val_img_paths,
    val_mask_paths,

)

device = "cuda" if torch.cuda.is_available() else "cpu"

pretrained_model = unet_model().to(device)
model = finetuning_unet_model(pretrained_model,out_channels=6)
pretrained_ckpt = torch.load(f'/content/icdar_2024_SAM/checkpoints/ckpt_finetune_{collection_name}/best_val_f1score_{h_}x{w_}.pth')
# load model weights state_dict
model.load_state_dict(pretrained_ckpt['model_state_dict'])
model.eval()
model = model.to(device)

['/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/img-Latin16746FS/validation/024.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/img-Latin16746FS/validation/107.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/img-Latin16746FS/validation/116.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/img-Latin16746FS/validation/122.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/img-Latin16746FS/validation/147.jpg']
['/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/pixel-level-gt-Latin16746FS/validation/024.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/pixel-level-gt-Latin16746FS/validation/107.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/pixel-level-gt-Latin16746FS/validation/116.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/pixel-level-gt-Latin16746FS/validation/122.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin16746FS/pixel-level-gt-Latin16746FS/validation/147.png']


# CODE to generate output image

In [9]:
test_dir = '/content/icdar_2024_SAM/U-DIADS-Bib-FS/'
#for DS in ["Latin2FS", "Latin14396FS", "Latin16746FS", "Syr341FS"]:
for DS in ["Latin16746FS"]:
    out_results = f'Unet-based/{DS}/result-{h_}x{w_}'
    os.makedirs(out_results,exist_ok=True)
    current_path = os.path.join(test_dir,DS,'img-'+DS,'validation')

    list_img = glob.glob(current_path+'/*')
    for im in list_img:
        img  = cv2.cvtColor(cv2.imread(im), cv2.COLOR_BGR2RGB)
        shape = img.shape
        h, w = img.shape[0], img.shape[1]
        img = 2*((img - img.min()) / (img.max() - img.min())) - 1
        Transforms = transforms.Compose([transforms.ToTensor()])
        img = Transforms(img)
        inputs = img.to(device).unsqueeze(0)
        outputs = model(inputs.float())
        preds = torch.argmax(outputs, 1)
        t = preds.cpu()
        t = torch.transpose(t, 0, 1).transpose(1, 2)
        t_np = t.numpy()[:,:,0]
        rgb_image = convertLabel_to_RGB(t_np)
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
        # if test with Latin16746FS, run the following postprocessing step
        # otherwise, comment it
        gray_img = cv2.imread(im,cv2.IMREAD_GRAYSCALE)
        thresh_ = threshold_sauvola(gray_img, window_size=25)
        bin_img = (gray_img > thresh_).astype(np.uint8) * 255
        rgb_image[bin_img==255]=0
        cv2.imwrite(os.path.join(out_results,os.path.basename(im)[0:-4]+'.png'),rgb_image)

# Compute evaluation metrics

In [11]:
!python /content/icdar_2024_SAM/Unet-based/metric.py

############## Latin14396FS Scores ##############
Precision:  0.8987766214915855
Recall:  0.9807385430738644
F1 score:  0.9351480699325356
Intersection Over Union:  0.8832217572127345
############## Syr341FS Scores ##############
Precision:  0.8658239115498026
Recall:  0.9663057854370031
F1 score:  0.9083066809791462
Intersection Over Union:  0.8406053073668961
############## Latin2FS Scores ##############
Precision:  0.8667847410296519
Recall:  0.9669577086926089
F1 score:  0.9094735542039508
Intersection Over Union:  0.8429133151410223
############## Latin16746FS Scores ##############
  _warn_prf(average, modifier, msg_start, len(result))
Precision:  0.19691445236409105
Recall:  0.21272169813690678
F1 score:  0.20400796468778884
Intersection Over Union:  0.1902492325682581
############## Final Scores ##############
Final result of Intersection Over Union:  0.6892474030722278
