<a href="https://colab.research.google.com/github/mveiyo/mveiyo/blob/main/Document_Binarization_Collection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
!git clone https://github.com/RichSu95/Document_Binarization_Collection

Cloning into 'Document_Binarization_Collection'...
remote: Enumerating objects: 805, done.[K
remote: Counting objects: 100% (35/35), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 805 (delta 14), reused 35 (delta 14), pack-reused 770[K
Receiving objects: 100% (805/805), 434.26 MiB | 38.16 MiB/s, done.
Resolving deltas: 100% (88/88), done.
Updating files: 100% (726/726), done.


In [None]:

#Train
import os
from time import time

import torch
import torch.utils.data as data

from data import ImageFolder
from framework import MyFrame
from loss import dice_bce_loss

# from networks.unet import UNet
# from networks.dunet import DUNet
from networks.dplinknet import LinkNet34, DLinkNet34, DPLinkNet34

SHAPE = (256, 256)
DATA_NAME = "DIBCO"  # BickleyDiary, DIBCO, PLM
DEEP_NETWORK_NAME = "DPLinkNet34"
print("Now training dataset: {}, using network model: {}".format(DATA_NAME, DEEP_NETWORK_NAME))

train_root = "Document_Binarization_Collection/DP-LinkNet/data/train/"
imagelist = list(filter(lambda x: x.find("img") != -1, os.listdir(train_root)))
trainlist = list(map(lambda x: x[:-8], imagelist))
log_name = DATA_NAME.lower() + "_" + DEEP_NETWORK_NAME.lower()

BATCHSIZE_PER_CARD = 32

if DEEP_NETWORK_NAME == "DPLinkNet34":
    solver = MyFrame(DPLinkNet34, dice_bce_loss, 2e-4)
elif DEEP_NETWORK_NAME == "DLinkNet34":
    solver = MyFrame(DLinkNet34, dice_bce_loss, 2e-4)
elif DEEP_NETWORK_NAME == "LinkNet34":
    solver = MyFrame(LinkNet34, dice_bce_loss, 2e-4)
else:
    print("Deep network not found, please have a check!")
    exit(0)

batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

dataset = ImageFolder(trainlist, train_root)
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batchsize,
    shuffle=True,
    num_workers=4)

mylog = open("Document_Binarization_Collection/DP-LinkNet/logs" + log_name + ".log", "w")
no_optim = 0
total_epoch = 500
train_epoch_best_loss = 100.

tic = time()
for epoch in range(1, total_epoch + 1):
    data_loader_iter = iter(data_loader)
    train_epoch_loss = 1
    '''for img, mask in data_loader_iter:
        solver.set_input(img, mask)
        train_loss = solver.optimize()
        train_epoch_loss += train_loss
'''
for img, mask in data_loader_iter:
    try:
        solver.set_input(img, mask)
        train_loss = solver.optimize()
        train_epoch_loss += train_loss
    except Exception as e:
        print(f"An error occurred: {str(e)}")

    train_epoch_loss /= len(data_loader_iter)
    print("********", file=mylog)
    print("epoch:", epoch, "    time:", int(time() - tic), file=mylog)
    print("train_loss:", train_epoch_loss, file=mylog)
    print("SHAPE:", SHAPE, file=mylog)
    print("********")
    print("epoch:", epoch, "    time:", int(time() - tic))
    print("train_loss:", train_epoch_loss)
    print("SHAPE:", SHAPE)

    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
    else:
        no_optim = 0
        train_epoch_best_loss = train_epoch_loss
        solver.save("Document_Binarization_Collection/DP-LinkNet/weights/" + log_name + ".th")

    if no_optim > 20:
        print("early stop at %d epoch" % epoch, file=mylog)
        print("early stop at %d epoch" % epoch)
        break

    if no_optim > 10:
        if solver.old_lr < 1e-7:
            break
        solver.load("Document_Binarization_Collection/DP-LinkNet/weights/" + log_name + ".th")
        solver.update_lr(5.0, factor=True, mylog=mylog)
    mylog.flush()

print("Finish!", file=mylog)
print("Finish!")
mylog.close()

In [None]:
# Prepare dataset for training

import os

import cv2

from utils import get_patches

TILE_SIZE = 256
print("Image patch size:", TILE_SIZE, "x", TILE_SIZE)

data_root_image = "Document_Binarization_Collection/DP-LinkNet/dataset"
data_root_mask = "Document_Binarization_Collection/DP-LinkNet/data_GT"
data_save = "Document_Binarization_Collection/DP-LinkNet/data"
img_list = os.listdir(data_root_image)
img_list.sort()

data_train_dir = os.path.join(data_save, "train")
if not os.path.exists(data_train_dir):
    os.makedirs(data_train_dir)

# img_patches, msk_patches = [], []  # patches for each image or ground truth
total_img_patches, total_msk_patches = [], []  # patches for all the images or ground truths

for idx in range(len(img_list)):
    if os.path.isdir(os.path.join(data_root_image, img_list[idx])):
        continue

    print("Now processing image:", os.path.join(data_root_image, img_list[idx]))
    (fname, fext) = os.path.splitext(img_list[idx])
    img = cv2.imread(os.path.join(data_root_image, img_list[idx]))
    msk = cv2.imread(os.path.join(data_root_mask, fname + ".png"))

    # extract the patches from the original document images and the corresponding ground truths
    img_patch_locations, img_patches = get_patches(img, TILE_SIZE, TILE_SIZE)
    msk_patch_locations, msk_patches = get_patches(msk, TILE_SIZE, TILE_SIZE)

    print("Patches extracted:", len(img_patches))
    for idy in range(len(img_patches)):
        total_img_patches.append(img_patches[idy])
        total_msk_patches.append(msk_patches[idy])

print("Total number of train patches:", len(total_img_patches))
for idz in range(len(total_img_patches)):
    cv2.imwrite(os.path.join(data_train_dir, str(idz) + "_img.png"), total_img_patches[idz])
    cv2.imwrite(os.path.join(data_train_dir, str(idz) + "_mask.png"), total_msk_patches[idz])

print("Done")


In [3]:
from google.colab.patches import cv2_imshow
import cv2

# Load and display an image (replace 'your_image.jpg' with the actual image file path)
img = cv2.imread('Document_Binarization_Collection/DP-LinkNet/dataset/3.png')

if img is not None:
    cv2_imshow(img)
else:
    print("Error: Image not loaded successfully.")

Error: Image not loaded successfully.


In [4]:
#!pip install -r Document_Binarization_Collection/DP-LinkNet/requirements.txt
if img is not None:
    # Now you can safely access the shape of the image and proceed with processing.
    img_shape = img.shape
    # Rest of your code for processing the image...
else:
    # Handle the case where the image wasn't loaded correctly.
    print("Error: Image not loaded successfully.")

Error: Image not loaded successfully.


In [5]:
#creat folder called dataset,data and data_GT
#change path in data_prepare.py
#upload ground truth images and original images
import os

# Define the folder names
folders = ['dataset', 'data', 'data_GT', 'train', 'weights', 'test_set']

# Specify the path to the root directory in Colab
root_dir = 'Document_Binarization_Collection/DP-LinkNet'

# Create the folders if they don't exist
for folder in folders:
    folder_path = os.path.join(root_dir, folder)
    if not os.path.exists(folder_path):
        os.makedirs(folder_path)

# Check if the folders are created
print("Folders created:", folders)


Folders created: ['dataset', 'data', 'data_GT', 'train']


In [10]:
!python Document_Binarization_Collection/DP-LinkNet/data_prepare.py

Image patch size: 256 x 256
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0005.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0006.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0007.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0008.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0009.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0010.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0011.png
Patches extracted: 48
Now processing image: Document_Binarization_Collection/DP-LinkNet/dataset/IMG-20230804-WA0012.png
Patches extracted: 48
Now processi

In [None]:
!python Document_Binarization_Collection/DP-LinkNet/train.py

In [15]:
!python Document_Binarization_Collection/DP-LinkNet/test.py

Image input directory: Document_Binarization_Collection/DP-LinkNet/test_set/
Image output directory: Document_Binarization_Collection/DP-LinkNet/test_set/Binarized
Now loading the model weights: Document_Binarization_Collection/DP-LinkNet/weights/dibco_dplinknet34.th
Now processing image: IMG-20230804-WA0005.jpg
Now processing image: IMG-20230804-WA0006.jpg
Now processing image: IMG-20230804-WA0007.jpg
Now processing image: IMG-20230804-WA0008.jpg
Now processing image: IMG-20230804-WA0009.jpg
Now processing image: IMG-20230804-WA0010.jpg
Now processing image: IMG-20230804-WA0011.jpg
Now processing image: IMG-20230804-WA0012.jpg
Now processing image: IMG-20230804-WA0013.jpg
Now processing image: IMG-20230804-WA0014.jpg
Now processing image: WhatsApp Prent 2023-08-04 om 15.17.23.jpg
Total running time: 22.488048 sec.
Finished!


In [None]:
#test file
import os
from time import time

import cv2
import numpy as np
import torch
from torch.autograd import Variable as V

# from networks.unet import UNet
# from networks.dunet import DUNet
from networks.dplinknet import LinkNet34, DLinkNet34, DPLinkNet34
from utils import get_patches, stitch_together

BATCHSIZE_PER_CARD = 32


class TTAFrame():
    def __init__(self, net):
        self.net = net().cuda()
        self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count()))

    def test_one_img_from_path(self, path, evalmode=True):
        if evalmode:
            self.net.eval()
        batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD
        if batchsize >= 8:
            return self.test_one_img_from_path_1(path)
        elif batchsize >= 4:
            return self.test_one_img_from_path_2(path)
        elif batchsize >= 2:
            return self.test_one_img_from_path_4(path)

    def test_one_img_from_path_8(self, path):
        img = np.array(path)  # .transpose(2,0,1)[None]
        # img = cv2.imread(path)  # .transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None], img90[None]])
        img2 = np.array(img1)[:, ::-1]
        img3 = np.array(img1)[:, :, ::-1]
        img4 = np.array(img2)[:, :, ::-1]

        img1 = img1.transpose(0, 3, 1, 2)
        img2 = img2.transpose(0, 3, 1, 2)
        img3 = img3.transpose(0, 3, 1, 2)
        img4 = img4.transpose(0, 3, 1, 2)

        img1 = V(torch.Tensor(np.array(img1, np.float32) / 255.0 * 3.2 - 1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32) / 255.0 * 3.2 - 1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32) / 255.0 * 3.2 - 1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32) / 255.0 * 3.2 - 1.6).cuda())

        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()

        mask1 = maska + maskb[:, ::-1] + maskc[:, :, ::-1] + maskd[:, ::-1, ::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1, ::-1]

        return mask2

    def test_one_img_from_path_4(self, path):
        img = np.array(path)  # .transpose(2,0,1)[None]
        # img = cv2.imread(path)  # .transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None], img90[None]])
        img2 = np.array(img1)[:, ::-1]
        img3 = np.array(img1)[:, :, ::-1]
        img4 = np.array(img2)[:, :, ::-1]

        img1 = img1.transpose(0, 3, 1, 2)
        img2 = img2.transpose(0, 3, 1, 2)
        img3 = img3.transpose(0, 3, 1, 2)
        img4 = img4.transpose(0, 3, 1, 2)

        img1 = V(torch.Tensor(np.array(img1, np.float32) / 255.0 * 3.2 - 1.6).cuda())
        img2 = V(torch.Tensor(np.array(img2, np.float32) / 255.0 * 3.2 - 1.6).cuda())
        img3 = V(torch.Tensor(np.array(img3, np.float32) / 255.0 * 3.2 - 1.6).cuda())
        img4 = V(torch.Tensor(np.array(img4, np.float32) / 255.0 * 3.2 - 1.6).cuda())

        maska = self.net.forward(img1).squeeze().cpu().data.numpy()
        maskb = self.net.forward(img2).squeeze().cpu().data.numpy()
        maskc = self.net.forward(img3).squeeze().cpu().data.numpy()
        maskd = self.net.forward(img4).squeeze().cpu().data.numpy()

        mask1 = maska + maskb[:, ::-1] + maskc[:, :, ::-1] + maskd[:, ::-1, ::-1]
        mask2 = mask1[0] + np.rot90(mask1[1])[::-1, ::-1]

        return mask2

    def test_one_img_from_path_2(self, path):
        img = np.array(path)  # .transpose(2,0,1)[None]
        # img = cv2.imread(path)  # .transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None], img90[None]])
        img2 = np.array(img1)[:, ::-1]
        img3 = np.concatenate([img1, img2])
        img4 = np.array(img3)[:, :, ::-1]
        img5 = img3.transpose(0, 3, 1, 2)
        img5 = np.array(img5, np.float32) / 255.0 * 3.2 - 1.6
        img5 = V(torch.Tensor(img5).cuda())
        img6 = img4.transpose(0, 3, 1, 2)
        img6 = np.array(img6, np.float32) / 255.0 * 3.2 - 1.6
        img6 = V(torch.Tensor(img6).cuda())

        maska = self.net.forward(img5).squeeze().cpu().data.numpy()  # .squeeze(1)
        maskb = self.net.forward(img6).squeeze().cpu().data.numpy()

        mask1 = maska + maskb[:, :, ::-1]
        mask2 = mask1[:2] + mask1[2:, ::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1, ::-1]

        return mask3

    def test_one_img_from_path_1(self, path):
        img = np.array(path)  # .transpose(2,0,1)[None]
        # img = cv2.imread(path)  # .transpose(2,0,1)[None]
        img90 = np.array(np.rot90(img))
        img1 = np.concatenate([img[None], img90[None]])
        img2 = np.array(img1)[:, ::-1]
        img3 = np.concatenate([img1, img2])
        img4 = np.array(img3)[:, :, ::-1]
        img5 = np.concatenate([img3, img4]).transpose(0, 3, 1, 2)
        img5 = np.array(img5, np.float32) / 255.0 * 3.2 - 1.6
        img5 = V(torch.Tensor(img5).cuda())

        mask = self.net.forward(img5).squeeze().cpu().data.numpy()  # .squeeze(1)
        mask1 = mask[:4] + mask[4:, :, ::-1]
        mask2 = mask1[:2] + mask1[2:, ::-1]
        mask3 = mask2[0] + np.rot90(mask2[1])[::-1, ::-1]

        return mask3

    def load(self, path):
        self.net.load_state_dict(torch.load(path))


TILE_SIZE = 256
DATA_NAME = "DIBCO"  # BickleyDiary, DIBCO, PLM
DEEP_NETWORK_NAME = "DPLinkNet34"  # LinkNet34, DLinkNet34, DPLinkNet34

img_indir = "Document_Binarization_Collection/DP-LinkNet/test_set/"
print("Image input directory:", img_indir)

img_outdir = os.path.join(img_indir, "Binarized")
if not os.path.exists(img_outdir):
    os.makedirs(img_outdir)
print("Image output directory:", img_outdir)

img_list = os.listdir(img_indir)
img_list.sort()

if DEEP_NETWORK_NAME == "DPLinkNet34":
    solver = TTAFrame(DPLinkNet34)
elif DEEP_NETWORK_NAME == "DLinkNet34":
    solver = TTAFrame(DLinkNet34)
elif DEEP_NETWORK_NAME == "LinkNet34":
    solver = TTAFrame(LinkNet34)
else:
    print("Deep network not found, please have a check!")
    exit(0)
# print(solver.net)
# summary(solver.net, input_size=(3, TILE_SIZE, TILE_SIZE))  # summary(your_model, input_size=(channels, H, W))

print("Now loading the model weights:", "Document_Binarization_Collection/DP-LinkNet/weights/" + DATA_NAME.lower() + "_" + DEEP_NETWORK_NAME.lower() + ".th")
solver.load("Document_Binarization_Collection/DP-LinkNet/weights/" + DATA_NAME.lower() + "_" + DEEP_NETWORK_NAME.lower() + ".th")

start_time = time()
for idx in range(len(img_list)):
    if os.path.isdir(os.path.join(img_indir, img_list[idx])):
        continue

    print("Now processing image:", img_list[idx])
    fname, fext = os.path.splitext(img_list[idx])
    img_input = os.path.join(img_indir, img_list[idx])
    img_output = os.path.join(img_outdir, fname +".png")

    img = cv2.imread(img_input)
    locations, patches = get_patches(img, TILE_SIZE, TILE_SIZE)
    masks = []
    for idy in range(len(patches)):
        msk = solver.test_one_img_from_path(patches[idy])
        masks.append(msk)
    prediction = stitch_together(locations, masks, tuple(img.shape[0:2]), TILE_SIZE, TILE_SIZE)
    prediction[prediction >= 5.0] = 255
    prediction[prediction < 5.0] = 0
    # prediction = np.concatenate([prediction[:, :, None], prediction[:, :, None], prediction[:, :, None]], axis=2)
    cv2.imwrite(img_output, prediction.astype(np.uint8))

print("Total running time: %f sec." % (time() - start_time))
print("Finished!")