<a href="https://colab.research.google.com/github/mandliya/ml/blob/master/text_detection.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Get the repository (TextSnake)

In [1]:
!git clone https://github.com/princewang1994/TextSnake.pytorch

Cloning into 'TextSnake.pytorch'...
remote: Enumerating objects: 396, done.[K
remote: Total 396 (delta 0), reused 0 (delta 0), pack-reused 396[K
Receiving objects: 100% (396/396), 1.64 MiB | 9.75 MiB/s, done.
Resolving deltas: 100% (244/244), done.


In [5]:
!ls

data  sample_data  TextSnake.pytorch


### Add it to system path

In [0]:
import sys
sys.path.append('/content/TextSnake.pytorch')

In [0]:
!mkdir data

### Download the pretrained model from google link.

Model is [here](https://drive.google.com/open?id=1YvsuxKH9M-Gseur9gc-SZJb3pCpTUddi) 

In [6]:
# Install the PyDrive wrapper & import libraries.
# This only needs to be done once per notebook.
!pip install -U -q PyDrive
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

# Authenticate and create the PyDrive client.
# This only needs to be done once per notebook.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

# Download a file based on its file ID.
#
file_id = '1YvsuxKH9M-Gseur9gc-SZJb3pCpTUddi'
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('data/textsnake_vgg_180.pth')

[?25l[K     |▎                               | 10kB 25.9MB/s eta 0:00:01[K     |▋                               | 20kB 2.1MB/s eta 0:00:01[K     |█                               | 30kB 3.0MB/s eta 0:00:01[K     |█▎                              | 40kB 2.0MB/s eta 0:00:01[K     |█▋                              | 51kB 2.5MB/s eta 0:00:01[K     |██                              | 61kB 3.0MB/s eta 0:00:01[K     |██▎                             | 71kB 3.4MB/s eta 0:00:01[K     |██▋                             | 81kB 3.8MB/s eta 0:00:01[K     |███                             | 92kB 4.3MB/s eta 0:00:01[K     |███▎                            | 102kB 3.4MB/s eta 0:00:01[K     |███▋                            | 112kB 3.4MB/s eta 0:00:01[K     |████                            | 122kB 3.4MB/s eta 0:00:01[K     |████▎                           | 133kB 3.4MB/s eta 0:00:01[K     |████▋                           | 143kB 3.4MB/s eta 0:00:01[K     |█████                     

imports

In [0]:
from PIL import Image
from IPython.display import display, HTML, clear_output
from ipywidgets import widgets, Layout
from io import BytesIO

import os
import time
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.utils.data as data

#from dataset.deploy import DeployDataset
from network.textnet import TextNet
from util.detection import TextDetector
from util.augmentation import BaseTransform
from util.config import config as cfg, update_config, print_config
from util.option import BaseOptions
from util.visualize import visualize_detection
from util.misc import mkdirs, rescale_result
from util.config import config as cfg, update_config, print_config

import torchvision.transforms as transforms

# pytorch provides a function to convert PIL images to tensors.
pil2tensor = transforms.ToTensor()
tensor2pil = transforms.ToPILImage()

import requests

%matplotlib inline


### Some utilities to get user input

In [0]:

def init_widgets_TD(url):
  image_text = widgets.Text(
    description="Image URL", layout=Layout(minwidth="70%")
  )
  image_text.value = url
  submit_button = widgets.Button(description="Ask Text Detection!")

  display(image_text)
  display(submit_button)

  submit_button.on_click(lambda b: on_button_click_td(
      b, image_text
  ))
  
  return image_text
  
  
def get_actual_image(image_path):
      if image_path.startswith('http'):
          path = requests.get(image_path, stream=True).raw
      else:
          path = image_path
      
      return path


### Set up configuration for text detection

change here if you need to adjust text detection parameters


In [0]:
from easydict import EasyDict
import torch

def create_text_detection_config():
  config = EasyDict()
  config.num_workers = 1
  config.batch_size = 1
  config.max_epoch = 100
  config.start_epoch = 0
  config.lr = 1e-4
  config.cuda = True
  config.n_disk = 15
  config.output_dir = 'output'
  config.input_size = 512
  # max polygon per image
  config.max_annotation = 200

  # max point per polygon
  config.max_points = 20

  # use hard examples (annotated as '#')
  config.use_hard = True

  # demo tr threshold
  config.tr_thresh = 0.6

  # demo tcl threshold
  config.tcl_thresh = 0.4

  # expand ratio in post processing
  config.post_process_expand = 0.3

  # merge joined text instance when predicting
  config.post_process_merge = False
  config.device = torch.device('cuda') if config.cuda else torch.device('cpu')
  config.detection_model_path = 'data/textsnake_vgg_180.pth'
  return config


def to_device(cfg, *tensors):
    if len(tensors) < 2:
        return tensors[0].to(cfg.device)
    return (t.to(cfg.device) for t in tensors)




### Create the text detection model

In [0]:
class TextDetectorModel:
  def __init__(self, cfg):
    """Creates a text detection model"""
    self.config = cfg
    self.model = TextNet(is_training=False, backbone='vgg')
    self.model.load_model(cfg.detection_model_path)
    
    self.model = self.model.to(cfg.device)
    if cfg.cuda:
        cudnn.benchmark = True
    self.detector = TextDetector(self.model, tr_thresh=cfg.tr_thresh,
                                 tcl_thresh=cfg.tcl_thresh)

  
  def rescale_result(self, image, contours, H, W):
    ori_H, ori_W = image.shape[:2]
    image = cv2.resize(image, (W, H))
    for cont in contours:
        cont[:, 0] = (cont[:, 0] * W / ori_W).astype(int)
        cont[:, 1] = (cont[:, 1] * H / ori_H).astype(int)
    return image, contours

  def predict(self, image, H, W):
    image = to_device(self.config, image)
    contours, output = self.detector.detect(image)
    image, contours = self.rescale_result(image, contours, H, W)
    return image, contours
    



In [0]:
def on_button_click_td(b, image_text):
  clear_output()
  image_path = get_actual_image(image_text.value)
  image = Image.open(image_path)
  config = create_text_detection_config()
  text_detection_model = TextDetectorModel(config)
  tensor = pil2tensor(image)
  reshaped = torch.unsqueeze(tensor, 0)#tensor.permute(2, 0, 1).unsqueeze(0)
  image, contours = text_detection_model.predict(reshaped, 240, 240)
  display(image)
  print(contours)

In [0]:

image_text = init_widgets_TD(
    "http://images.cocodataset.org/train2017/000000505539.jpg", 
)

Loading from data/textsnake_vgg_180.pth


ValueError: ignored