In [1]:
import torch
import torch.nn.functional as F
from torchvision import transforms

import numpy as np

from utils import train_valid_test_split, get_num_curr_models, show_error_images, GreedyCTCDecoder
from crnn_model import CRNN
from captcha_dataset import CaptchaDataSet



In [2]:
import os
import gc
import numpy as np

import gdown

In [3]:
to_download_zip = True

if to_download_zip is True:
  url = 'https://drive.google.com/uc?id=1nalIGeKAJk9OaFrmLALEJC56lAxyE7K6'
  output = './DataSet 109k.zip'
  gdown.download(url, output, quiet=False)

unzip_data_script = """
if [ ! -d "./DataSet" ];
then
    jar xvf './DataSet 109k.zip';
    rm './DataSet/.png' #remove one pic without a name
    
fi 
"""
with open('unzip_data_script.sh', 'w') as file:
    file.write(unzip_data_script)

!bash './unzip_data_script.sh'

Downloading...
From: https://drive.google.com/uc?id=1nalIGeKAJk9OaFrmLALEJC56lAxyE7K6
To: /content/DataSet 109k.zip
100%|██████████| 603M/603M [00:12<00:00, 47.3MB/s]


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 inflated: DataSet/xvytv.png
 inflated: DataSet/xvyv8n.png
 inflated: DataSet/xvyvd.png
 inflated: DataSet/xvyy.png
 inflated: DataSet/xw33d.png
 inflated: DataSet/xw35.png
 inflated: DataSet/xw3ac.png
 inflated: DataSet/xw3c.png
 inflated: DataSet/xw3k.png
 inflated: DataSet/xw3nv.png
 inflated: DataSet/xw3x.png
 inflated: DataSet/xw4c.png
 inflated: DataSet/xw4d.png
 inflated: DataSet/xw4h5.png
 inflated: DataSet/xw4ruk.png
 inflated: DataSet/xw4t33.png
 inflated: DataSet/xw4tuh.png
 inflated: DataSet/xw4y.png
 inflated: DataSet/xw4y4m.png
 inflated: DataSet/xw53ns.png
 inflated: DataSet/xw54us.png
 inflated: DataSet/xw56.png
 inflated: DataSet/xw58mm.png
 inflated: DataSet/xw5ab.png
 inflated: DataSet/xw5ax9.png
 inflated: DataSet/xw5d.png
 inflated: DataSet/xw5nnc.png
 inflated: DataSet/xw5vk.png
 inflated: DataSet/xw64.png
 inflated: DataSet/xw6cu.png
 inflated: DataSet/xw6k.png
 inflated: DataSet/xw6pn.png
 inflated

In [4]:
dataset_folder = './DataSet'
saved_models_folder = '.'

num_epochs = 10
batch_size = 64

train_ratio = 0.8
test_ratio = 0.2

if torch.cuda.is_available():
    gc.collect()
    torch.cuda.empty_cache()
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print(device)

_transforms = transforms.Compose([transforms.ToTensor(),
                                  transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))])

cpu


In [None]:
captcha_dataset = CaptchaDataSet(path=dataset_folder, transform=_transforms)
num_classes = captcha_dataset.get_num_classes()
captcha_classes = captcha_dataset.get_classes()
end_sentence_idx = captcha_dataset.get_end_sentence_idx()
blank_idx = captcha_dataset.get_blank_idx()


_, _, test_loader = train_valid_test_split(captcha_dataset,train_ratio,test_ratio,batch_size)
print(f"test_len:{len(test_loader.dataset)}")

In [7]:
#from google.colab import drive
#drive.mount('/content/drive')
#!cp './drive/MyDrive/CaptchaCracking/rcnn_model10.pt' '.'

Mounted at /content/drive


In [8]:
file_name = 'rcnn_model'
curr_file_idx = 10
#curr_file_idx = get_num_curr_models(saved_models_folder,file_name,file_ext)
file_ext = '.pt'
loaded_data = torch.load(saved_models_folder+'/'+file_name+str(curr_file_idx)+file_ext,map_location=device)
print(f"Loaded model with idx:{curr_file_idx}")
model = loaded_data['model']

print(model)

Loaded model with idx:10
CRNN(
  (cnn_seq_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, t

In [9]:
def test(model, loader, decoder, captcha_classes):
  model.eval()
  total_errors = 0
  errors_images_dict = {}
  with torch.no_grad():
    for batch_id, (images_data, letters_captchas, _, images) in enumerate(loader):

      images_data = images_data.to(device)
      preds = model(images_data).detach().cpu()
      pred_indices, chars_lst = decoder.forward(preds,captcha_classes)

      _, req_seq_len = letters_captchas.size()

      pred_indices = pred_indices[:,:req_seq_len]
      chars_lst = chars_lst[:,:req_seq_len]

      sum_of_abs_diff_pred_label = torch.sum(torch.abs(pred_indices-letters_captchas), dim=1)
      batch_errors_indices = torch.nonzero(sum_of_abs_diff_pred_label, as_tuple=True)[0].numpy()
      if len(batch_errors_indices) > 0:
        letters_captchas_preds = chars_lst[batch_errors_indices]
        letters_captchas_np = letters_captchas.numpy()
        letters_captchas_labels = np.choose(letters_captchas_np, captcha_classes)[batch_errors_indices]
        error_images = images.numpy()[batch_errors_indices]
        errors_images_dict[batch_id] = [batch_errors_indices,letters_captchas_preds, letters_captchas_labels, error_images]

      curr_batch_errors = torch.count_nonzero(sum_of_abs_diff_pred_label, dim=0).item()
      total_errors += curr_batch_errors

      print(f"batch_id:{batch_id}, curr_batch_errors:{curr_batch_errors}")

      
    accuracy_metric = 100 * (1 - (total_errors / len(loader.dataset)))
    return accuracy_metric, errors_images_dict


In [None]:
decoder = GreedyCTCDecoder(blank_idx, end_sentence_idx)
accuracy_metric, errors_images_dict = test(model,test_loader, decoder, captcha_classes)
print(f"Accuracy:{accuracy_metric:.2f}")

In [None]:
end_sentence_char = captcha_classes[end_sentence_idx] 
num_error_images_to_show = 16
show_error_images(errors_images_dict,end_sentence_char,num_error_images_to_show,to_rand=True)