In [None]:
import cv2
import glob
import json
import os
import torch
import sys

import torch.nn.functional as F
import torch.nn as nn
import gradio as gr
import numpy as np

import matplotlib.pyplot as plt

In [None]:
FILE_PATH = "../data"

In [None]:
def show(filenames):
    n = 2
    m = 2
    fig, axes = plt.subplots(n, m, figsize=(20, 5))
    for i in range(n):
        for j in range(m):
            idx = (i * m) + j
            img = cv2.imread(filenames[idx])
            img[:, :, [0, 1, 2]] = img[:, :, [2, 1, 0]]
            axes[i, j].imshow(img)
    plt.show()

In [None]:
filenames = sorted(glob.glob(os.path.join(FILE_PATH, "images\*")))

In [None]:
filenames

# Save image in correct rotation to ease things

In [None]:
for filename in filenames:
    img = cv2.imread(filename)
    h, w, c = img.shape
    if w > h:
        if sys.platform.startswith('linux'):
            print("save", "linux", filename)
            img = img.transpose(0, 2, 1)
        else:
            print("save", "windows", filename)
            img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
        cv2.imwrite(filename, img)

# Model definition for pre prediction

In [None]:
class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = nn.Sequential(nn.Conv2d(3, 32, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.AvgPool2d(2, 2),
                                     nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.AvgPool2d(2, 2),
                                     nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU(),
                                     nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1),
                                     nn.ReLU())

        self.dropout = nn.Dropout(p=0.25)

        self.out = nn.Sequential(nn.Linear(1024, 128),
                                 nn.ReLU(),
                                 nn.Linear(128, 128),
                                 nn.ReLU(),
                                 nn.Linear(128, 3 + 3))

    def forward(self, x):
        
        x = self.encoder(x)
        latent = x.view(x.shape[0], -1)
        latent = self.dropout(latent)
        x = self.out(latent)
        visible = F.softmax(x[:, :3], dim=-1)
        text = F.softmax(x[:, 3:], dim=-1)
        
        return visible, text

In [None]:
def append_to_json_file(payload):

    filename = os.path.join(FILE_PATH, "data.json")

    if not os.path.isfile(filename):
        print(f"Creat file {filename}")
        with open(filename, "w") as f:
            json.dump({"entries": []}, f)
    
    with open(filename, "r") as f:
        data = json.load(f)

    data["entries"].append(payload)
    print(data)
    
    with open(filename, "w") as f:
        json.dump(data, f)

# Load model

In [None]:
# HEIGHT = 4032
# WIDTH = 3024
# DEVICE = "cuda"
# checkpoint = torch.load(f"/mnt/data/checkpoints/bill_detection.pth")
# model = Model().to(DEVICE)
# _ = model.load_state_dict(checkpoint["model_state_dict"])

# Gradio application running in a separat browser tab

In [None]:
def create_app():
    def get_image_patch(filename, patch_size=(256, 256)):

        img = cv2.imread(filename)
        h, w, c = img.shape

        img[:, :, [0, 1, 2]] = img[:, :, [2, 1, 0]]
        
        ph, pw = patch_size
        ry = np.random.randint(0, h - ph)
        rx = np.random.randint(0, w - pw)
        patch = img[ry:ry + ph, rx:rx + pw]
        upper_left = rx, ry
        lower_right = rx + pw, ry + ph

        coord = upper_left, lower_right

        img = cv2.rectangle(img.copy(), (rx, ry), (rx + pw, ry + ph), (255, 0, 0), 10)
        img = cv2.resize(img, None, fx=0.25, fy=0.25)
        h, w, c = img.shape

        #####################################################################################
        #####################################################################################
        #####################################################################################

        # visible_categories = ["yes", "no", "unclear"]
        # text_categories = ["yes", "no", "unclear"]
        
        # patch_tensor = patch.copy()
        # patch_tensor = torch.from_numpy(patch_tensor) / 255
        # patch_tensor = patch_tensor.permute(2, 0, 1)
        # patch_tensor = patch_tensor[None, :]
        # patch_tensor = patch_tensor.to(DEVICE)

        # visible, text = model(patch_tensor)
        # visible_idx = visible.detach().cpu().numpy().argmax(axis=1)[0]
        # text_idx = text.detach().cpu().numpy().argmax(axis=1)[0]

        # pred_label_visible = visible_categories[visible_idx]
        # pred_label_text = text_categories[text_idx]
        
        #####################################################################################
        #####################################################################################
        #####################################################################################

        pred_label_visible = None
        pred_label_text = None
        
        return img, patch, coord, pred_label_visible, pred_label_text, h, w
        
    with gr.Blocks(theme="adam-haile/DSTheme") as demo:
    
        idx = np.random.randint(0, len(filenames))
        filename = filenames[idx]
        img, patch, coord, pred_label_visible, pred_label_text, h, w = get_image_patch(filename)

        with gr.Row():
            big_image = gr.Image(img, height=h, width=w, label="Full image for reference")
            image = gr.Image(patch, height=h, width=w, label="Patch to label")
        
        label_visible_radio = gr.Radio(value=pred_label_visible, 
                                       choices=["yes", "no", "unclear"], 
                                       label="Is the bill visible in the image?")
        
        label_text_radio = gr.Radio(value=pred_label_text, 
                                    choices=["yes", "no", "unclear"], 
                                    label="Is text visible in the image?")
        
        filename_text = gr.Text(filename, 
                                label="filename", 
                                interactive=False)
        
        coord_text = gr.Text(coord, 
                             label="coord", 
                             interactive=False)
        
        output_textbox = gr.Textbox(label="Submitted data", 
                                    interactive=False)
            
        button = gr.Button("Submit")
    
        @button.click(inputs=[label_visible_radio, 
                              label_text_radio, 
                              filename_text, 
                              coord_text], 
                      outputs=[output_textbox, 
                               label_visible_radio, 
                               label_text_radio, 
                               filename_text,
                               coord_text,
                               big_image,
                               image])
        def submit(label_visible, label_text, filename, coord):
            
            idx = np.random.randint(0, len(filenames))
            new_filename = filenames[idx]
            img, patch, new_coord, pred_label_visible, pred_label_text, h, w = get_image_patch(filename)
    
            entry = dict()
            entry["label_visible"] = label_visible
            entry["label_text"] = label_text
            entry["filename"] = filename
            entry["coord"] = coord
            
            append_to_json_file(entry)
            text = f"{label_visible=}, {label_text=}, {filename=}, {coord=}"
    
            return text, pred_label_visible, pred_label_text, new_filename, new_coord, img, patch
    
    _ = demo.launch(inline=False, inbrowser=True)

In [None]:
create_app()