<a href="https://colab.research.google.com/github/honghanhh/icdar_2024_SAM/blob/main/L3i%2B%2BFewShotLayoutSegmentation_Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch.nn as nn
from PIL import Image
import cv2

import glob
from skimage.filters import threshold_sauvola

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
!rm -rf icdar_2024_SAM
!git clone https://github.com/honghanhh/icdar_2024_SAM.git

Cloning into 'icdar_2024_SAM'...
remote: Enumerating objects: 278, done.[K
remote: Counting objects: 100% (36/36), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 278 (delta 16), reused 10 (delta 4), pack-reused 242[K
Receiving objects: 100% (278/278), 44.66 MiB | 35.47 MiB/s, done.
Resolving deltas: 100% (36/36), done.


# Some utility function for process data

In [None]:
def convertRGB_to_label(image):
    """
    Convert RGB image to mask label
    """
    # Define RGB color values
    colors = {
        (0, 0, 0): "Background",
        (255, 255, 0): "Paratext",
        (0, 255, 255): "Decoration",
        (255, 0, 255): "Main Text",
        (255, 0, 0): "Title",
        (0, 255, 0): "Chapter Headings"
    }

    # Convert image to numpy array if it's not already
    image = np.array(image)

    # Convert image to 3D if it's grayscale
    if len(image.shape) == 2:
        image = np.stack((image,) * 3, axis=-1)

    # Initialize labels array with the same shape as the input image
    labels = np.zeros_like(image[:, :, 0], dtype=np.int8)

    # Assign labels based on color
    for color, label in colors.items():
        mask = np.all(image == np.array(color), axis=-1)
        labels[mask] = list(colors.values()).index(label)

    return labels

def convertLabel_to_RGB(labels):
    """
    Convert mask image to mask RGB
    """
    label_colors = {
        0: [0, 0, 0],            # Background
        1: [255, 255, 0],        # Paratext
        2: [0, 255, 255],        # Decoration
        3: [255, 0, 255],        # Main Text
        4: [255, 0, 0],          # Title
        5: [0, 255, 0]           # Chapter Headings
    }

    # Create an empty RGB image with the same shape as labels
    h, w = labels.shape
    rgb_image = np.zeros((h, w, 3), dtype=np.uint8)

    # Assign colors based on label values
    for label_value, color in label_colors.items():
        mask = labels == label_value
        rgb_image[mask] = color

    return rgb_image


# Dataset

In [None]:
class UDIADS_Validation(Dataset):
    """
    Dataset for simple Evaluation and Testing
    """
    def __init__(
            self,
            imagePaths,
            maskPaths,

    ):
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths

    def __getitem__(self, idx):

        # read data
        img_path = self.imagePaths[idx]
        mask_path = self.maskPaths[idx]
        read_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        h, w = read_img.shape[0], read_img.shape[1]
        img = 2*((read_img - read_img.min()) / (read_img.max() - read_img.min())) - 1
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)

        mask = convertRGB_to_label(mask)


        #To tensor
        Transforms = transforms.Compose([transforms.ToTensor()])
        img = Transforms(img)
        mask = torch.from_numpy(mask).long()

        return img, mask, (h, w)

    def __len__(self):
        return len(self.imagePaths)

# Model

In [None]:
def dil_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=2, dilation=2),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=2, dilation=2),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),


        )
    return conv


def encoding_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        )
    return conv

def encoding_block1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        )
    return conv

class unet_model(nn.Module):
    def __init__(self,out_channels=4,features=[16, 32]):
        super(unet_model,self).__init__()


        self.dil1 = dil_block(3,features[0])

        self.pool1 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil2 = dil_block(features[0],features[0])

        self.pool2 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil3 = dil_block(features[0],features[0])

        self.pool3 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil4 = dil_block(features[0],features[0])

        self.pool4 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.bott = encoding_block1(features[0],features[0])

        self.tconv1 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv1 = encoding_block(features[1],features[0])

        self.tconv2 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv2 = encoding_block(features[1],features[0])

        self.tconv3 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv3 = encoding_block(features[1],features[0])

        self.tconv4 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv4 = encoding_block1(features[1],features[0])

        self.final_layer = nn.Conv2d(features[0],out_channels, kernel_size=1)

    def forward(self,x):
        dil_1 = self.dil1(x)

        pool_1 = self.pool1(dil_1)

        dil_2 = self.dil2(pool_1)

        pool_2 = self.pool2(dil_2)

        dil_3 = self.dil3(pool_2)

        pool_3 = self.pool3(dil_3)

        dil_4 = self.dil4(pool_3)

        pool_4 = self.pool4(dil_4)

        bott = self.bott(pool_4)

        tconv_1 = self.tconv1(bott)

        concat1 = torch.cat((tconv_1, dil_4), dim=1)

        conv_1 = self.conv1(concat1)

        tconv_2 = self.tconv2(conv_1)

        concat2 = torch.cat((tconv_2, dil_3), dim=1)

        conv_2 = self.conv2(concat2)

        tconv_3 = self.tconv3(conv_2)

        concat3 = torch.cat((tconv_3, dil_2), dim=1)

        conv_3 = self.conv3(concat3)

        tconv_4 = self.tconv4(conv_3)

        concat4 = torch.cat((tconv_4, dil_1), dim=1)

        conv_4 = self.conv4(concat4)

        x = self.final_layer(conv_4)

        return x

class finetuning_unet_model(nn.Module):
    def __init__(self, unet_model, out_channels=10, features=[16, 32]):
        super(finetuning_unet_model,self).__init__()
        self.unet_model = unet_model
        self.unet_model.final_layer = nn.Conv2d(features[0],out_channels, kernel_size=1)

    def forward(self,x):
        return self.unet_model(x)

In [None]:
collection_name ='Latin14396FS'
w_ = 512
h_ = 512
# the other dataset use 256x256, update ckpt of Latin14396 _aug

In [None]:
img_DIR = f'/content/icdar_2024_SAM/U-DIADS-Bib-FS/{collection_name}/img-{collection_name}/'
mask_DIR = f'/content/icdar_2024_SAM/U-DIADS-Bib-FS/{collection_name}/pixel-level-gt-{collection_name}/'
x_valid_dir = os.path.join(img_DIR, 'validation')
y_valid_dir = os.path.join(mask_DIR, 'validation')
val_img_paths = glob.glob(os.path.join(x_valid_dir, "*.jpg"))
val_mask_paths = glob.glob(os.path.join(y_valid_dir, "*.png"))
val_img_paths.sort()
val_mask_paths.sort()
print(val_img_paths[:5])
print(val_mask_paths[:5])

valid_dataset = UDIADS_Validation(
    val_img_paths,
    val_mask_paths,

)

device = "cuda" if torch.cuda.is_available() else "cpu"

pretrained_model = unet_model().to(device)
model = finetuning_unet_model(pretrained_model,out_channels=6)
pretrained_ckpt = torch.load(f'/content/icdar_2024_SAM/checkpoints/ckpt_finetune_{collection_name}/best_val_f1score_{h_}x{w_}.pth')
# load model weights state_dict
model.load_state_dict(pretrained_ckpt['model_state_dict'])
model.eval()
model = model.to(device)

['/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/img-Latin14396FS/validation/028.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/img-Latin14396FS/validation/040.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/img-Latin14396FS/validation/044.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/img-Latin14396FS/validation/064.jpg', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/img-Latin14396FS/validation/137.jpg']
['/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/pixel-level-gt-Latin14396FS/validation/028.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/pixel-level-gt-Latin14396FS/validation/040.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/pixel-level-gt-Latin14396FS/validation/044.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/pixel-level-gt-Latin14396FS/validation/064.png', '/content/icdar_2024_SAM/U-DIADS-Bib-FS/Latin14396FS/pixel-level-gt-Latin14396FS/validation/137.png']


# CODE to generate output image

# Post processing Latin16746FS
**bold text**

In [None]:
def find_connected_components(image):
    # Convert the image to binary
    _, binary_image = cv2.threshold(image, 127, 255, cv2.THRESH_BINARY)

    # Find connected components
    num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(binary_image, connectivity=8)

    # Filter out the background component
    stats = stats[1:]

    return num_labels - 1, labels, stats

def keep_largest_component(image):
    num_labels, labels, stats = find_connected_components(image)

    # Sort by area
    sorted_stats = sorted(stats, key=lambda x: -x[4])

    largest_area = sorted_stats[0][4]
    second_largest_area = sorted_stats[1][4] if len(sorted_stats) > 1 else 0

    # Keep the largest area and the second largest if it's larger than half of the largest area
    keep_indices = [i for i, stat in enumerate(stats) if stat[4] == largest_area or (stat[4] == second_largest_area and second_largest_area > largest_area / 2)]

    # Create a mask to keep only the desired components
    mask = np.zeros_like(labels, dtype=np.uint8)
    for index in keep_indices:
        mask[labels == index + 1] = 255

    return mask
def post_processing_for_Latin16746FS(rgb_image, orig_img):
    gray_img = cv2.imread(orig_img,cv2.IMREAD_GRAYSCALE)
    thresh_ = threshold_sauvola(gray_img, window_size=1005)
    bin_img = (gray_img > thresh_).astype(np.uint8) * 255
    rgb_image[bin_img==255]=0

    gray = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY)
    kernel = np.ones((91, 41), np.uint8)
    dilated = cv2.dilate(thresh,kernel)
    removed_img = keep_largest_component(dilated)
    #plt.imshow(removed_img)
    #plt.show()
    rgb_image[removed_img==0]=0
    return rgb_image

# Post processing for Latin2FS

In [None]:
def post_processing_for_Latin2FS(rgb_image, orig_img):
    gray_img = cv2.imread(orig_img,cv2.IMREAD_GRAYSCALE)
    thresh_ = threshold_sauvola(gray_img, window_size=1005)
    bin_img = (gray_img > thresh_).astype(np.uint8) * 255
    rgb_image[bin_img==255]=0
    return rgb_image

# Post processing for Syr341FS

In [None]:
def imclearborder(img):
    '''
    Remove targets in binary images that are in contact with edges
    @param img: numpy.array, source image, must be a binary image
    @return cropImg: numpy.array, image without border tagets
    '''
    h, w = img.shape
    # expand binary image with a white border of thickness 10
    x = 10
    extended = cv2.copyMakeBorder(img, x, x, x, x, cv2.BORDER_CONSTANT, value=255)

    # Then fill the white border with black
    mh, mw = extended.shape[:2]
    mask = np.zeros([mh + 2, mw + 2], np.uint8)
    cv2.floodFill(extended, mask, (0, 0), 0,flags=cv2.FLOODFILL_FIXED_RANGE)
    cv2.floodFill(extended, mask, (w, h), 0,flags=cv2.FLOODFILL_FIXED_RANGE)
    cv2.floodFill(extended, mask, (w, 0), 0,flags=cv2.FLOODFILL_FIXED_RANGE)
    cv2.floodFill(extended, mask, (0, h), 0,flags=cv2.FLOODFILL_FIXED_RANGE)

    # crop from the original position
    cropImg = extended[x:x+h, x:x+w]
    return cropImg
def remove_small_components(image, min_area):
    # Convert image to grayscale
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    # Threshold the image
    ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY)

    # Find connected components and their statistics
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(thresh, connectivity=8)

    # Iterate through connected components and remove small ones
    img_filtered = np.zeros(thresh.shape, dtype='uint8')
    for i in range(1, num_labels):  # Exclude the background label (0)
        if stats[i, cv2.CC_STAT_AREA] <= min_area:
            image[labels == i] = 0
    return image

def post_processing_for_Syr341FS(rgb_image):
    gray = cv2.cvtColor(rgb_image, cv2.COLOR_BGR2GRAY)
    ret, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY)
    removed_img = imclearborder(thresh)
    rgb_image[removed_img==0]=0
    rgb_image = remove_small_components(rgb_image,50)
    return rgb_image


# Post-processing for Latin14396FS




In [None]:
def post_processing_for_Latin14396FS(rgb_image, orig_img):
  gray_img = cv2.imread(orig_img,cv2.IMREAD_GRAYSCALE)
  thresh_ = threshold_sauvola(gray_img, window_size=1005)
  bin_img = (gray_img > thresh_).astype(np.uint8) * 255
  rgb_image[bin_img==255]=0
  return rgb_image

# Prediction

In [None]:
test_dir = '/content/icdar_2024_SAM/U-DIADS-Bib-FS/'
#for DS in ["Latin2FS", "Latin14396FS", "Latin16746FS", "Syr341FS"]:
for DS in [collection_name]:
    out_results = f'/content/icdar_2024_SAM/Unet-based/{DS}/result-{h_}x{w_}'
    os.makedirs(out_results,exist_ok=True)
    current_path = os.path.join(test_dir,DS,'img-'+DS,'validation')

    list_img = glob.glob(current_path+'/*')
    for im in list_img:
        img  = cv2.cvtColor(cv2.imread(im), cv2.COLOR_BGR2RGB)
        shape = img.shape
        h, w = img.shape[0], img.shape[1]
        img = 2*((img - img.min()) / (img.max() - img.min())) - 1
        Transforms = transforms.Compose([transforms.ToTensor()])
        img = Transforms(img)
        inputs = img.to(device).unsqueeze(0)
        outputs = model(inputs.float())
        preds = torch.argmax(outputs, 1)
        t = preds.cpu()
        t = torch.transpose(t, 0, 1).transpose(1, 2)
        t_np = t.numpy()[:,:,0]
        rgb_image = convertLabel_to_RGB(t_np)
        rgb_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR)
        if DS == 'Latin16746FS':
          rgb_image = post_processing_for_Latin16746FS(rgb_image, im)
        elif DS== 'Latin2FS':
          rgb_image = post_processing_for_Latin2FS(rgb_image, im)
        elif DS =='Latin14396FS':
          rgb_image =post_processing_for_Latin14396FS(rgb_image, im)
        else:
          rgb_image =post_processing_for_Syr341FS(rgb_image)

        cv2.imwrite(os.path.join(out_results,os.path.basename(im)[0:-4]+'.png'),rgb_image)

# Compute evaluation metrics

In [None]:
!python /content/icdar_2024_SAM/Unet-based/metric.py

############## Latin14396FS Scores ##############
Precision:  0.8571748922565886
Recall:  0.6950635384544867
F1 score:  0.718465854224091
Intersection Over Union:  0.6499802605482814
The result folder is not a directory
The result folder is not a directory
The result folder is not a directory
############## Final Scores ##############
Traceback (most recent call last):
  File "/content/icdar_2024_SAM/Unet-based/metric.py", line 124, in <module>
    print("Final result of Intersection Over Union: ", np.mean(result))
  File "/usr/local/lib/python3.10/dist-packages/numpy/core/fromnumeric.py", line 3504, in mean
    return _methods._mean(a, axis=axis, dtype=dtype,
  File "/usr/local/lib/python3.10/dist-packages/numpy/core/_methods.py", line 118, in _mean
    ret = umr_sum(arr, axis, dtype, out, keepdims, where=where)
TypeError: unsupported operand type(s) for +: 'float' and 'NoneType'
