In [3]:
import os
import cv2
import time
import glob
import random
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models

from tqdm.notebook import tqdm

from sklearn.utils.class_weight import compute_class_weight
from torch.optim.lr_scheduler import StepLR

from focal_loss import FocalLoss
from dataloader_ICDAR import ICDAR2024, ICDAR2024_v2,SlidingWindowCrop
from utils import get_sampler, pixel_accuracy, mIoU, f1_score_metric
from model import finetuning_unet_model, unet_model

os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [None]:
collection_list =  ['Latin14396FS', 'Latin16746FS', 'Latin2FS', 'Syr341FS']
collection_name = 'Latin14396FS' #Choose the collection to train

## Set path and load data

In [None]:
os.makedirs(f'ckpt_finetune_{collection_name}',exist_ok=True)
img_DIR = f'../U-DIADS-Bib-FS/{collection_name}/img-{collection_name}/'
mask_DIR = f'../U-DIADS-Bib-FS/{collection_name}/pixel-level-gt-{collection_name}/'
# Load training and validation data
x_train_dir = os.path.join(img_DIR, 'training')
y_train_dir = os.path.join(mask_DIR, 'training')

x_valid_dir = os.path.join(img_DIR, 'validation')
y_valid_dir = os.path.join(mask_DIR, 'validation')

train_img_paths = glob.glob(os.path.join(x_train_dir, "*.jpg"))
train_mask_paths = glob.glob(os.path.join(y_train_dir, "*.png"))
val_img_paths = glob.glob(os.path.join(x_valid_dir, "*.jpg"))
val_mask_paths = glob.glob(os.path.join(y_valid_dir, "*.png"))
train_img_paths.sort()
train_mask_paths.sort()
val_img_paths.sort()
val_mask_paths.sort()

print('the number of image/label in the train: ',len(os.listdir(x_train_dir)))
print('the number of image/label in the validation: ',len(os.listdir(x_valid_dir)))

## Compute weight for cross-entropy loss, which is the square of `compute_class_weight`

In [None]:
xx_dataset = ICDAR2024(train_img_paths,train_mask_paths)
list_gt= []
for i in range(3):
    img, mask, (h,w)= xx_dataset[i]
    list_gt.extend(mask.flatten().tolist())
    
weights = compute_class_weight(class_weight="balanced", classes=np.unique(list_gt), y=list_gt)

## Training and save best val loss

In [None]:
# Data loader
transform=SlidingWindowCrop((256,256))
train_dataset = ICDAR2024_v2(train_img_paths,train_mask_paths,transform)
valid_dataset = ICDAR2024(val_img_paths,val_mask_paths)
train_loader = DataLoader(train_dataset, batch_size=1, sampler=get_sampler(train_dataset), num_workers=10)
valid_loader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=10)

device = "cuda" if torch.cuda.is_available() else "gpu"
pretrained_model = unet_model().to(device)

model = finetuning_unet_model(pretrained_model, out_channels=6)
model = model.to(device)
print('number of trainable parameters: ',sum(p.numel() for p in model.parameters() if p.requires_grad))

# Optimazation
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-05)
scheduler = StepLR(optimizer, step_size=50, gamma=0.1)

weights = torch.tensor([0.4, 11, 3, 1.7, 7, 3]).to(device)
criterion = nn.CrossEntropyLoss(weights)
# Training
train_loss = []
val_loss = []
train_f1 = []
val_f1 = []
train_IoU = []
val_IoU = []
best_loss = np.Inf
best_f1_score = 0.0
epochs = 200
fit_time = time.time()
for epoch in range(epochs):
    print('Epoch: [{}/{}]'.format(epoch+1, epochs))

    trainloss = 0
    train_f1_score = 0
    trainIoU = 0

    since = time.time()
    model.train()
    #for img,label in tqdm(train_loader):
    for index, batch  in enumerate(train_loader):
        img, label, (h, w) = batch
        #print(img.shape)
        '''
            Traning the Model.
        '''
        #print(img.shape)
        optimizer.zero_grad()
        img = img.float()
        img = img.squeeze(dim=0)
        label = label.squeeze(dim=0)
        img = img.to(device)
        label = label.to(device)

        output = model(img)
        preds = torch.argmax(output, 1)
        print(torch.unique(label,return_counts=True),torch.unique(preds,return_counts=True))
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        trainloss+=loss.item()
        train_f1_score += f1_score_metric(preds, label)
        trainIoU += mIoU(output, label)
    scheduler.step()
    
    print('Epoch:', epoch+1, 'LR:', scheduler.get_last_lr()[0])
    model.eval()
    valloss = 0
    val_f1_score = 0
    valIoU = 0

    with torch.no_grad():
      for img_val, label_val, (h, w) in valid_loader:
        '''
            Validation of Model.
        '''
        img_val=img_val.float()
        img_val = img_val.to(device)
        label_val = label_val.to(device)
        output_val = model(img_val)
        #output_val = F.interpolate(output_val, size=(h, w), mode='bilinear', align_corners=False)
        #loss_val = criterion(output_val,label_val)
        preds_val = torch.argmax(output_val, 1)
        loss_val = criterion(output_val,label_val)

        valloss+=loss_val.item()
        val_f1_score += f1_score_metric(preds_val, label_val)
        valIoU += mIoU(output_val, label_val)

    train_loss.append(trainloss/len(train_loader))
    train_f1.append(train_f1_score/len(train_loader))
    train_IoU.append(trainIoU/len(train_loader))
    val_loss.append(valloss/len(valid_loader))
    val_f1.append(val_f1_score/len(valid_loader))
    val_IoU.append(valIoU/len(valid_loader))

    # Save model if a better val IoU score is obtained
    if best_loss > valloss:
         best_loss = valloss
         torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, f'ckpt_finetune_{collection_name}/best_val_loss_256x256.pth')
         print('Loss_Model saved!')

    # Save model if a better val IoU score is obtained
    if best_f1_score < val_f1_score:
         best_f1_score = val_f1_score
         torch.save({
            'epoch': epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': criterion,
            }, f'ckpt_finetune_{collection_name}/best_val_f1score_256x256.pth')
         print('IOU_Model saved!')

    #print("epoch : {} ,train loss : {} ,valid loss : {} ,train acc : {} ,val acc : {} ".format(i,train_loss[-1],val_loss[-1],train_accuracy[-1],val_accuracy[-1]))
    print(#"Epoch:{}".format(epoch),
          "Train Loss: {}".format(trainloss/len(train_loader)),
          "Val Loss: {}".format(valloss/len(valid_loader)),
          # "Train mIoU:{}".format(trainIoU/len(train_loader)),
          # "Val mIoU: {}".format(valIoU/len(valid_loader)),
          "Train F1:{}".format(train_f1_score/len(train_loader)),
          "Val F1:{}".format(val_f1_score/len(valid_loader)),
          "Time: {:.2f}m".format((time.time()-since)/60))
print('Total time: {:.2f} m' .format((time.time()- fit_time)/60))