In [1]:
%matplotlib inline

import argparse
import torch 
import pytorch_lightning as pl
import json
import random
import yaml
import numpy as np
import skimage.io as io
import torchvision.transforms as transform
import matplotlib.pyplot as plt
import os
from ipywidgets import interact
import matplotlib.patches as patches

from munch import Munch
from PIL import Image

from model import PostOCRLearner
from dataset import PostOCRDataLoader
import feature_engineering

from ipywidgets import interact
import matplotlib.patches as patches
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import *
import pandas as pd

import json

## I. Test 이미지 Resize , OCR api requests , Bbox 합치기 

#### 01. Test 이미지 Resize

In [2]:
import cv2
import matplotlib.pyplot as plt

files = sorted(os.listdir('/opt/final-project-level3-cv05/model/image_test'))

for img in files:
    if img != '.DS_Store': 
        im=os.path.join('/opt/final-project-level3-cv05/model/image_test',img)
        ims=cv2.imread(im)
        if ims is None:
            continue
        resize_img=cv2.resize(ims,(900,500))
        cv2.imwrite(f'test_image/{img}',resize_img)

#### 02. OCR api requests & Bbox 합치기 

In [3]:
import os 
import json
import requests

from word2line import word2line

def api(img):
    api_url = "http://118.222.179.32:30000/ocr/"
    headers = {"secret": "Boostcamp0000"}
    file_dict = {"file": open(img, "rb")}
    response = requests.post(api_url, headers=headers, files=file_dict)
    response_json=response.json()
    return response_json
    
with open('sample.json', 'r', encoding="UTF-8") as j:
    json_object = json.load(j)

files = sorted(os.listdir('/opt/final-project-level3-cv05/model/test_image'))

for idx, img in enumerate(files):
    response_json=api(os.path.join('/opt/final-project-level3-cv05/model/test_image',img))
    # bbox 병합 코드 -> json
    json_object["images"].append({"width": 900, "height": 500, "file":img, "id": idx})
    annotation = {"image_id":idx, "ocr":{"word":response_json['ocr']['word']}}
    line_annotation = word2line(annotation)
    json_object["annotations"].append(line_annotation)

with open('info_test.json', 'w', encoding="UTF-8") as j:
    json_string = json.dump(json_object,j, indent=2,ensure_ascii=False)

In [4]:
def config():
    with open('/opt/final-project-level3-cv05/model/config_inference.yaml', 'r') as f:
        cfg = yaml.safe_load(f)
    return Munch(cfg)

seed=42
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)  # if use multi-GPU
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(seed)
random.seed(seed)

In [5]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

------------------

In [6]:
cfg = config()
datamodule = PostOCRDataLoader(cfg, **cfg.Dataset)
test_model = PostOCRLearner(cfg)

best_model = '/opt/final-project-level3-cv05/model/boostcamp3_cv_final_ocr/pmucsaka/checkpoints/epoch=8_val_accuracy=0.9789.ckpt'
model=test_model.load_from_checkpoint(best_model,cfg=cfg)

trainer = pl.Trainer(**cfg.trainer)
pred=trainer.predict(model, dataloaders=datamodule)

Using 16bit native Automatic Mixed Precision (AMP)
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: 0it [00:00, ?it/s]

In [8]:
pred_categoryid=[]

model_pred_lst=[]
small_pred_lst=[]
for i in pred:
        # criterion = torch.nn.Softmax()
# prob = criterion(out)
        softmax=torch.nn.functional.softmax(i['pred'],dim=1,dtype=torch.float)
        
        prd = softmax.cpu().detach().numpy()
        # prd,idx = torch.max(prd,dim=1)

        # small_pred_lst.append([prd,idx])
        small_pred_lst.append(prd[0])

# small_pred_lst
model_pred_lst.append(np.array(small_pred_lst)[...,np.newaxis])
model_pred_lst
pred_categoryid.append(np.argmax(np.mean(np.concatenate(model_pred_lst, axis=2), axis=2), axis=1))


In [9]:
pred_categoryid

[array([ 0,  5,  8,  5,  0,  6,  0,  6,  3,  0,  8, 10,  5,  5,  7,  5, 10])]

In [10]:

categories = {
            "0": "UNKNOWN",
            "1": "name",
            "2": "phone",
            "3": "email",
            "4": "position",
            "5": "company",
            "6": "department",
            "7": "address",
            "8": "site",
            "9": "account",
            "10": "wise",
        }

root_path = 'test_image' # images folder 
anno_root = 'info_test.json' # json file folder

with open(anno_root, 'r') as f:
    train_json = json.load(f)
    images = train_json['images']
    annotations = train_json['annotations']

images_viz = dict()
for idx, item in enumerate(images):
    images_viz[item['id']] = dict()
    images_viz[item['id']]['id'] = item['id']
    images_viz[item['id']]['file_name'] = item['file']
    images_viz[item['id']]['category_id'] = list(pred_categoryid[0])

for anno in annotations:
    images_viz[anno['image_id']]['bbox'] = anno['ocr']['word']

palette = sns.color_palette('bright',11)

fnames = [(images_viz[id]['id'], images_viz[id]['file_name']) for id in images_viz]

@interact(idx=(fnames[0][0], fnames[-1][0]), cls_id=range(0, len(categories)+1))
def showImg(idx=0, cls_id=len(categories)):
    fig, ax = plt.subplots(dpi=200)
    img = io.imread(root_path + '/' +fnames[idx][1])

    anns = images_viz[idx]['bbox']
    category_ids=images_viz[idx]['category_id']

    ax.imshow(img)
    for i, ann in enumerate(anns):

        class_idx = category_ids[i]

        ax.set_title(f"{fnames[idx][1]}", fontsize=7)

        # 축 제거 
        ax.set_xticks([])
        ax.set_yticks([])
        for pos in ['right', 'top', 'bottom', 'left']:
            ax.spines[pos].set_visible(False)

        points = np.array(ann['points'])
        text = ann['text']

        # bbox 시각화 
        if(class_idx == cls_id or cls_id == 11):
            color = palette[class_idx]
            ax.add_patch(
                patches.Polygon(
                    points,
                    closed=True,
                    edgecolor=color,
                    fill=False,
                    ),
                )

            x, y = points[-1][0], points[0][1]

            text_y = y-5 if y>5 else y+5 
            plt_text = ax.text(x,text_y, f'{class_idx} : {categories[str(class_idx)]}', color='white', fontsize='3', weight='semibold', backgroundcolor=color)
            plt_text.set_bbox(dict(
                facecolor=palette[class_idx],  # background color
                alpha=0.6,  # background alpha
                edgecolor='none',  # border color
                pad=2
            ))

interactive(children=(IntSlider(value=0, description='idx', max=0), Dropdown(description='cls_id', index=11, o…