In [1]:
import cv2 as cv
import torch.nn as nn
import torch
import torch.nn.functional as F
import os
import pyro
import numpy as np
import pyro.optim as optim
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO
from os import walk
import os
import os.path
import cv2
import glob
import imutils
import uuid
import random

In [5]:
from shutil import move

sampleFolder = "generated_captcha_images"
testFolder = "test_images"
_, _, filenames = next(walk(sampleFolder))
'''
# moving 10% train data as test data
test_files = random.sample(filenames, int(len(filenames)* 0.1))
for f in test_files:
    s_p = os.path.join(sampleFolder, f)
    d_p = os.path.join(testFolder, f)
    
    move(s_p, d_p)
'''
_, _, test_filenames = next(walk(testFolder))

In [3]:
# Promoted by https://medium.com/@ageitgey/how-to-break-a-captcha-system-in-15-minutes-with-machine-learning-dbebb035a710
def split_image(img, gd_truth_label, TRAIN=True):
    # Add some extra padding around the image
    gray = cv2.copyMakeBorder(img, 8, 8, 8, 8, cv2.BORDER_REPLICATE)

    # threshold the image (convert it to pure black and white)
    thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)[1]

    # find the contours (continuous blobs of pixels) the image
    contours = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    # Hack for compatibility with different OpenCV versions
    contours = contours[0] #if imutils.is_cv2() else contours[1]

    letter_image_regions = []

    # Now we can loop through each of the four contours and extract the letter
    # inside of each one
    for contour in contours:
        # Get the rectangle that contains the contour
        (x, y, w, h) = cv2.boundingRect(contour)

        # Compare the width and height of the contour to detect letters that
        # are conjoined into one chunk
        if w / h > 1.25:
            # This contour is too wide to be a single letter!
            # Split it in half into two letter regions!
            half_width = int(w / 2)
            letter_image_regions.append((x, y, half_width, h))
            letter_image_regions.append((x + half_width, y, half_width, h))
        else:
            # This is a normal letter by itself
            letter_image_regions.append((x, y, w, h))
    
    if ((len(letter_image_regions) != 4) and TRAIN):
        return None, None

    # Sort the detected letter images based on the x coordinate to make sure
    # we are processing them from left-to-right so we match the right image
    # with the right letter
    letter_image_regions = sorted(letter_image_regions, key=lambda x: x[0])
    char_images = []
    labels = []
    # Save out each letter as a single image
    for letter_bounding_box, letter_text in zip(letter_image_regions, gd_truth_label):
        # Grab the coordinates of the letter in the image
        x, y, w, h = letter_bounding_box

        # Extract the letter from the original image with a 2-pixel margin around the edge
        letter_image = gray[y - 2:y + h + 2, x - 2:x + w + 2]
        char_images.append(letter_image)
        labels.append(letter_text)
    return char_images, labels

In [4]:
# Promoted by https://medium.com/@ageitgey/how-to-break-a-captcha-system-in-15-minutes-with-machine-learning-dbebb035a710
def resize_to_fit(image, width, height):
    """
    A helper function to resize an image to fit within a given size
    :param image: image to resize
    :param width: desired width in pixels
    :param height: desired height in pixels
    :return: the resized image
    """

    # grab the dimensions of the image, then initialize
    # the padding values
    (h, w) = image.shape[:2]

    # if the width is greater than the height then resize along
    # the width
    if w > h:
        image = imutils.resize(image, width=width)

    # otherwise, the height is greater than the width so resize
    # along the height
    else:
        image = imutils.resize(image, height=height)

    # determine the padding values for the width and height to
    # obtain the target dimensions
    padW = int((width - image.shape[1]) / 2.0)
    padH = int((height - image.shape[0]) / 2.0)

    # pad the image then apply one more resizing to handle any
    # rounding issues
    image = cv2.copyMakeBorder(image, padH, padH, padW, padW,
        cv2.BORDER_REPLICATE)
    image = cv2.resize(image, (width, height))

    # return the pre-processed image
    return image

In [None]:

vocabulary = {
    "1" : 1,
    "2" : 2,
    "3" : 3,
    "4" : 4,
    "5" : 5,
    "6" : 6,
    "7" : 7,
    "8" : 8,
    "9" : 9,
    "a" : 10,
    "b" : 11,
    "c" : 12,
    "d" : 13,
    "e" : 14,
    "f" : 15,
    "g" : 16,
    "h" : 17,
    "i" : 18,
    "j" : 19,
    "k" : 20,
    "l" : 21,
    "m" : 22,
    "n" : 23,
    "o" : 24,
    "p" : 25,
    "q" : 26,
    "r" : 27,
    "s" : 28,
    "t" : 29,
    "u" : 30,
    "v" : 31,
    "w" : 32, 
    "x" : 33,
    "y" : 34,
    "z" : 35
}
vocabulary_size = len(vocabulary)

MAXCHAR = 10 # assume a captcha has at most 10 chars

# p(N | img)
class NumNet(nn.Module):
    def __init__(self, img_size, out_size = 10):
        super(NumNet, self).__init__()
        self.neural_net = nn.Sequential(
            nn.Linear(img_size[0] * img_size[1], img_size[0] * img_size[1] * 2),
            nn.ReLU(),
            nn.Linear(img_size[0] * img_size[1] * 2, 128),
            nn.ReLU(),
            nn.Linear(128, out_size),
            nn.LogSoftmax())
        
    def forward(self, img):
        img = torch.reshape(img, (-1,))
        prob = self.neural_net(img)
        return prob

# p(c | img)
class CharNet(nn.Module):
    def __init__(self, img_size, out_size=vocabulary_size, noise=None):
        super(CharNet, self).__init__()
        self.neural_net = nn.Sequential(
            nn.Linear(img_size[0] * img_size[1], img_size[0] * img_size[1] * 2),
            nn.ReLU(),
            nn.Linear(img_size[0] * img_size[1] * 2, 128),
            nn.ReLU(),
            nn.Linear(128, out_size),
            nn.Softmax())

    def forward(self, img, noise=None):
        img = torch.reshape(img, (-1,))
        prob = self.neural_net(img)
        return prob

class CaptchaSolver(nn.Module):
    def __init__(self):
        super().__init__()
        self.numNet = NumNet((24, 72), MAXCHAR)
        self.charNet = CharNet((20, 20), vocabulary_size)
        self.epi = 0
        
    def model(self, original_img, imgs, n_char, charList):
        pyro.module("captchasolver", self)
        num_p = torch.tensor(1 / MAXCHAR).repeat(MAXCHAR)
        N = pyro.sample("num_char", dist.Categorical(num_p)).float()
        pyro.sample("num_char_obs", dist.Normal(N, torch.tensor(1e-6)), obs=n_char)
        
        sampled_c_probs = []
        for i in range(1, int(n_char) + 1):
            c_prob = torch.tensor(1 / vocabulary_size).repeat(vocabulary_size)
            c_prob = torch.unsqueeze(c_prob, dim=0)
            sampled_c_probs.append(c_prob)
        c_probs = torch.cat(sampled_c_probs, dim=0)
        chars = pyro.sample("chars", dist.Categorical(c_probs)).float()
        sigmas = torch.tensor([1e-6, 1e-6, 1e-6, 1e-6])
        pyro.sample("chars_obs", dist.Normal(chars, sigmas), obs=charList)
    
    def guide(self, original_img, imgs, n_char, charList):
        pyro.module("captchasolver", self)
        num_p = self.numNet(original_img)
        N = pyro.sample("num_char", dist.Categorical(num_p)).float()
        sampled_c_probs = []
        for i in range(int(n_char)):
            segmented_img = imgs[i]
            c_prob = self.charNet(segmented_img)
            
            c_prob = torch.unsqueeze(c_prob, dim=0)
            sampled_c_probs.append(c_prob)
        c_probs = torch.cat(sampled_c_probs, dim=0)
        chars = pyro.sample("chars", dist.Categorical(c_probs)).float()
        sigmas = torch.tensor([0.05, 0.05, 0.05, 0.05]).float()

captchaSolver = CaptchaSolver()
model = captchaSolver.model
guide = captchaSolver.guide
learning_rate = 2e-4
optimizer = optim.Adam({"lr":learning_rate})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

def optimize():
    loss = 0
    num_steps = 80000
    print("Optimizing...")
    for t in range(num_steps):
        loss += inference(t)
        if (t % 1000 == 0) and (t > 0):
            print("at {} step loss is {}".format(t, loss / t))
            run_test(n_test=250, mute=True)

def inference(t):
    global DATA, RAW_DATA
    batchNum = 64
    loss = 0
    batchFileNames = random.sample(list(DATA), batchNum)
    batchImages = [ DATA[k] for k in batchFileNames ]
    batchOriImages = [ RAW_DATA[k] for k in batchFileNames ]
    for i in range(batchNum):
        label = batchFileNames[i]
        if t == 0:
            print("label is", label)
        label = convert_label_to_indexList(label)
        loss += svi.step(batchOriImages[i], batchImages[i], torch.tensor(len(label)).float(), label) / batchNum
    return loss

def convert_label_to_indexList(label):
    indexList = []
    for char in label:
        indexList.append(vocabulary[char.lower()] - 1)
    return torch.tensor(indexList).float()

# compute the global mean of the images
def calculate_global_mean():
    imgs = []
    for file in filenames:
        sample_img_gray = cv.imread(os.path.join(sampleFolder, file), cv.IMREAD_GRAYSCALE)
        sample_img_gray = np.expand_dims(sample_img_gray, axis=0)
        imgs.append(sample_img_gray)
    imgs = np.concatenate(imgs)
    return np.mean(imgs)

GLOBAL_MEAN = calculate_global_mean()
print("GLOBAL_MEAN is", GLOBAL_MEAN)

def preprocess_image(img, gd_truth_label="", TRAIN=True, saveFolder=None):
    letter_images, letter_labels = split_image(img, gd_truth_label, TRAIN)
    if letter_images is None or letter_labels is None:
        return None, None
    resized_letter_images = []
    for letter_image, letter_label in zip(letter_images, letter_labels):
        letter_image = letter_image.astype(np.uint8) # convert type to uint8
        letter_image = resize_to_fit(letter_image, width=20, height=20)
        if saveFolder: # save the extracted and resized image
            # Get the folder to save the image in
            save_dir = os.path.join(saveFolder, letter_label)
            
            # if the output directory does not exist, create it
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)

            # write the letter image to a file
            save_path = os.path.join(save_dir, "{}.png".format(uuid.uuid4()))
            cv2.imwrite(save_path, letter_image)
        
        letter_image = torch.from_numpy(letter_image - GLOBAL_MEAN).float() / 255 # normalize
        resized_letter_images.append(letter_image)
            
    return resized_letter_images, letter_labels

import time

DATA = {} # scaled/cropped input
RAW_DATA = {} # original input 
def prepare_data():
    global DATA, RAW_DATA
    for file in filenames:
        sample_img_gray = cv.imread(os.path.join(sampleFolder, file), cv.IMREAD_GRAYSCALE)
        gd_truth_label = file.split(".")[0]
        resized_letter_images, letter_labels = preprocess_image(sample_img_gray, gd_truth_label, TRAIN=True, saveFolder=None)
        if resized_letter_images is None or letter_labels is None:
            continue
        sample_img_gray = torch.from_numpy(sample_img_gray - GLOBAL_MEAN).float() / 255
        DATA[gd_truth_label] = resized_letter_images
        RAW_DATA[gd_truth_label] = sample_img_gray

def check_data():
    global DATA
    for k, v in DATA.items():
        assert len(k) == 4
        assert len(v) == 4
        for sub_img in v:
            assert int(sub_img.shape[0]) == 20
            assert int(sub_img.shape[1]) == 20


def predict_char(img):
    n_sample = 20
    outs = []
    for i in range(n_sample):
        c_prob = captchaSolver.charNet(img)
        char = dist.Categorical(c_prob).sample()
        outs.append(char)
    return round(np.mean(np.array(outs)))

def predict_num(img):
    n_sample = 20
    outs = []
    for i in range(n_sample):
        n_prob = captchaSolver.numNet(img)
        N = int(dist.Categorical(n_prob).sample())
        outs.append(N)
    return round(np.mean(np.array(outs)))            
            
def test(testFileName, mute=False):
    test_img_gray = cv.imread(os.path.join(sampleFolder, testFileName), cv.IMREAD_GRAYSCALE)
    groundtruth = testFileName.split('.')[0]
    
    if not mute:
        print("testing expected label is", groundtruth)
    trueCharList =  convert_label_to_indexList(groundtruth)
    if not mute:
        print("testing expected indices are", trueCharList)
    resized_letter_images, letter_labels = preprocess_image(test_img_gray, groundtruth)
    if resized_letter_images is None or letter_labels is None:
        return 0, 0, 0
    
    test_img_gray = torch.from_numpy(test_img_gray - GLOBAL_MEAN).float() / 255
    N = predict_num(test_img_gray)
    if not mute:
        print("testing N is", N)
    
    corre_num = 1 if N == len(trueCharList) else 0

    chars = []
    for i in range(min(N, len(resized_letter_images))):
        segmented_img = resized_letter_images[i]
        char = predict_char(segmented_img)
        chars.append(char)
        
    if not mute:
        print("testing actual indices are", chars)
    count = 0
    
    for i in range(min(len(chars), len(trueCharList))):
        if int(chars[i]) == int(trueCharList[i]):
            count += 1
    
    correct_word = 1 if count == len(trueCharList) else 0
    return count, corre_num, correct_word

def run_test(n_test=250, mute=False):
    
    NUM_TEST = n_test
    correct_num = 0
    correct_chars = 0
    correct_words = 0
    for i in range(NUM_TEST):
        testFileName = random.sample(test_filenames, 1)[0]
        correct_c, corrent_n, correct_word = test(testFileName, mute)
        correct_chars += correct_c
        correct_num += corrent_n
        correct_words += correct_word

    print("in %d tests, correct chars is" % NUM_TEST, correct_chars)
    print("in %d tests, correct predicted number of chars" % NUM_TEST, correct_num)
    print("in %d tests, correct predicted words" % NUM_TEST, correct_words)
       
prepare_data()
check_data()
print("KVP in data processed:", len(DATA))

optimize()

run_test()


9955 35
torch.Size([1, 1, 24, 72])
torch.FloatTensor
GLOBAL_MEAN is 227.96146786697545
KVP in data processed: 9686
Optimizing...
OrderedDict([('neural_net.0.weight', tensor([[-0.0207,  0.0099, -0.0137,  ...,  0.0236, -0.0178, -0.0071],
        [ 0.0237, -0.0227,  0.0141,  ...,  0.0108, -0.0111,  0.0057],
        [-0.0119, -0.0238,  0.0102,  ..., -0.0032,  0.0136, -0.0074],
        ...,
        [ 0.0217,  0.0011,  0.0093,  ..., -0.0215,  0.0186, -0.0240],
        [-0.0212, -0.0178,  0.0123,  ..., -0.0022, -0.0103,  0.0164],
        [-0.0156, -0.0098, -0.0173,  ...,  0.0122, -0.0222, -0.0190]])), ('neural_net.0.bias', tensor([ 0.0006, -0.0200, -0.0094,  ..., -0.0225,  0.0143, -0.0231])), ('neural_net.2.weight', tensor([[ 0.0053,  0.0062, -0.0010,  ...,  0.0142,  0.0074,  0.0121],
        [-0.0033, -0.0168, -0.0086,  ..., -0.0010,  0.0043, -0.0018],
        [ 0.0053, -0.0167,  0.0029,  ...,  0.0019,  0.0165,  0.0059],
        ...,
        [-0.0010,  0.0061,  0.0008,  ..., -0.0149, -0.0049

  input = module(input)
  input = module(input)
  allow_unreachable=True)  # allow_unreachable flag


label is DUUX
label is E7HY
label is HZD5
label is DVJC
label is MQQC
label is 97GN
label is BCQQ
label is 6Z2B
label is NHLW
label is 2XF7
label is C4HR
label is BL3Y
label is V3TZ
label is YE34
label is QK2V
label is YF79
label is KY8P
label is LGB9
label is ED7B
label is W4F9
label is 3J38
label is B7E5
label is GGSC
label is NV7Z
label is PWLF
label is 6L5Y
label is 9VAE
label is BJXY
label is UZR9
label is QJD7
label is 6P5Y
label is WKYM
label is 5D9J
label is GYVE
label is 7CNG
label is 8VFD
label is CSAX
label is 373H
label is RNW5
label is CKJU
label is NW3S
label is L46V
label is BBT5
label is VSH5
label is W96G
label is M6LP
label is NNMM
label is RENL
label is 6Q27
label is 3GYK
label is 7W48
label is U55R
label is 2TNL
label is CU56
label is 29DD
label is YXW2
label is BCUW
label is FALW
label is E4N7
at 1000 step loss is 104959288102031.56
in 250 tests, correct chars is 175
in 250 tests, correct predicted number of chars 241
in 250 tests, correct predicted words 0
at 2000