In [None]:
import cv2
import math
import glob
import inspect
import json
import mlflow
import tiktoken
import torch
import os

import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from dataclasses import dataclass
from torch.utils.data import Dataset, DataLoader, Sampler

In [None]:
BATCH_SIZE = 4
BATCH_ACC = 1
NUM_WORKERS = 16
DEVICE = "cuda"
N_EPOCHS = 100
LR = 1e-4

ROOT_PATH = "/home/henning/repos/ai-playground/projects/label_images"
FILE_NAME = f"/mnt/data/checkpoints/bill_detection.pth"

In [None]:
# checkpoint = torch.load(FILE_NAME)

In [None]:
checkpoint = None

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root, train=True):
        self.root = root

        with open(os.path.join(self.root, "data.json"), "r") as f:
            data = json.load(f)

        self.elements = data["entries"]
        self.visible_categories = {"yes": 0, "no": 1, "unclear": 2}
        self.text_categories = {"yes": 0, "no": 1, "unclear": 2}

        n_train = int(len(self) * 0.8)

        if train:
            self.elements = self.elements[:n_train]
        else:
            self.elements = self.elements[n_train:]

    def __len__(self):
        return len(self.elements)

    def __getitem__(self, idx):
        
        if torch.is_tensor(idx):
            idx = idx.tolist()

        element = self.elements[idx]

        label_visible = element["label_visible"]
        label_visible = np.array(self.visible_categories[label_visible])
        
        label_text = element["label_text"]
        label_text = np.array(self.text_categories[label_text])
        
        filename = element["filename"]
        
        coord = element["coord"]
        coord = coord.replace(")", "")
        coord = coord.replace("(", "")
        xl, yl, xu, yu = coord.split(",")
        xl = int(xl)
        yl = int(yl)
        xu = int(xu)
        yu = int(yu)

        img = cv2.imread(filename)
        h, w, c = img.shape
        img[:, :, [0, 1, 2]] = img[:, :, [2, 1, 0]]
        img = torch.from_numpy(img) / 255
        img = img.permute(2, 0, 1)

        if w > h:
            img = img.permute(0, 2, 1)

        patch = img[:, yl:yu, xl:xu]

        sample = dict()
        sample["img"] = img
        sample["patch"] = patch
        sample["label_visible"] = label_visible
        sample["label_text"] = label_text

        return sample

In [None]:
train_dataset = CustomDataset(ROOT_PATH, train=True)
test_dataset = CustomDataset(ROOT_PATH, train=False)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, pin_memory=True, shuffle=True, num_workers=NUM_WORKERS)
print("Train:", len(train_dataset))
print("Test:", len(test_dataset))

In [None]:
element = train_dataset[0]
print(element["img"].shape)
print(element["patch"].shape)
print(element["label_visible"].shape)
print(element["label_text"].shape)
print()
batch = next(iter(train_loader))
print(batch["img"].shape)
print(batch["patch"].shape)
print(batch["label_visible"].shape)
print(batch["label_text"].shape)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.AvgPool2d(2, 2),
                                     nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.AvgPool2d(2, 2),
                                     nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU())

        self.latent_visible = nn.Sequential(nn.Linear(1024, 128),
                                            nn.ReLU(),
                                            nn.Linear(128, 3),
                                            nn.Softmax(dim=-1))

        self.latent_text = nn.Sequential(nn.Linear(1024, 128),
                                         nn.ReLU(),
                                         nn.Linear(128, 3),
                                         nn.Softmax(dim=-1))

    def forward(self, x):
        
        x = self.encoder(x)
        latent = x.view(x.shape[0], -1)
        visible = self.latent_visible(latent)
        text = self.latent_text(latent)
        
        return visible, text

In [None]:
model = Model().to(DEVICE)

if checkpoint is not None:
    model.load_state_dict(checkpoint["model_state_dict"])

scaler = torch.cuda.amp.GradScaler(enabled=True)
if checkpoint is not None:
    scaler.load_state_dict(checkpoint["scaler_state_dict"])

optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
if checkpoint is not None:
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

n_model_parameters = sum(p.numel() for p in model.parameters())
print(f"Model: {n_model_parameters:,}".replace(",", "."))
print()

with torch.no_grad():
    batch = next(iter(train_loader))
    img = batch["img"].to(DEVICE)
    patch = batch["patch"].to(DEVICE)
    label_visible = batch["label_visible"].to(DEVICE)
    label_text = batch["label_text"].to(DEVICE)
    print(img.shape)
    print(patch.shape)
    print(label_visible.shape)
    print(label_text.shape)
    visible, text = model(patch)
    print(visible.shape)
    print(text.shape)
    loss_visible = F.cross_entropy(visible, label_visible)
    print(loss_visible)
    loss_text = F.cross_entropy(text, label_text)
    print(loss_text)
    loss = loss_visible + loss_text
    print(loss)
    print()

mlflow.set_tracking_uri(uri="http://localhost:8080")
_ = mlflow.set_experiment(f"Bill images.")

In [None]:
def generate_examples(loader, num_new_tokens=16, verbose=False):
    data = dict()
    return data
def train():
    iterator = iter(train_loader)
    N = len(train_loader) // BATCH_SIZE
    sum_loss = 0
    count = 0
    for n_epoch in range(N_EPOCHS):
        for n in range(N):
            for b in range(BATCH_ACC):
                with torch.autocast(device_type=DEVICE, dtype=torch.float16, enabled=True):
                    batch = next(iterator)
                    
                    img = batch["img"].to(DEVICE)
                    patch = batch["patch"].to(DEVICE)
                    label_visible = batch["label_visible"].to(DEVICE)
                    label_text = batch["label_text"].to(DEVICE)
                    
                    visible, text = model(patch)
                    
                    loss_visible = F.cross_entropy(visible, label_visible)
                    loss_text = F.cross_entropy(text, label_text)
                    loss = loss_visible + loss_text
    
                    mlflow.log_metric("train_loss", loss, step=count, synchronous=False)
                scaler.scale(loss).backward()

                sum_loss += loss.item()
                count += 1

                mlflow.log_metric("mean_train_loss", sum_loss / count, step=(n_epoch + 1) * (n + 1) * BATCH_ACC, synchronous=False)

                print(f"\r{n_epoch + 1:03d}|{N_EPOCHS:03d}, {n + 1:04d}|{N:04d}, {b + 1:03d}|{BATCH_ACC:03d}, loss: {sum_loss / count:.05f}", end="")

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)
            scaler.step(optimizer)
            scaler.update()
            
            optimizer.zero_grad(set_to_none=True)
                
            if (n + 1) % 25 == 0:
                print()
                sum_loss = 0
                count = 0
                
            if (n + 1) % 150 == 0:
                print("\nSave...")
                torch.save({"model_state_dict": model.model.state_dict(), 
                            "optimizer_state_dict": optimizer.state_dict(), 
                            "scaler_state_dict": scaler.state_dict()}, FILE_NAME)
                print("...done!\n")
                data = generate_examples(train_loader, num_new_tokens=16, verbose=False)
                mlflow.log_dict(data, f"example_{n_epoch + 1}_{n + 1}.json")
                test(n_epoch + 1, (n_epoch + 1) * (n + 1) * BATCH_ACC)
@torch.no_grad
def test(epoch=0, step=0):
    sum_loss = 0
    count = 0
    for i, batch in enumerate(test_loader):
        img = batch["img"].to(DEVICE)
        patch = batch["patch"].to(DEVICE)
        label_visible = batch["label_visible"].to(DEVICE)
        label_text = batch["label_text"].to(DEVICE)
        
        visible, text = model(patch)
                    
        loss_visible = F.cross_entropy(visible, label_visible)
        loss_text = F.cross_entropy(text, label_text)
        loss = loss_visible + loss_text

        sum_loss += loss.item()
        count += 1

        print(f"\r{i + 1:04d}|{len(test_loader):04d}, loss: {sum_loss / count:.05f}", end="")

    data = generate_examples(test_loader, num_new_tokens=16, verbose=False)
    mlflow.log_dict(data, f"example_{epoch}_{i + 1}.json")
    mlflow.log_metric("test_loss", sum_loss / count, step=step, synchronous=False)
    print()
    print()

In [None]:
run_id = None
with mlflow.start_run(run_id=run_id):
    train()

In [None]:
torch.save({"model_state_dict": model.model.state_dict(), 
            "optimizer_state_dict": optimizer.state_dict(), 
            "scaler_state_dict": scaler.state_dict()}, FILE_NAME)