In [10]:
import sys
import os
import time
import argparse
from matplotlib import pyplot as plt
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.autograd import Variable

from PIL import Image
from craft import CRAFT
import uuid

import cv2
from skimage import io
import numpy as np
import craft_utils
import imgproc
import file_utils
import json
import zipfile
from collections import OrderedDict
%matplotlib inline

In [2]:
def copyStateDict(state_dict):
    if list(state_dict.keys())[0].startswith("module"):
        start_idx = 1
    else:
        start_idx = 0
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = ".".join(k.split(".")[start_idx:])
        new_state_dict[name] = v
    return new_state_dict

def run_net(net, image, text_threshold, link_threshold, low_text, cuda, poly, refine_net=None):
    t0 = time.time()

    # resize
    img_resized, target_ratio, size_heatmap = imgproc.resize_aspect_ratio(image, 1280, interpolation=cv2.INTER_LINEAR, mag_ratio=1.5)
    ratio_h = ratio_w = 1 / target_ratio

    # preprocessing
    x = imgproc.normalizeMeanVariance(img_resized)
    x = torch.from_numpy(x).permute(2, 0, 1)    # [h, w, c] to [c, h, w]
    x = Variable(x.unsqueeze(0))                # [c, h, w] to [b, c, h, w]
    if cuda:
        x = x.cuda()

    # forward pass
    with torch.no_grad():
        y, feature = net(x)

    # make score and link map
    score_text = y[0,:,:,0].cpu().data.numpy()
    score_link = y[0,:,:,1].cpu().data.numpy()

    # refine link
    if refine_net is not None:
        with torch.no_grad():
            y_refiner = refine_net(y, feature)
        score_link = y_refiner[0,:,:,0].cpu().data.numpy()

    t0 = time.time() - t0
    t1 = time.time()

    # Post-processing
    boxes, polys = craft_utils.getDetBoxes(score_text, score_link, text_threshold, link_threshold, low_text, poly)

    # coordinate adjustment
    boxes = craft_utils.adjustResultCoordinates(boxes, ratio_w, ratio_h)
    polys = craft_utils.adjustResultCoordinates(polys, ratio_w, ratio_h)
    for k in range(len(polys)):
        if polys[k] is None: polys[k] = boxes[k]

    t1 = time.time() - t1

    # render results (optional)
    render_img = score_text.copy()
    render_img = np.hstack((render_img, score_link))
    ret_score_text = imgproc.cvt2HeatmapImg(render_img)

    # if args.show_time : print("\ninfer/postproc time : {:.3f}/{:.3f}".format(t0, t1))

    return boxes, polys, ret_score_text

In [3]:
net = CRAFT()  
net.load_state_dict(copyStateDict(torch.load("./weights/craft_mlt_25k.pth")))
net = net.cuda()
net = torch.nn.DataParallel(net)
net.eval()



DataParallel(
  (module): CRAFT(
    (basenet): vgg16_bn(
      (slice1): Sequential(
        (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
        (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (9): ReLU(inplace=True)
        (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (slice2): Sequential(


In [16]:
for image in os.listdir("/home/hb/pubg_640_dataset/test_names"):
    img = cv2.imread(os.path.join("/home/hb/pubg_640_dataset/test_names", image))
    bboxes, polys, score_text = run_net(net, img, 0.7, 
                                        9999999, 0.5, 
                                        True, False, None)
    print(bboxes)
    bboxes = sorted(bboxes, key=lambda x: x[0][0])
    _a = []
    for box in bboxes:
        x1, y1 = [int(x) for x in box[0]]
        x2, y2 = [int(x) for x in box[2]]
        crop = img[y1:y2, x1:x2]
        cv2.imwrite(f"/home/hb/pubg_640_dataset/chars/{uuid.uuid4().hex}.png", crop)
        crop = cv2.resize(crop, (112,112))
        _a.append(crop)
    col = np.hstack(_a)

[[[ 4.         4.       ]
  [13.333333   4.       ]
  [13.333333  16.       ]
  [ 4.        16.       ]]

 [[28.         5.3333335]
  [36.         5.3333335]
  [36.        16.       ]
  [28.        16.       ]]

 [[36.         5.3333335]
  [40.         5.3333335]
  [40.        14.666667 ]
  [36.        14.666667 ]]

 [[48.         5.3333335]
  [56.         5.3333335]
  [56.        16.       ]
  [48.        16.       ]]

 [[12.         6.6666665]
  [20.         6.6666665]
  [20.        16.       ]
  [12.        16.       ]]

 [[20.         6.6666665]
  [28.         6.6666665]
  [28.        16.       ]
  [20.        16.       ]]

 [[40.         6.6666665]
  [48.         6.6666665]
  [48.        16.       ]
  [40.        16.       ]]

 [[57.333332   6.6666665]
  [64.         6.6666665]
  [64.        16.       ]
  [57.333332  16.       ]]

 [[64.         6.6666665]
  [72.         6.6666665]
  [72.        16.       ]
  [64.        16.       ]]]
[[[48.         4.       ]
  [54.666668   4.   