In [7]:
show.run()

HBox(children=(Button(description='刷新验证码', style=ButtonStyle()), Button(description='识别验证码', style=ButtonStyle…

HBox(children=(Image(value=b'', format='jpg'), Image(value=b'', format='jpg')))

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import ipywidgets as widgets
from fastai.vision import *
from fastai.metrics import error_rate, accuracy

import cv2
import random
import numpy as np
import matplotlib.pyplot as plt

In [3]:
from utils import get_all_images, get_char_images
from putText import cv2ImgAddText

In [4]:
def inference_all(img, learn):
    cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
    s1, s2, s3, s4, s5, s6, s7, s8 = get_all_images(img)
    cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img)

    o1 = learn.predict(Image(s1))[0]
    o2 = learn.predict(Image(s2))[0]
    o3 = learn.predict(Image(s3))[0]
    o4 = learn.predict(Image(s4))[0]
    o5 = learn.predict(Image(s5))[0]
    o6 = learn.predict(Image(s6))[0]
    o7 = learn.predict(Image(s7))[0]
    o8 = learn.predict(Image(s8))[0]
    
    return [o1, o2, o3, o4, o5, o6, o7, o8]


In [5]:
def inference_char(img, learn):
    cv2.cvtColor(img, cv2.COLOR_BGR2RGB, img)
    out = get_char_images(img)
    cv2.cvtColor(img, cv2.COLOR_RGB2BGR, img)

    cls = []
    
    if len(out) == 2:
        out[0] = torch.nn.functional.interpolate(out[0].unsqueeze(0), (80, 80)).squeeze()
        out[1] = torch.nn.functional.interpolate(out[1].unsqueeze(0), (80, 80)).squeeze()
        cls.append(str(learn.predict(Image(out[0]))[0]))
        cls.append(str(learn.predict(Image(out[1]))[0]))
    else:
        out = torch.nn.functional.interpolate(out.unsqueeze(0), (80, 80), 
                                              mode='bilinear', align_corners=True).squeeze()
        cls.append(str(learn.predict(Image(out))[0]))
    
    return cls

In [6]:
class Show:
    def __init__(self):
        from IPython.display import display, clear_output
        
        # Widgets
        self.refresh = widgets.Button(description="刷新验证码")
        self.answer = widgets.Button(description="识别验证码")
        self.image_box1 = widgets.Image(format='jpg')
        self.image_box2 = widgets.Image(format='jpg')
        self.output = widgets.Output()
        self.btn_box = widgets.HBox([self.refresh, self.answer, self.output])
        self.img_box = widgets.HBox([self.image_box1, self.image_box2])
        
        # Models 
        path = Path('/data/12306/images/')
        im_data = ImageDataBunch.from_folder(path,
                                             valid_pct=0.2).normalize(imagenet_stats)
        self.im_learn = cnn_learner(im_data, models.resnet34)
        self.im_learn.load('stage-1')

        path = Path('/data/12306/chars/')
        ch_data = ImageDataBunch.from_folder(path,
                                             valid_pct=0.2).normalize(imagenet_stats)
        self.ch_learn = cnn_learner(ch_data, models.resnet18)
        self.ch_learn.load('char2000-5')

        
        # Images
        self.p = Path('/data/12306/pure_img/train_data_2/').ls()
        self.img = cv2.imread(str(self.p[0]))
    
    def run(self):
        display(self.btn_box, self.img_box)

        self.image_box1.value = cv2.imencode('.jpg', self.img)[1].tobytes()
        
        def on_refresh_clicked(b):
            i = random.randint(0, len(self.p)-1)
            self.img = cv2.imread(str(self.p[i]))

            self.image_box1.value = cv2.imencode('.jpg', self.img)[1].tobytes()

        def on_answer_clicked(b):
            out = inference_all(self.img, self.im_learn)
            cls = inference_char(self.img, self.ch_learn)
            for i in range(len(out)):
                out[i] = str(out[i])
            with self.output:
                self.output.clear_output()
                print(cls)
            # Mark matched target
            for i in range(len(out)):
                if out[i]  in cls:
                    out[i] = out[i] + '〇'
            imt = cv2ImgAddText(self.img, str(out[0]), 5, 41, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[1]), 5+72, 41, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[2]), 5+72*2, 41, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[3]), 5+72*3, 41, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[4]), 5, 113, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[5]), 5+72, 113, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[6]), 5+72*2, 113, textColor=(0, 0, 255)) 
            imt = cv2ImgAddText(imt, str(out[7]), 5+72*3, 113, textColor=(0, 0, 255)) 
            self.image_box2.value = cv2.imencode('.jpg', imt)[1].tobytes()

        self.refresh.on_click(on_refresh_clicked)
        self.answer.on_click(on_answer_clicked)
        
show = Show()