<a href="https://colab.research.google.com/github/honghanhh/icdar_2024_SAM/blob/main/L3iFewShotLayoutSegmentation_Training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
!rm -rf icdar_2024_SAM

In [3]:

# Clone the GitHub repository to Colab
!git clone https://github.com/honghanhh/icdar_2024_SAM.git


Cloning into 'icdar_2024_SAM'...
remote: Enumerating objects: 188, done.[K
remote: Counting objects: 100% (188/188), done.[K
remote: Compressing objects: 100% (181/181), done.[K
remote: Total 188 (delta 18), reused 167 (delta 5), pack-reused 0[K
Receiving objects: 100% (188/188), 39.80 MiB | 24.04 MiB/s, done.
Resolving deltas: 100% (18/18), done.


In [4]:
from torchvision import models
from torchsummary import summary
import os

import numpy as np
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torch.nn.functional as F
from PIL import Image
import cv2

import time
import glob
from tqdm.notebook import tqdm

import random
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score

from torch.optim.lr_scheduler import StepLR
import albumentations as A
from torch.utils.data import RandomSampler
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

# Some utility function for process data

In [5]:
def convertRGB_to_label(image):
    """
    Convert RGB image to mask label
    """
    # Define RGB color values
    colors = {
        (0, 0, 0): "Background",
        (255, 255, 0): "Paratext",
        (0, 255, 255): "Decoration",
        (255, 0, 255): "Main Text",
        (255, 0, 0): "Title",
        (0, 255, 0): "Chapter Headings"
    }

    # Convert image to numpy array if it's not already
    image = np.array(image)

    # Convert image to 3D if it's grayscale
    if len(image.shape) == 2:
        image = np.stack((image,) * 3, axis=-1)

    # Initialize labels array with the same shape as the input image
    labels = np.zeros_like(image[:, :, 0], dtype=np.int8)

    # Assign labels based on color
    for color, label in colors.items():
        mask = np.all(image == np.array(color), axis=-1)
        labels[mask] = list(colors.values()).index(label)

    return labels

def padding_image(image, divisible):
    """
    Padding image
    """
    h, w = image.shape[:2]
    pad_h = divisible - (h % divisible)
    pad_w = divisible - (w % divisible)
    pad_tuple = ((0, pad_h), (0, pad_w)) + ((0, 0),) * (image.ndim - 2)  # Pad along height and width dimensions
    padded_image = np.pad(image, pad_tuple, mode='constant')
    return padded_image

class SlidingWindowCrop(object):
    """
    Class for sliding crop image to given size
    """
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size)
        else:
            assert len(output_size) == 2
            self.output_size = output_size

    def __call__(self, image, mask):

        h, w = image.shape[:2]
        new_h, new_w = self.output_size

        # Padding if necessary
        pad_h = 0 if h % new_h == 0 else new_h - (h % new_h)
        pad_w = 0 if w % new_w == 0 else new_w - (w % new_w)
        image = np.pad(image, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')
        mask = np.pad(mask, ((0, pad_h), (0, pad_w)), mode='constant')

        cropped_images = []
        cropped_masks = []

        for i in range(0, h + pad_h, new_h):
            for j in range(0, w + pad_w, new_w):
                if i + new_h <= h + pad_h and j + new_w <= w + pad_w:
                    cropped_images.append(image[i:i+new_h, j:j+new_w])
                    cropped_masks.append(mask[i:i+new_h, j:j+new_w])

        for k in range(10): #Change it to 40 when training with possible GPU
            top = np.random.randint(0, h - new_h)
            left = np.random.randint(0, w - new_w)
            cropped_images.append(image[top: top + new_h, left: left + new_w])
            cropped_masks.append(mask[top: top + new_h, left: left + new_w])
        return cropped_images, cropped_masks

def get_sampler(dataset, seed=123):
    generator = torch.Generator()
    generator.manual_seed(seed)
    sampler = RandomSampler(dataset, generator=generator)
    return sampler

# Metrics


In [6]:

def f1_score_metric(output, mask):
    with torch.no_grad():
        f1 = f1_score(mask.flatten().cpu(), output.flatten().cpu(), average='macro')
    return f1
def pixel_accuracy(output, mask):
    with torch.no_grad():
        preds = torch.argmax(F.softmax(output, dim=1), dim=1)
        num_correct = (preds == mask).sum()
        num_pixels = torch.numel(preds)
        accuracy = float(num_correct) / float(num_pixels)
    return accuracy

def mIoU(pred_mask, mask, smooth=1e-10, n_classes=4):
    with torch.no_grad():
        pred_mask = F.softmax(pred_mask, dim=1)
        pred_mask = torch.argmax(pred_mask, dim=1)
        pred_mask = pred_mask.contiguous().view(-1)
        mask = mask.contiguous().view(-1)

        iou_per_class = []
        for clas in range(0, n_classes): #loop per pixel class
            true_class = pred_mask == clas
            true_label = mask == clas

            if true_label.long().sum().item() == 0: #no exist label in this loop
                iou_per_class.append(np.nan)
            else:
                intersect = (true_class[true_label]).sum().float().item()

                union = (true_class + true_label).sum().float().item() - intersect

                iou = (intersect + smooth) / (union +smooth)
                iou_per_class.append(iou)
        return np.nanmean(iou_per_class)

# Preparation dataset

In [7]:
class UDIADS(Dataset):
    """
    Dataset for Training Phase
    """
    def __init__(
            self,
            imagePaths,
            maskPaths,
            transform
    ):
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.trans = transform


    def __getitem__(self, idx):

        # read data
        img_path = self.imagePaths[idx]
        mask_path = self.maskPaths[idx]
        img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        h, w = img.shape[0], img.shape[1]
        img = 2*((img - img.min()) / (img.max() - img.min())) - 1
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
        #resize_mask = cv2.resize(mask, (int(h/2),int(w/2)), interpolation = cv2.INTER_AREA)
        #resize_mask = padding_image(resize_mask, 32)

        mask = convertRGB_to_label(mask)
        #resize_img = cv2.resize(img, (int(h/2),int(w/2)), interpolation = cv2.INTER_LINEAR)
        #resize_img = padding_image(resize_img, 32)


        if self.trans:
            img, mask = self.trans(img, mask)
        img = torch.stack([torch.from_numpy(i) for i in img])
        mask = torch.stack([torch.from_numpy(i).long() for i in mask])
        img = img.permute(0,3,1,2)


        return img, mask, (h,w)#, #repeated_img, repeated_mask

    def __len__(self):
        return len(self.imagePaths)

In [8]:
class UDIADS_Validation(Dataset):
    """
    Dataset for simple Evaluation and Testing
    """
    def __init__(
            self,
            imagePaths,
            maskPaths,

    ):
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths

    def __getitem__(self, idx):

        # read data
        img_path = self.imagePaths[idx]
        mask_path = self.maskPaths[idx]
        read_img = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        h, w = read_img.shape[0], read_img.shape[1]
        img = 2*((read_img - read_img.min()) / (read_img.max() - read_img.min())) - 1
        mask = cv2.imread(mask_path)
        mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)


        mask = convertRGB_to_label(mask)


        #To tensor
        Transforms = transforms.Compose([transforms.ToTensor()])
        img = Transforms(img)
        mask = torch.from_numpy(mask).long()

        return img, mask, (h, w)

    def __len__(self):
        return len(self.imagePaths)

# Model from L-U-Net-based

In [9]:
def dil_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1, dilation=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=2, dilation=2),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=2, dilation=2),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),


        )
    return conv


def encoding_block(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        )
    return conv

def encoding_block1(in_c, out_c):
    conv = nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        nn.Conv2d(out_c, out_c, kernel_size=3, stride=1, padding=1),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True),

        )
    return conv

class unet_model(nn.Module):
    def __init__(self,out_channels=4,features=[16, 32]):
        super(unet_model,self).__init__()


        self.dil1 = dil_block(3,features[0])

        self.pool1 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil2 = dil_block(features[0],features[0])

        self.pool2 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil3 = dil_block(features[0],features[0])

        self.pool3 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.dil4 = dil_block(features[0],features[0])

        self.pool4 = nn.MaxPool2d(kernel_size=(2,2),stride=(2,2))

        self.bott = encoding_block1(features[0],features[0])

        self.tconv1 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv1 = encoding_block(features[1],features[0])

        self.tconv2 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv2 = encoding_block(features[1],features[0])

        self.tconv3 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv3 = encoding_block(features[1],features[0])

        self.tconv4 = nn.ConvTranspose2d(features[0], features[0], kernel_size=2, stride=2)

        self.conv4 = encoding_block1(features[1],features[0])

        self.final_layer = nn.Conv2d(features[0],out_channels, kernel_size=1)

    def forward(self,x):
        dil_1 = self.dil1(x)

        pool_1 = self.pool1(dil_1)

        dil_2 = self.dil2(pool_1)

        pool_2 = self.pool2(dil_2)

        dil_3 = self.dil3(pool_2)

        pool_3 = self.pool3(dil_3)

        dil_4 = self.dil4(pool_3)

        pool_4 = self.pool4(dil_4)

        bott = self.bott(pool_4)

        tconv_1 = self.tconv1(bott)

        concat1 = torch.cat((tconv_1, dil_4), dim=1)

        conv_1 = self.conv1(concat1)

        tconv_2 = self.tconv2(conv_1)

        concat2 = torch.cat((tconv_2, dil_3), dim=1)

        conv_2 = self.conv2(concat2)

        tconv_3 = self.tconv3(conv_2)

        concat3 = torch.cat((tconv_3, dil_2), dim=1)

        conv_3 = self.conv3(concat3)

        tconv_4 = self.tconv4(conv_3)

        concat4 = torch.cat((tconv_4, dil_1), dim=1)

        conv_4 = self.conv4(concat4)

        x = self.final_layer(conv_4)

        return x

class finetuning_unet_model(nn.Module):
    def __init__(self, unet_model, out_channels=10, features=[16, 32]):
        super(finetuning_unet_model,self).__init__()
        self.unet_model = unet_model
        self.unet_model.final_layer = nn.Conv2d(features[0],out_channels, kernel_size=1)

    def forward(self,x):
        return self.unet_model(x)


In [10]:
collection_name = 'Latin16746FS'
img_DIR = f'/content/icdar_2024_SAM/U-DIADS-Bib-FS/{collection_name}/img-{collection_name}/'
mask_DIR = f'/content/icdar_2024_SAM/U-DIADS-Bib-FS/{collection_name}/pixel-level-gt-{collection_name}/'
# Load training and validation data
x_train_dir = os.path.join(img_DIR, 'training')
y_train_dir = os.path.join(mask_DIR, 'training')

x_valid_dir = os.path.join(img_DIR, 'validation')
y_valid_dir = os.path.join(mask_DIR, 'validation')

train_img_paths = glob.glob(os.path.join(x_train_dir, "*.jpg"))
train_mask_paths = glob.glob(os.path.join(y_train_dir, "*.png"))
val_img_paths = glob.glob(os.path.join(x_valid_dir, "*.jpg"))
val_mask_paths = glob.glob(os.path.join(y_valid_dir, "*.png"))
train_img_paths.sort()
train_mask_paths.sort()
val_img_paths.sort()
val_mask_paths.sort()
print('the number of image/label in the train: ',len(os.listdir(x_train_dir)))
print('the number of image/label in the validation: ',len(os.listdir(x_valid_dir)))

the number of image/label in the train:  3
the number of image/label in the validation:  10


# Training

In [11]:
# Compute weight of each class
tmp_dataset = UDIADS_Validation(train_img_paths,train_mask_paths)
list_gt= []
for i in range(3):
    img, mask, (h,w)= tmp_dataset[i]
    list_gt.extend(mask.flatten().tolist())
compute_class_weight(class_weight="balanced", classes=np.unique(list_gt), y=list_gt)

array([ 0.18784916, 35.34627426,  6.26214292,  2.4999391 , 50.95926274,
       14.50002676])

In [15]:
# Save checkpoint
save_ckpt = f'ckpt_finetune_{collection_name}_aug'
os.makedirs(save_ckpt,exist_ok=True)

# Data loader

slidingwindow=SlidingWindowCrop((512,512))
train_dataset = UDIADS(train_img_paths, train_mask_paths, slidingwindow)
valid_dataset = UDIADS_Validation(val_img_paths, val_mask_paths)
train_loader = DataLoader(train_dataset, batch_size=1, sampler=get_sampler(train_dataset), num_workers=10)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=10)

device = "cuda" if torch.cuda.is_available() else "gpu"
pretrained_model = unet_model().to(device)


# Define model
model = finetuning_unet_model(pretrained_model, out_channels=6)
model = model.to(device)
print('number of trainable parameters: ',sum(p.numel() for p in model.parameters() if p.requires_grad))

# Optimazation
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-05)
scheduler = StepLR(optimizer, step_size=50, gamma=0.1)

# Compute from training data
weights = torch.tensor([0.4, 11, 3, 1.7, 4.7, 3]).to(device)
criterion = nn.CrossEntropyLoss(weights)
# Training
train_loss = []
val_loss = []
train_f1 = []
val_f1 = []
train_IoU = []
val_IoU = []
best_loss = np.Inf
best_f1_score = 0.0
epochs = 200
fit_time = time.time()
for epoch in range(epochs):
    print('Epoch: [{}/{}]'.format(epoch+1, epochs))

    trainloss = 0
    train_f1_score = 0
    trainIoU = 0

    since = time.time()
    model.train()
    for index, batch  in enumerate(train_loader):
        img, label, (h, w) = batch
        '''
            Traning the Model.
        '''
        optimizer.zero_grad()
        img = img.float()
        img = img.squeeze(dim=0)
        label = label.squeeze(dim=0)
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        preds = torch.argmax(output, 1)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        trainloss+=loss.item()
        train_f1_score += f1_score_metric(preds, label)
        trainIoU += mIoU(output, label,n_classes=6)
    scheduler.step()

    print('Epoch:', epoch+1, 'LR:', scheduler.get_last_lr()[0])
    model.eval()
    valloss = 0
    val_f1_score = 0
    valIoU = 0

    with torch.no_grad():
      for img_val, label_val, (h, w) in valid_loader:
        '''
            Validation of Model.
        '''
        img_val=img_val.float()
        img_val = img_val.to(device)
        label_val = label_val.to(device)
        output_val = model(img_val)
        preds_val = torch.argmax(output_val, 1)
        loss_val = criterion(output_val,label_val)

        valloss+=loss_val.item()
        val_f1_score += f1_score_metric(preds_val, label_val)
        valIoU += mIoU(output_val, label_val,n_classes=6)

    train_loss.append(trainloss/len(train_loader))
    train_f1.append(train_f1_score/len(train_loader))
    train_IoU.append(trainIoU/len(train_loader))
    val_loss.append(valloss/len(valid_loader))
    val_f1.append(val_f1_score/len(valid_loader))
    val_IoU.append(valIoU/len(valid_loader))

    # Save model if a better val IoU score is obtained
    if best_loss > valloss:
         best_loss = valloss
         torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, f'ckpt_finetune_{collection_name}_aug/best_val_loss_512x512.pth')
         print('Loss_Model saved!')

    # Save model if a better val IoU score is obtained
    if best_f1_score < val_f1_score:
         best_f1_score = val_f1_score
         torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, f'ckpt_finetune_{collection_name}_aug/best_val_f1score_512x512.pth')
         print('IOU_Model saved!')

    #print("epoch : {} ,train loss : {} ,valid loss : {} ,train acc : {} ,val acc : {} ".format(i,train_loss[-1],val_loss[-1],train_accuracy[-1],val_accuracy[-1]))
    print(#"Epoch:{}".format(epoch),
          "Train Loss: {}".format(trainloss/len(train_loader)),
          "Val Loss: {}".format(valloss/len(valid_loader)),
          "Train mIoU:{}".format(trainIoU/len(train_loader)),
          "Val mIoU: {}".format(valIoU/len(valid_loader)),
          "Train F1:{}".format(train_f1_score/len(train_loader)),
          "Val F1:{}".format(val_f1_score/len(valid_loader)),
          "Time: {:.2f}m".format((time.time()-since)/60))
print('Total time: {:.2f} m' .format((time.time()- fit_time)/60))