In [2]:
import os
from glob import glob
from zipfile import ZipFile
from natsort import natsorted
from Config import MODERN_CLASSES_V2
from huggingface_hub import snapshot_download
from Source.Utils import create_dir
from Source.Trainer import MultiSegmentationTrainer

In [None]:
layout_dataset = "BDRC/LayoutSegmentation_Dataset"

dataset_path = snapshot_download(
            repo_id=f"{layout_dataset}",
            repo_type="dataset",
            cache_dir="Datasets")

with ZipFile(f"{dataset_path}/data.zip", 'r') as zip:
    zip.extractall(f"{dataset_path}")

print(f"downloaded and extracted the dataset to: {dataset_path}")

Fetching 3 files:   0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
dataset_path = "Datasets/WesternTiledDataset"

train_data = os.path.join(dataset_path, "train")
val_data = os.path.join(dataset_path, "val")
test_data = os.path.join(dataset_path, "test")

train_x = natsorted(glob(f"{train_data}/images/*.png"))
train_y = natsorted(glob(f"{train_data}/masks/*.png"))

valid_x = natsorted(glob(f"{val_data}/images/*.png"))
valid_y = natsorted(glob(f"{val_data}/masks/*.png"))

test_x = natsorted(glob(f"{test_data}/images/*.png"))
test_y = natsorted(glob(f"{test_data}/masks/*.png"))

print(f"Training data => Images: {len(train_x)}, Masks: {len(train_y)}")
print(f"Validation data => Images: {len(valid_x)}, Masks: {len(valid_y)}")
print(f"Test data => Images: {len(test_x)}, Masks: {len(test_y)}")

Training data => Images: 16914, Masks: 16914
Validation data => Images: 16126, Masks: 16126
Test data => Images: 16130, Masks: 16130


In [3]:
patch_size = 512
batch_size = 32

In [4]:
output_dir = os.path.join(dataset_path, "Output")
create_dir(output_dir)

segmentation_trainer = MultiSegmentationTrainer(
    train_x,
    train_y,
    valid_x,
    valid_y,
    test_x,
    test_y,
    image_width=patch_size,
    image_height=patch_size,
    batch_size=batch_size,
    network="deeplab",
    output_path=output_dir,
    classes=MODERN_CLASSES_V2)

Initializing Mutliclass Segmentation trainer...




In [None]:
# validate train loader
train_sample = next(iter(segmentation_trainer.train_ds))

In [None]:
# train mdoel
epochs = 12
segmentation_trainer.train(epochs=epochs)

In [None]:
# export to onnx
segmentation_trainer.export2onnx(segmentation_trainer.model, model_name="modernbookformat")