In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch 
import torch.nn as nn
import torch.nn.functional as F
import torch.optim

In [None]:
!git clone https://github.com/huyhoang17/KIE_invoice_minimal.git

Cloning into 'KIE_invoice_minimal'...
remote: Enumerating objects: 114, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 114 (delta 0), reused 1 (delta 0), pack-reused 111[K
Receiving objects: 100% (114/114), 11.33 MiB | 32.06 MiB/s, done.
Resolving deltas: 100% (25/25), done.


In [None]:
import copy
import imageio

import cv2
import numpy as np
import torch


In [None]:
!pip install -r /content/KIE_invoice_minimal/requirements.txt

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
%cd KIE_invoice_minimal

/content/KIE_invoice_minimal


In [None]:
!pip install --upgrade gdown

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting gdown
  Downloading gdown-4.4.0.tar.gz (14 kB)
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
    Preparing wheel metadata ... [?25l[?25hdone
Building wheels for collected packages: gdown
  Building wheel for gdown (PEP 517) ... [?25l[?25hdone
  Created wheel for gdown: filename=gdown-4.4.0-py3-none-any.whl size=14774 sha256=67b5fb06222b073d4936404515c71a31e4d59e116db60677e63035d38f46fedf
  Stored in directory: /root/.cache/pip/wheels/fb/c3/0e/c4d8ff8bfcb0461afff199471449f642179b74968c15b7a69c
Successfully built gdown
Installing collected packages: gdown
  Attempting uninstall: gdown
    Found existing installation: gdown 3.11.0
    Uninstalling gdown-3.11.0:
      Successfully uninstalled gdown-3.11.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installe

In [None]:
import gdown
from backend.backend_utils import timer

In [None]:
data_path = '/content/drive/MyDrive/transferjson'

In [None]:
class GraphNorm(nn.Module):
    """
    Param: []
    """

    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.num_features = num_features
        self.gamma = nn.Parameter(torch.ones(self.num_features))
        self.beta = nn.Parameter(torch.zeros(self.num_features))

    def norm(self, x):
        mean = x.mean(dim=0, keepdim=True)
        var = x.std(dim=0, keepdim=True)

        x = (x - mean) / (var + self.eps)
        return x

    def forward(self, x, graph_size):
        x_list = torch.split(x, graph_size)
        norm_list = []
        for x in x_list:
            norm_list.append(self.norm(x))

        x = torch.cat(norm_list, 0)
        return self.gamma * x + self.beta


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import LSTM
from torch.nn.utils.rnn import pack_padded_sequence
import configs as cf
import numpy as np



"""
    ResGatedGCN: Residual Gated Graph ConvNets
    An Experimental Study of Neural Networks for Variable Graphs (Xavier Bresson and Thomas Laurent, ICLR 2018)
    https://arxiv.org/pdf/1711.07553v2.pdf
"""


class MLPReadout(nn.Module):
    def __init__(self, input_dim, output_dim, L=2):  # L=nb_hidden_layers
        super().__init__()
        list_FC_layers = [
            nn.Linear(input_dim // 2 ** l, input_dim // 2 ** (l + 1), bias=True)
            for l in range(L)
        ]
        list_FC_layers.append(nn.Linear(input_dim // 2 ** L, output_dim, bias=True))
        self.FC_layers = nn.ModuleList(list_FC_layers)
        self.L = L

    def forward(self, x):
        y = x
        for l in range(self.L):
            y = self.FC_layers[l](y)
            y = F.relu(y)
        y = self.FC_layers[self.L](y)
        return y


class GatedGCNLayer(nn.Module):
    """
    Param: []
    """

    def __init__(
        self, input_dim, output_dim, dropout, graph_norm, batch_norm, residual=False
    ):
        super().__init__()
        self.in_channels = input_dim
        self.out_channels = output_dim
        self.dropout = dropout
        self.graph_norm = graph_norm
        self.batch_norm = batch_norm
        self.residual = residual

        if input_dim != output_dim:
            self.residual = False

        self.A = nn.Linear(input_dim, output_dim, bias=True)
        self.B = nn.Linear(input_dim, output_dim, bias=True)
        self.C = nn.Linear(input_dim, output_dim, bias=True)
        self.D = nn.Linear(input_dim, output_dim, bias=True)
        self.E = nn.Linear(input_dim, output_dim, bias=True)

        self.bn_node_h = GraphNorm(output_dim)
        self.bn_node_e = GraphNorm(output_dim)

    def message_func(self, edges):
        Bh_j = edges.src["Bh"]
        e_ij = (
            edges.data["Ce"] + edges.src["Dh"] + edges.dst["Eh"]
        )  # e_ij = Ce_ij + Dhi + Ehj
        edges.data["e"] = e_ij
        return {"Bh_j": Bh_j, "e_ij": e_ij}

    def reduce_func(self, nodes):
        Ah_i = nodes.data["Ah"]
        Bh_j = nodes.mailbox["Bh_j"]
        e = nodes.mailbox["e_ij"]
        sigma_ij = torch.sigmoid(e)  # sigma_ij = sigmoid(e_ij)
        # h = Ah_i + torch.mean( sigma_ij * Bh_j, dim=1 ) # hi = Ahi + mean_j alpha_ij * Bhj
        h = Ah_i + torch.sum(sigma_ij * Bh_j, dim=1) / (
            torch.sum(sigma_ij, dim=1) + 1e-6
        )  # hi = Ahi + sum_j eta_ij/sum_j' eta_ij' * Bhj <= dense attention
        return {"h": h}

    def forward(self, g, h, e, snorm_n, snorm_e, graph_node_size, graph_edge_size):

        h_in = h  # for residual connection
        e_in = e  # for residual connection

        g.ndata["h"] = h
        g.ndata["Ah"] = self.A(h)
        g.ndata["Bh"] = self.B(h)
        g.ndata["Dh"] = self.D(h)
        g.ndata["Eh"] = self.E(h)
        g.edata["e"] = e
        g.edata["Ce"] = self.C(e)
        g.update_all(self.message_func, self.reduce_func)
        h = g.ndata["h"]  # result of graph convolution
        e = g.edata["e"]  # result of graph convolution

        if self.graph_norm:
            h = h * snorm_n  # normalize activation w.r.t. graph size
            e = e * snorm_e  # normalize activation w.r.t. graph size

        if self.batch_norm:
            h = self.bn_node_h(h, graph_node_size)  # graph normalization
            e = self.bn_node_e(e, graph_edge_size)  # graph normalization

        h = F.relu(h)  # non-linear activation
        e = F.relu(e)  # non-linear activation

        if self.residual:
            h = h_in + h  # residual connection
            e = e_in + e  # residual connection

        h = F.dropout(h, self.dropout, training=self.training)
        e = F.dropout(e, self.dropout, training=self.training)

        return h, e

    def __repr__(self):
        return "{}(in_channels={}, out_channels={})".format(
            self.__class__.__name__, self.in_channels, self.out_channels
        )


class DenseLayer(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        # self.bn = nn.BatchNorm1d(in_dim)
        self.bn = nn.LayerNorm(in_dim)
        self.linear = nn.Linear(in_dim, out_dim)

    def forward(self, feat):
        feat = self.bn(feat)
        feat = F.relu(feat)
        feat = self.linear(feat)
        return feat


class GatedGCNNet(nn.Module):
    def __init__(self, net_params):
        super().__init__()
        in_dim_text = net_params["in_dim_text"]
        in_dim_node = net_params["in_dim_node"]  # node_dim (feat is an integer)
        in_dim_edge = net_params["in_dim_edge"]  # edge_dim (feat is a float)
        hidden_dim = net_params["hidden_dim"]
        n_classes = net_params["n_classes"]
        dropout = net_params["dropout"]
        n_layers = net_params["L"]
        self.ohem = net_params["OHEM"]

        self.readout = net_params["readout"]
        self.graph_norm = net_params["graph_norm"]
        self.batch_norm = net_params["batch_norm"]
        self.residual = net_params["residual"]
        self.n_classes = n_classes
        self.device = net_params["device"]

        self.embedding_text = nn.Embedding(
            in_dim_text, hidden_dim
        )  # node feat is an integer
        self.embedding_h = nn.Linear(in_dim_node, hidden_dim)  # edge feat is a float
        self.embedding_e = nn.Linear(in_dim_edge, hidden_dim)  # edge feat is a float
        self.layers = nn.ModuleList(
            [
                GatedGCNLayer(
                    hidden_dim,
                    hidden_dim,
                    dropout,
                    self.graph_norm,
                    self.batch_norm,
                    self.residual,
                )
                for _ in range(n_layers)
            ]
        )
        self.dense_layers = nn.ModuleList(
            [
                DenseLayer(hidden_dim + i * hidden_dim, hidden_dim)
                for i in range(1, n_layers + 1)
            ]
        )

        self.lstm = LSTM(
            input_size=hidden_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=True,
        )
        self.MLP_layer = MLPReadout(hidden_dim, n_classes)
        self.criterion = nn.CrossEntropyLoss(ignore_index=-100)

    def lstm_text_embeding(self, text, text_length):
        # FIXED
        packed_sequence = pack_padded_sequence(
            text, text_length.cpu(), batch_first=True, enforce_sorted=False
        )

        # packed_sequence = packed_sequence.to("cuda")
        outputs_packed, (h_last, c_last) = self.lstm(packed_sequence)
        # outputs, _ = pad_packed_sequence(outputs_packed)
        return h_last.mean(0)

    def clamp(self):
        min = torch.tensor(0.0).cuda()
        with torch.no_grad():
            for m in self.modules():
                if isinstance(m, UnifiedNorm):
                    m.lambda_batch.masked_fill_(m.lambda_batch < 0, min)
                    m.lambda_graph.masked_fill_(m.lambda_graph < 0, min)
                    m.lambda_adja.masked_fill_(m.lambda_adja < 0, min)
                    m.lambda_node.masked_fill_(m.lambda_node < 0, min)

    def concat(self, h_list, l):
        h_concat = torch.cat(h_list, dim=1)
        h = self.dense_layers[l](h_concat)
        return h

    def forward(
        self,
        g,
        h,
        e,
        text,
        text_length,
        snorm_n,
        snorm_e,
        graph_node_size,
        graph_edge_size,
    ):
        # input embedding
        h_embeding = self.embedding_h(h)
        e_embeding = self.embedding_e(e)

        # FIXED
        text_embeding = self.embedding_text(text.long())
        print(text_length)
        text_embeding = self.lstm_text_embeding(text_embeding, text_length)

        text_embeding = F.normalize(text_embeding)

        e = e_embeding
        h = h_embeding + text_embeding
        all_h = [h]
        for i, conv in enumerate(self.layers):
            h1, e = conv(g, h, e, snorm_n, snorm_e, graph_node_size, graph_edge_size)
            all_h.append(h1)
            h = self.concat(all_h, i)

        # output
        h_out = self.MLP_layer(h)

        return h_out

    def _ohem(self, pred, label):
        # import pdb; pdb.set_trace()
        pred = pred.data.cpu().numpy()
        label = label.data.cpu().numpy()

        pos_num = sum(label != 0)
        neg_num = pos_num * self.ohem

        pred_value = pred[:, 1:].max(1)

        neg_score_sorted = np.sort(-pred_value[label == 0])

        if neg_score_sorted.shape[0] > neg_num:
            threshold = -neg_score_sorted[neg_num - 1]
            mask = (pred_value >= threshold) | (label != 0)
        else:
            mask = label != -1
        return torch.from_numpy(mask)

    def loss(self, pred, label):

        mask_label = label.clone()
        mask = self._ohem(pred, label)
        mask = mask.to("cuda")
        mask_label[mask == False] = -100
        loss = self.criterion(pred, mask_label)

        # calculating label weights for weighted loss computation
        # V = label.size(0)
        # label_count = torch.bincount(label)
        # label_count = label_count[label_count.nonzero()].squeeze()
        # cluster_sizes = torch.zeros(self.n_classes).long().to(self.device)
        # cluster_sizes[torch.unique(label)] = label_count
        # weight = (V - cluster_sizes).float() / V
        # weight *= (cluster_sizes>0).float()

        # # weighted cross-entropy for unbalanced classes
        # criterion = nn.CrossEntropyLoss(weight=weight)
        # loss = criterion(pred, label)

        return loss


if __name__ == "__main__":

    net_params = {}
    net_params["in_dim_text"] = len(cf.alphabet)
    net_params["in_dim_node"] = 10
    net_params["in_dim_edge"] = 2
    net_params["hidden_dim"] = 512
    net_params["out_dim"] = 384
    net_params["n_classes"] = 5
    net_params["in_feat_dropout"] = 0.1
    net_params["dropout"] = 0.0
    net_params["L"] = 4
    net_params["readout"] = True
    net_params["graph_norm"] = True
    net_params["batch_norm"] = True
    net_params["residual"] = True
    net_params["device"] = "cuda"
    net_params["OHEM"] = 3
    
    model = GatedGCNNet(net_params)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.00075)
    print(model)


GatedGCNNet(
  (embedding_text): Embedding(3791, 512)
  (embedding_h): Linear(in_features=10, out_features=512, bias=True)
  (embedding_e): Linear(in_features=2, out_features=512, bias=True)
  (layers): ModuleList(
    (0): GatedGCNLayer(in_channels=512, out_channels=512)
    (1): GatedGCNLayer(in_channels=512, out_channels=512)
    (2): GatedGCNLayer(in_channels=512, out_channels=512)
    (3): GatedGCNLayer(in_channels=512, out_channels=512)
  )
  (dense_layers): ModuleList(
    (0): DenseLayer(
      (bn): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=1024, out_features=512, bias=True)
    )
    (1): DenseLayer(
      (bn): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=1536, out_features=512, bias=True)
    )
    (2): DenseLayer(
      (bn): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=2048, out_features=512, bias=True)
    )
    (3): DenseLayer(
      

In [None]:
model.train()

GatedGCNNet(
  (embedding_text): Embedding(3791, 512)
  (embedding_h): Linear(in_features=10, out_features=512, bias=True)
  (embedding_e): Linear(in_features=2, out_features=512, bias=True)
  (layers): ModuleList(
    (0): GatedGCNLayer(in_channels=512, out_channels=512)
    (1): GatedGCNLayer(in_channels=512, out_channels=512)
    (2): GatedGCNLayer(in_channels=512, out_channels=512)
    (3): GatedGCNLayer(in_channels=512, out_channels=512)
  )
  (dense_layers): ModuleList(
    (0): DenseLayer(
      (bn): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=1024, out_features=512, bias=True)
    )
    (1): DenseLayer(
      (bn): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=1536, out_features=512, bias=True)
    )
    (2): DenseLayer(
      (bn): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=2048, out_features=512, bias=True)
    )
    (3): DenseLayer(
      

In [None]:
model = model.to("cuda")

In [None]:
print(model)

GatedGCNNet(
  (embedding_text): Embedding(3791, 512)
  (embedding_h): Linear(in_features=10, out_features=512, bias=True)
  (embedding_e): Linear(in_features=2, out_features=512, bias=True)
  (layers): ModuleList(
    (0): GatedGCNLayer(in_channels=512, out_channels=512)
    (1): GatedGCNLayer(in_channels=512, out_channels=512)
    (2): GatedGCNLayer(in_channels=512, out_channels=512)
    (3): GatedGCNLayer(in_channels=512, out_channels=512)
  )
  (dense_layers): ModuleList(
    (0): DenseLayer(
      (bn): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=1024, out_features=512, bias=True)
    )
    (1): DenseLayer(
      (bn): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=1536, out_features=512, bias=True)
    )
    (2): DenseLayer(
      (bn): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (linear): Linear(in_features=2048, out_features=512, bias=True)
    )
    (3): DenseLayer(
      

In [None]:
def make_text_encode(text):
    text_encode = []
    for t in text.upper():
        if t not in cf.alphabet:
            text_encode.append(cf.alphabet.index(" "))
        else:
            text_encode.append(cf.alphabet.index(t))
    return np.array(text_encode)


def prepare_data(cells, text_key="vietocr_text"):
    texts = []
    text_lengths = []
    polys = []
    for cell in cells:
        text = cell[text_key]
        text_encode = make_text_encode(text)
        text_lengths.append(text_encode.shape[0])
        texts.append(text_encode)

        poly = copy.deepcopy(cell["poly"])
        poly.append(np.max(poly[0::2]) - np.min(poly[0::2]))
        poly.append(np.max(poly[1::2]) - np.min(poly[1::2]))
        poly = list(map(int, poly))
        polys.append(poly)

    texts = np.array(texts, dtype=object)
    text_lengths = np.array(text_lengths)
    polys = np.array(polys)
    return texts, text_lengths, polys


def prepare_pipeline(boxes, edge_data, text, text_length):
    box_min = boxes.min(0)
    box_max = boxes.max(0)

    boxes = (boxes - box_min) / (box_max - box_min)
    boxes = (boxes - 0.5) / 0.5

    edge_min = edge_data.min(0)
    edge_max = edge_data.max(0)

    edge_data = (edge_data - edge_min) / (edge_max - edge_min)
    edge_data = (edge_data - 0.5) / 0.5
  
    return boxes, edge_data, text, text_length



def prepare_graph(cells):
    texts, text_lengths, boxes = prepare_data(cells)

    origin_boxes = boxes.copy()
    node_nums = text_lengths.shape[0]

    src = []
    dst = []
    edge_data = []
    for i in range(node_nums):
        for j in range(node_nums):
            if i == j:
                continue

            edata = []
            # y distance
            y_distance = np.mean(boxes[i][:8][1::2]) - np.mean(boxes[j][:8][1::2])
            # w = boxes[i, 8]
            h = boxes[i, 9]
            if np.abs(y_distance) > 3 * h:
                continue

            x_distance = np.mean(boxes[i][:8][0::2]) - np.mean(boxes[j][:8][0::2])
            edata.append(y_distance)
            edata.append(x_distance)

            edge_data.append(edata)
            src.append(i)
            dst.append(j)

    edge_data = np.array(edge_data)
    g = dgl.DGLGraph()
    g.add_nodes(node_nums)
    g.add_edges(src, dst)

    boxes, edge_data, text, text_length = prepare_pipeline(
        boxes, edge_data, texts, text_lengths
    )
    boxes = torch.from_numpy(boxes).float()
    edge_data = torch.from_numpy(edge_data).float()

    tab_sizes_n = g.number_of_nodes()
    tab_snorm_n = torch.FloatTensor(tab_sizes_n, 1).fill_(1.0 / float(tab_sizes_n))
    snorm_n = tab_snorm_n.sqrt()

    tab_sizes_e = g.number_of_edges()
    tab_snorm_e = torch.FloatTensor(tab_sizes_e, 1).fill_(1.0 / float(tab_sizes_e))
    snorm_e = tab_snorm_e.sqrt()

    max_length = text_lengths.max()
    new_text = [
        np.expand_dims(np.pad(t, (0, max_length - t.shape[0]), "constant"), axis=0)
        for t in text
    ]
    texts = np.concatenate(new_text)

    texts = torch.from_numpy(np.array(texts))
    text_length = torch.from_numpy(np.array(text_length))

    graph_node_size = [g.number_of_nodes()]
    graph_edge_size = [g.number_of_edges()]

    return (
        g,
        boxes,
        edge_data,
        snorm_n,
        snorm_e,
        texts,
        text_length,
        origin_boxes,
        graph_node_size,
        graph_edge_size,
    )


In [None]:
import PIL

In [None]:
import json

In [None]:
a = open('/content/Labels.json')
L = json.load(a)

In [None]:
import os

In [None]:
images_folder = ['1.jpg','2.jpg','3.jpg','4.jpg','5.jpg','6.jpg','7.jpg','8.jpg','9.jpg', '10.jpg']

In [None]:
imgs=[]
for i in images_folder:
  imgs.append(cv2.imread(data_path + '/Images/' + i))


In [None]:
imgs=np.array(imgs)

In [None]:
L.keys()

dict_keys(['1', '2', '3', '4', '5', '6', '7', '8', '9', '10'])

In [None]:
Labels = []
for key in L.keys():
  print(key)
  Labels.append(L[key]['cells'])

1
2
3
4
5
6
7
8
9
10


In [None]:
import dgl

DGL backend not selected or invalid.  Assuming PyTorch for now.
Using backend: pytorch


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [None]:
len(Labels[0])

84

In [None]:
score_ths=0.98
device = "cuda"
n_epochs = 19
for epoch in range(1, n_epochs + 1):
        loss_train = 0.0
        for i in range(10):
            (
            batch_graphs,
            batch_x,
            batch_e,
            batch_snorm_n,
            batch_snorm_e,
            text,
            text_length,
            boxes,
            graph_node_size,
            graph_edge_size,
            ) = prepare_graph(Labels[i])

            batch_graphs = batch_graphs.to(device)
            batch_x = batch_x.to(device)
            batch_e = batch_e.to(device)
            text = text.to(device)
            text_length = text_length.to(device)
            batch_snorm_e = batch_snorm_e.to(device)
            batch_snorm_n = batch_snorm_n.to(device)
            batch_graphs = batch_graphs.to(device)

            batch_scores = model.forward(
            batch_graphs,
            batch_x,
            batch_e,
            text,
            text_length,
            batch_snorm_n,
            batch_snorm_e,
            graph_node_size,
            graph_edge_size,
    )
            
            labels_key = []
            for t in range (len(Labels[i])):
              if Labels[i][t]["cate_text"] == "OTHER":
                labels_key.append(0)
              if Labels[i][t]["cate_text"] == "ADDRESS":
                labels_key.append(1)
              if Labels[i][t]["cate_text"] == "STATUS":
                labels_key.append(2)
              if Labels[i][t]["cate_text"] == "NUMBER":
                labels_key.append(3)
              if Labels[i][t]["cate_text"] == "NAME":
                labels_key.append(4)
            labels_key = torch.tensor(labels_key).to("cuda")
            
            loss = model.loss(batch_scores, labels_key)
          
            loss_train += loss
        loss_train /= 10
        print("{}: {}".format(epoch, loss_train))
        optimizer.zero_grad()     
        loss_train.backward() 
        optimizer.step()
        



1: 1.5019820928573608
2: 1.091110348701477
3: 0.8703385591506958
4: 0.7411569952964783
5: 0.5751540064811707
6: 0.5211777091026306
7: 0.4624095559120178
8: 0.3605923354625702
9: 0.3187809884548187
10: 0.2602454721927643
11: 0.21209333837032318
12: 0.16747455298900604
13: 0.125735342502594
14: 0.06969473510980606
15: 0.03844205662608147
16: 0.02322150580585003
17: 0.012573055922985077
18: 0.005203959997743368
19: 0.0032408263068646193


In [None]:
torch.save(model.state_dict(), "/content/kie_mcocr.pkl")

In [None]:
a = open('/content/Labels_test.json')
L = json.load(a)
Labels_test = []
for key in L.keys():
  print(key)
  Labels_test.append(L[key]['cells'])

1
2
3
4
5
6
7
8
9
10


In [None]:
(
batch_graphs,
batch_x,
batch_e,
batch_snorm_n,
batch_snorm_e,
text,
text_length,
boxes,
graph_node_size,
graph_edge_size,
) = prepare_graph(Labels_test[2])

batch_graphs = batch_graphs.to(device)
batch_x = batch_x.to(device)
batch_e = batch_e.to(device)
text = text.to(device)
text_length = text_length.to(device)
batch_snorm_e = batch_snorm_e.to(device)
batch_snorm_n = batch_snorm_n.to(device)
batch_graphs = batch_graphs.to(device)

batch_scores = model.forward(
batch_graphs,
batch_x,
batch_e,
text,
text_length,
batch_snorm_n,
batch_snorm_e,
graph_node_size,
graph_edge_size,
)

RuntimeError: ignored

In [None]:
def postprocess_scores(batch_scores, score_ths=0.98, get_max=False):
    values, preds = [], []
    batch_scores = batch_scores.cpu().softmax(1)
    for score in batch_scores:
        _score = score.detach().cpu().numpy()
        values.append(_score.max())
        pred_index = np.argmax(_score)
        if get_max:
            preds.append(pred_index)
        else:
            if pred_index != 0 and _score.max() >= score_ths:
                preds.append(pred_index)
            else:
                preds.append(0)

    preds = np.array(preds)
    return values, preds

In [None]:
values, preds = postprocess_scores(
        batch_scores, score_ths=cf.score_ths, get_max=cf.get_max
    )

In [None]:
def postprocess_write_info(merged_cells, preds, text_key="vietocr_text"):
    # 1/2/3/4
    # 'ADDRESS', 'SELLER', 'TIMESTAMP', 'TOTAL_COST'
    kie_info = dict()
    preds = np.array(preds)
    for i in range(1, 5):
        indexes = np.where(preds == i)[0]
        if len(indexes) > 0:
            text_output = " ".join(merged_cells[index][text_key] for index in indexes)
            kie_info[cf.node_labels[i].title()] = text_output
    return kie_info

In [None]:
kie_info = postprocess_write_info(Labels_test[0], preds)

In [None]:
def vis_kie_pred(img, preds, values, boxes, save_path):
    vis_img = img.copy()
    length = preds.shape[0]
    for i in range(length):

        pred_id = preds[i]
        if pred_id != 0:
            msg = "{}-{}".format(cf.node_labels[preds[i]], round(float(values[i]), 2))
            color = (0, 0, 255)

            info = boxes[i]
            box = np.array(
                [
                    [int(info[0]), int(info[1])],
                    [int(info[2]), int(info[3])],
                    [int(info[4]), int(info[5])],
                    [int(info[6]), int(info[7])],
                ]
            )
            cv2.polylines(vis_img, [box], 1, (255, 0, 0))
            cv2.putText(
                vis_img,
                msg,
                (int(info[0]), int(info[1])),
                cv2.FONT_HERSHEY_SIMPLEX,
                0.5,
                color,
                1,
                cv2.LINE_AA,
            )

    imageio.imwrite(save_path, vis_img)
    return vis_img

In [None]:
img = cv2.imread('/content/3.jpg')

In [None]:
save_path = os.path.join('/content/results', "{}.jpg".format('3'))
vis_img = vis_kie_pred(img, preds, values, boxes, save_path)