In [None]:
import cv2
import math
import glob
import inspect
import json
import mlflow
import torch
import os
import time
import sys
import datasets

import torchvision as tv
import lightning as L
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import ultralytics as ula

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from pytorch_lightning.loggers import MLFlowLogger
from lightning.pytorch.callbacks import ModelCheckpoint

In [None]:
torch.set_float32_matmul_precision('medium')

In [None]:
np.set_printoptions(suppress=True)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root, train=True):
        self.root = root

        self.dataset_fg = datasets.load_dataset("skytnt/anime-segmentation", 'fg', split="train")
        self.dataset_bg = datasets.load_dataset("skytnt/anime-segmentation", 'bg', split="train")

        self.elements = np.arange(0, 50_000)

        n_train = int(len(self) * 0.95)

        if train:
            self.elements = self.elements[:n_train]
        else:
            self.elements = self.elements[n_train:]

    def __len__(self):
        return len(self.elements)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        _ = self.elements[idx]

        h_small = 0
        w_small = 0

        h_big = 0
        w_big = 0
        
        while h_big - h_small <= 0 or w_big - w_small <= 0:
            fg_idx = np.random.randint(len(self.dataset_fg))
            bg_idx = np.random.randint(len(self.dataset_bg))
    
            fg = self.dataset_fg[fg_idx]["image"]
            bg = self.dataset_bg[bg_idx]["image"]
            
            fg = np.asarray(fg)
            h_small, w_small, c = fg.shape
            
            bg = np.asarray(bg)
            h_big, w_big, c = bg.shape
        
        h_off = np.random.randint(0, h_big - h_small)
        w_off = np.random.randint(0, w_big - w_small)
        
        mask = fg[:, :, -1] == 255
        selection = fg[mask, :3]
        
        small_img = bg[h_off:h_off + h_small, w_off:w_off + w_small].copy()
        small_img[mask] = selection
        
        img = bg.copy()
        img[h_off:h_off + h_small, w_off:w_off + w_small] = small_img
        
        small_mask = np.zeros(small_img.shape[:2])
        small_mask[mask] = 1
        label_img = np.zeros(img.shape[:2])
        label_img[h_off:h_off + h_small, w_off:w_off + w_small] = small_mask

        img = torch.from_numpy(img).to(torch.float32).permute(2, 0, 1)
        label_img = torch.from_numpy(label_img).to(torch.long)

        img = img / 255
        
        p = 256        
        rx = np.random.randint(0, w_big - p)
        ry = np.random.randint(0, h_big - p)

        patch = img[:, ry:ry + p, rx:rx + p]
        label_patches = label_img[ry:ry + p, rx:rx + p]
        
        sample = dict()
        sample["patches"] = patch
        sample["label_patches"] = label_patches

        return sample

In [None]:
ROOT_PATH = "../data"
BATCH_SIZE = 16
NUM_WORKERS = 15
DEVICE = "cuda"
LR = 1e-3
N_CLASSES = 2
N_EPOCHS = 100
MODEL_IDENTIFIER = "multi-seg-documents"

In [None]:
checkpoint = None

In [None]:
train_dataset = CustomDataset(ROOT_PATH, train=True)
test_dataset = CustomDataset(ROOT_PATH, train=False)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
print("Train:", len(train_dataset))
print("Test:", len(test_dataset))

In [None]:
element = next(iter(train_loader))

patches = element["patches"]
label_patches = element["label_patches"]

print("patches", patches.shape, patches.min(), patches.max())
print("label_patches", label_patches.shape, label_patches.min(), label_patches.max())

In [None]:
counts = np.zeros(N_CLASSES)
N = 1000
for i in range(N):
    element = train_dataset[i]
    print(f"\r{i + 1}|{len(train_dataset)}", end="")
    label_img = element["label_patches"]
    for j in range(N_CLASSES):
        counts[j] += len(label_img[label_img == j].flatten())
counts[counts == 0] = np.mean(counts)
print()
print()
print("counts")
print(counts)
print()
weight = (1 / torch.tensor(counts).to(DEVICE).to(torch.float32))
weight = weight / sum(weight)
print("weight")
print(weight)
print(sum(weight))

In [None]:
@torch.no_grad
def test_model_config():
    batch = next(iter(train_loader))
    
    patches = batch["patches"].to(DEVICE)
    label_patches = batch["label_patches"].to(DEVICE)

    print(f"input")
    print(f"  {patches.shape}")
        
    pred = model(patches)
    
    print(f"output")
    print(f"  {pred.shape}")
    print(f"  {label_patches.shape}")
    print()
    
    pred = pred.permute(0, 2, 3, 1)
    pred = pred.reshape(-1, N_CLASSES)

    label_patches = label_patches.reshape(-1)
    
    loss = F.cross_entropy(pred, label_patches)
    print(loss)

In [None]:
class DownBaseBlock(nn.Module):
    def __init__(self, f, f_out):
        super().__init__()

        self.conv1 = nn.Conv2d(f, f, 3, 1, padding=1)
        self.norm1 = nn.GroupNorm(f // 4, f)
        
        self.conv2 = nn.Conv2d(f * 2, f, 3, 1, padding=1)
        self.norm2 = nn.GroupNorm(f // 4, f)

        self.conv3 = nn.Conv2d(f * 3, f, 3, 1, padding=1)
        self.norm3 = nn.GroupNorm(f // 4, f)

        self.conv4 = nn.Conv2d(f, f_out, 4, 2, padding=1)
        self.norm4 = nn.GroupNorm(f_out // 4, f_out)

    def forward(self, x0):
        
        x1 = self.conv1(x0)
        x1 = self.norm1(x1)
        x1 = F.relu(x1)

        x01 = torch.cat([x0, x1], dim=1)
        
        x2 = self.conv2(x01)
        x2 = self.norm2(x2)
        x2 = F.relu(x2)

        x012 = torch.cat([x0, x1, x2], dim=1)
        
        x3 = self.conv3(x012)
        x3 = self.norm3(x3)
        x3 = x0 + x1 + x2 + x3
        x3 = F.relu(x3)

        x4 = self.conv4(x3)
        x4 = self.norm4(x4)
        x4 = F.relu(x4)
        
        return x4
class UpBaseBlock(nn.Module):
    def __init__(self, f_in, f):
        super().__init__()

        self.conv1 = nn.ConvTranspose2d(f_in, f, 4, 2, padding=1)
        self.norm1 = nn.GroupNorm(f // 4, f)
        
        self.conv2 = nn.Conv2d(f, f, 3, 1, padding=1)
        self.norm2 = nn.GroupNorm(f // 4, f)
        
        self.conv3 = nn.Conv2d(f * 2, f, 3, 1, padding=1)
        self.norm3 = nn.GroupNorm(f // 4, f)

        self.conv4 = nn.Conv2d(f * 3, f, 3, 1, padding=1)
        self.norm4 = nn.GroupNorm(f // 4, f)

    def forward(self, x0):
        
        x1 = self.conv1(x0)
        x1 = self.norm1(x1)
        x1 = F.relu(x1)

        x2 = self.conv2(x1)
        x2 = self.norm2(x2)
        x2 = F.relu(x2)

        x12 = torch.cat([x1, x2], dim=1)
        
        x3 = self.conv3(x12)
        x3 = self.norm3(x3)
        x3 = F.relu(x3)
        
        x123 = torch.cat([x1, x2, x3], dim=1)
        
        x4 = self.conv4(x123)
        x4 = self.norm4(x4)
        x4 = x1 + x2 + x3 + x4
        x4 = F.relu(x4)
        
        return x4
class Model(L.LightningModule):
    def __init__(self, f=64, n_classes=12):
        super().__init__()

        self.in_conv = nn.Conv2d(3, f, 3, 1, padding=1)
        self.norm1 = nn.GroupNorm(f // 4, f)

        self.down_block1 = DownBaseBlock(f * 1, f * 2)
        self.down_block2 = DownBaseBlock(f * 2, f * 3)
        self.down_block3 = DownBaseBlock(f * 3, f * 4)
        self.down_block4 = DownBaseBlock(f * 4, f * 5)

        self.up_block1 = UpBaseBlock(f * 5, f * 4)
        self.up_block2 = UpBaseBlock(f * 4, f * 3)
        self.up_block3 = UpBaseBlock(f * 3, f * 2)
        self.up_block4 = UpBaseBlock(f * 2, f * 1)
        
        self.out_conv = nn.Conv2d(f * 1, n_classes, 3, 1, padding=1)

        for p in self.parameters():
            torch.nn.init.normal_(p, mean=0.0, std=0.02)

    def forward(self, x):

        x = self.in_conv(x)
        x = self.norm1(x)
        x0 = F.relu(x)

        x1 = self.down_block1(x)
        x2 = self.down_block2(x1)
        x3 = self.down_block3(x2)
        x_ = self.down_block4(x3)

        x3 = self.up_block1(x_) + x3
        x2 = self.up_block2(x3) + x2
        x1 = self.up_block3(x2) + x1
        x0 = self.up_block4(x1) + x0

        x = self.out_conv(x0)
        x = F.softmax(x, dim=1)

        return x

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LR)
        return optimizer

    def training_step(self, batch, batch_idx):

        patches = batch["patches"]
        label_patches = batch["label_patches"]

        pred = self(patches)

        pred = pred.permute(0, 2, 3, 1)
        pred = pred.reshape(-1, N_CLASSES)
    
        label_patches = label_patches.reshape(-1)
        
        loss = F.cross_entropy(pred, label_patches, weight=weight)
        
        self.log("train_loss", loss, prog_bar=True, on_step=True, on_epoch=True)
        
        return loss

    def on_train_epoch_end(self):
        print()

In [None]:
model = Model(n_classes=N_CLASSES).to(DEVICE)

In [None]:
test_model_config()

In [None]:
mlflow_logger = MLFlowLogger(
    experiment_name="default",
    tracking_uri="http://localhost:8888"
)

trainer = L.Trainer(devices=1, 
                    accelerator="gpu",
                    max_epochs=N_EPOCHS,
                    enable_checkpointing=True,
                    benchmark=True,
                    log_every_n_steps=1,
                    logger=mlflow_logger)

In [None]:
trainer.fit(model=model, train_dataloaders=train_loader)

In [None]:
with torch.no_grad():
    model = model.to(DEVICE)
    batch = next(iter(train_loader))
    
    patches = batch["patches"].to(DEVICE)[:4]
    label_patches = batch["label_patches"].to(DEVICE)[:4]

    pred = model(patches)

    label_patches = label_patches.detach().cpu().numpy() / (N_CLASSES - 1)
    pred = pred.detach().cpu().numpy().argmax(axis=1) / (N_CLASSES - 1)
    patches = patches.detach().cpu().numpy().transpose(0, 2, 3, 1)

    width = 2
    n = 4
    m = 4
    for i in range(len(pred)):
        fig, axes = plt.subplots(1, 1, figsize=(20, 8))
        patches_ = patches[i]
        patches_ = cv2.rectangle(patches_.copy(), (0, 0), (256, 256), (1.0, 0, 0), width)
        
        gt_ = np.repeat(label_patches[i][:, :, None], 3, -1)
        gt_ = cv2.rectangle(gt_.copy(), (0, 0), (256, 256), (1.0, 0, 0), width)
        
        pred_ = np.repeat(pred[i][:, :, None], 3, -1)
        pred_ = cv2.rectangle(pred_.copy(), (0, 0), (256, 256), (1.0, 0, 0), width)
        
        img_ = np.concatenate([patches_, pred_, gt_], axis=1)
        
        axes.imshow(img_)
        plt.show()