In [None]:
import os
# in puhti.csc.fi some of the python packages are installed under user's folder, make sure the correct 
# folder is included in the path

import sys
sys.path.append('/users/vesalaia/.local/lib/python3.9/site-packages')
sys.path.append('/users/vesalaia/.local/lib/python3.9/site-packages/bin')
sys.path.append('/users/vesalaia/.local/lib/python3.9/site-packages/lib/python3.9/site-packages')

In [None]:
# detectron2 is used for object detection

import detectron2


In [None]:
detectron2.__version__

In [None]:
from detectron2 import model_zoo
from detectron2.engine import DefaultTrainer
from detectron2.engine import DefaultPredictor

from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.data.catalog import DatasetCatalog
from detectron2.utils.visualizer import ColorMode

In [None]:
from detectron2.structures import BoxMode

In [None]:
# some other key libraries

import torch

from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt




### Data setup

In [None]:
from detectron2.data.datasets import register_coco_instances
from pycocotools.coco import COCO

register_coco_instances("Illustrations", {}, 
                        "/scratch/project_2005488/DHH23/bbox/result.json", 
                        "/scratch/project_2005488/DHH23/bbox")

In [None]:
sys.path.append('/users/vesalaia/cocoapi/PythonAPI/pycocotools')
sys.path.append('/users/vesalaia/vision/references/detection')

In [None]:
TRAIN_RATIO = 0.9

In [None]:
# some conversions between PIL and cv2 images

def convert_from_cv2_to_image(img: np.ndarray) -> Image:
    # return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    return Image.fromarray(img)


def convert_from_image_to_cv2(img: Image) -> np.ndarray:
    # return cv2.cvtColor(numpy.array(img), cv2.COLOR_RGB2BGR)
    return np.asarray(img)

In [None]:
# description of element classes from PubLayNet we use for transfer learning

Elementclasses = {'Text':0, "Title":1, "List":2, "Table":3, "Figure":4}
UNKNOWN = len(Elementclasses)-1

ElementclassLabels = [x for x in Elementclasses.keys()]
reverse_Elementclass = {v:k for k,v in Elementclasses.items()}
def get_key(l):
    return reverse_Elementclass[l]

In [None]:
len(DatasetCatalog.get("Illustrations"))

In [None]:
metadata = MetadataCatalog.get("Illustrations")

In [None]:
metadata

In [None]:
from detectron2.utils.visualizer import ColorMode

import random
import cv2
import matplotlib.pyplot as plt

def plot_samples(dataset_name, n=1):
    dataset_custom = DatasetCatalog.get(dataset_name)
    dataset_custom_metadata = MetadataCatalog.get(dataset_name)
   
    for s in random.sample(list(dataset_custom), n):
        print(s['file_name'], s['image_id'])
        img = cv2.imread(s['file_name'])
        v = Visualizer(img[:,:,::-1], metadata=dataset_custom_metadata, scale=0.5)
        v = v.draw_dataset_dict(s)
        plt.figure(figsize=(15,20))
        plt.imshow(v.get_image())
        plt.show()

### Training

In [None]:
from detectron2.utils.logger import setup_logger
setup_logger()

In [None]:
plot_samples("Illustrations",n=5)


In [None]:

from detectron2.data import detection_utils as utils
import detectron2.data.transforms as T
import copy

def custom_mapper(dataset_dict):
    dataset_dict = copy.deepcopy(dataset_dict)  # it will be modified by code below
    image = utils.read_image(dataset_dict["file_name"], format="BGR")
    transform_list = [
        T.Resize((800,600)),
        T.RandomBrightness(0.8, 1.8),
        T.RandomContrast(0.6, 1.3),
        T.RandomSaturation(0.8, 1.4),
        T.RandomLighting(0.7),
    ]
    image, transforms = T.apply_transform_gens(transform_list, image)
    dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

    annos = [
        utils.transform_instance_annotations(obj, transforms, image.shape[:2])
        for obj in dataset_dict.pop("annotations")
        if obj.get("iscrowd", 0) == 0
    ]
    instances = utils.annotations_to_instances(annos, image.shape[:2])
    dataset_dict["instances"] = utils.filter_empty_instances(instances)
    return dataset_dict

In [None]:
from detectron2.engine import DefaultTrainer
from detectron2.data import build_detection_test_loader, build_detection_train_loader

class CustomTrainer(DefaultTrainer):
    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg, mapper=custom_mapper)

In [None]:


config_file_path = "/scratch/project_2005488/DHH23/config.yaml"

train_dataset_name = "Illustrations"
num_classes = len(ElementclassLabels)
device = "cuda"
output_dir = "/scratch/project_2005488/DHH23/model"

def get_train_cfg(config_file_path, train_dataset_name, num_classes, device, output_dir):

    
    cfg = get_cfg()
    cfg.merge_from_file(config_file_path)
    cfg.MODEL.WEIGHTS = "/scratch/project_2005488/DHH23/model_final.pth"
    cfg.DATASETS.TRAIN = (train_dataset_name,)

    cfg.DATALOADER.NUM_WORKERS = 2
    cfg.SOLVER.IMS_PER_BATCH = 2
    cfg.SOLVER.BASE_LR = 0.001
    cfg.SOLVER.MAX_ITER = 20000
    cfg.SOLVER.STEPS = []
    
#    cfg.TEST.EVAL_PERIOD = 100
    cfg.DETECTIONS_PER_IMAGE = 100

    cfg.MODEL.ROI_HEADS.NUM_CLASSES = num_classes
    cfg.MODEL.DEVICE = device
    cfg.OUTPUT_DIR = output_dir
    cfg.MASK_FORMAT = "bitmask"
    return cfg

In [None]:
cfg = get_train_cfg(config_file_path, train_dataset_name,  num_classes, device, output_dir)

In [None]:
num_classes

In [None]:
print(cfg)

In [None]:
from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_test_loader, build_detection_train_loader
from detectron2.data import detection_utils as utils
from detectron2.evaluation import COCOEvaluator
import detectron2.data.transforms as T
import copy

def custom_mapper(dataset_dict):
    
    dataset_dict = copy.deepcopy(dataset_dict)
    image = utils.read_image(dataset_dict["file_name"], format="BGR")
    transform_list = [T.Resize((800,800)),
                      T.RandomBrightness(0.9, 1.1)]
            
    image, transforms = T.apply_transform_gens(transform_list, image)
    dataset_dict["image"] = torch.as_tensor(image.transpose(2, 0, 1).astype("float32"))

    annos = [
        utils.transform_instance_annotations(obj, transforms, image.shape[:2])
        for obj in dataset_dict.pop("annotations")
        if obj.get("iscrowd", 0) == 0
    ]
    instances = utils.annotations_to_instances(annos, image.shape[:2])
    dataset_dict["instances"] = utils.filter_empty_instances(instances)
    return dataset_dict
class AugTrainer(DefaultTrainer):
    
    @classmethod
    def build_train_loader(cls, cfg):
        return build_detection_train_loader(cfg, mapper=custom_mapper)



In [None]:
trainer = CustomTrainer(cfg)

In [None]:
#trainer = DefaultTrainer(cfg)
#trainer = AugTrainer(cfg) 
trainer.resume_or_load(resume=False)
trainer.train()

### Testing model

In [None]:
def on_image(image_path, predictor):
    im = cv2.imread(image_path)
    outputs = predictor(im)
    v = Visualizer(im[:,:,::-1], metadata = {}, scale=0.5, instance_mode = ColorMode.SEGMENTATION)
    v = v.draw_instance_predictions(outputs['instances'].to("cpu"))
    plt.figure(figsize=(10,6))
    plt.imshow(v.get_image())
    plt.show()

In [None]:
def drawBoxes(image_path, predictor):
    im = cv2.imread(image_path)
    outputs = predictor(im)
    v = Visualizer(
        im[:, :, ::-1], 
        metadata={}, 
        scale=0.5,
        )
    for box,l,sc in zip(outputs["instances"].pred_boxes.to('cpu'),outputs["instances"].pred_classes.to('cpu'),outputs["instances"].scores.to('cpu')) :
        
        if sc >= 0.5 and l == 4:
            v.draw_box(box)
            v.draw_text(get_key(l.item()), tuple(box[:2].numpy()))
    v = v.get_output()
    plt.figure(figsize=(10,6))
    plt.imshow(v.get_image())
    plt.show()

In [None]:
from detectron2.engine import DefaultPredictor

cfg = get_train_cfg(config_file_path, train_dataset_name,  num_classes, device, output_dir)
cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")
cfg.MODEL.ROI_HEADS.SCORE_TRESH_TEST = 0.8

predictor = DefaultPredictor(cfg)

In [None]:

imagedir = "/scratch/project_2005488/DHH23/Test"
imagelist = os.listdir(imagedir)
for imgname in imagelist:
    if imgname.endswith(".png") or imgname.endswith(".jpg") or imgname.endswith(".jpeg"):
        print(imgname)
        image_path = os.path.join(imagedir,imgname)
        drawBoxes(image_path, predictor)
#        on_image(image_path, predictor)