In [None]:
!pip install transformers

# Install sentencepiece to use the pretrained model
!pip install sentencepiece

from PIL import Image, ImageDraw, ImageFont, ImageFilter
from matplotlib.pyplot import imshow,cm

import numpy as np
import random
import cv2
import time
import csv
import tqdm

## Download Google Font
* The location where the fonts will be stored should be **/usr/share/fonts/truetype/google-fonts/**.

In [None]:
_wgeturl="https://github.com/google/fonts/archive/main.tar.gz"
_gf="google-fonts"

!echo "Connecting to Github server to download fonts..."
!wget "https://github.com/google/fonts/archive/main.tar.gz" -O google-fonts.tar.gz

!echo "Extracting the downloaded archive..."
!tar -zxvf google-fonts.tar.gz

!echo "Creating the /usr/share/fonts/truetype/$_gf folder"
!sudo mkdir -p /usr/share/fonts/truetype/$_gf

!echo "Installing all .ttf fonts in /usr/share/fonts/truetype/$_gf"
!find $PWD/fonts-main/ -name "*.ttf" -exec sudo install -m644 {} /usr/share/fonts/truetype/google-fonts/ \; || echo "An error occured, please run this script again."

!echo "Updating the font cache"
!fc-cache -f

!echo "Done. Now you can delete the tarball file $_gf.tar.gz if you wish."

* Find specific font you want

In [None]:
!find /usr/share/fonts/truetype/google-fonts/ -name "Roboto*"

## 1. Generate Dataset
We have to generate data to train the model.
We generate four types of image data file.

Each type is named **pure,blur,gaussian,blur+gaussian.**


The location where the type of data will be stored should be **/data**. 

In [None]:
#Set directory to save the data file
!mkdir data

In [None]:
def generate_gaussian_noise(image):
    import cv2
    import numpy as np

    mean = 0
    sigma = 10
    gaussian = np.random.normal(mean, sigma, (image.shape[0],image.shape[1]))
    noisy_image = np.zeros(image.shape, np.float32)
    noisy_image[:,:,0] = (image[:,:,0]+gaussian)/(np.max(gaussian)+255)*255
    noisy_image[:,:,1] = (image[:,:,1]+gaussian)/(np.max(gaussian)+255)*255
    noisy_image[:,:,2] = (image[:,:,2]+gaussian)/(np.max(gaussian)+255)*255
    cv2.normalize(noisy_image, noisy_image, 0, 255, cv2.NORM_MINMAX, dtype=-1)
    noisy_image = noisy_image.astype(np.uint8)
    return noisy_image

In [None]:
def ImageGenerator(word_list, data_type = 'pure', #is_rand=False, is_gaussian=False,\
                   background='white',fill=(0,0,0),\
                   font_idx=0,num_of_generated_samples=0):

    from PIL import Image, ImageDraw, ImageFont, ImageFilter
    from matplotlib.pyplot import imshow,cm

    import numpy as np
    import random
    import cv2
    import time

    '''
    background='white'
    fill=(0,0,0)
    is_rand = True
    is_gaussian = True
    word_list = ["asd","qwer","ejhjdwf"]
    '''
    type_dict = {'pure':[False, False],\
                 'gaussian':[False, True],\
                 'blur':[True, False],\
                 'blur+gaussian':[True, True]}
    is_rand, is_gaussian = type_dict[data_type]
    font_list = ['/usr/share/fonts/truetype/google-fonts/RobotoMono[wght].ttf',\
                '/usr/share/fonts/truetype/google-fonts/MontserratAlternates-Medium.ttf',\
                '/usr/share/fonts/truetype/google-fonts/KdamThmorPro-Regular.ttf',\
                '/usr/share/fonts/truetype/google-fonts/Joan-Regular.ttf',\
                '/usr/share/fonts/truetype/google-fonts/Roboto-Bold.ttf']
    random.seed(time.time())
    string = ""
    for item in word_list:
        string = string + item + " "
    #print(string)

    font_type = font_list[font_idx%5]
    font_size = 28
    font = ImageFont.truetype(font_type,size = font_size)

    string_size = font.getsize(string)
    height = 64
    #print(string_size)


    x_pos = len(word_list)
    y_pos= 16
    image = Image.new('RGB',(string_size[0]+len(word_list), height), color = background)
    imageDraw = ImageDraw.Draw(image)
    #print(word_list)
    blurr_list = []
    for item in word_list:
        xx_pos = x_pos
        for char in item:
            rand = 1
            if is_rand==True:
                rand = int(random.uniform(0,5))
            if rand==0:
                blurr_list.append([xx_pos, font.getsize(char)[0]])
            xx_pos = xx_pos + font.getsize(char)[0] # font.getsize >> [width, height]
        imageDraw.text((x_pos,y_pos), item, font=font, fill=fill)
        x_pos = x_pos + font.getsize(item)[0]
        imageDraw.text((x_pos,y_pos), " ", font=font, fill=fill)
        x_pos = x_pos + font.getsize(" ")[0]

    #print(blurr_list)
    
    for item in blurr_list:
        pos, width = item
        box = (pos,10,pos+width,60)
        ic = image.crop(box)
        ic = ic.filter(ImageFilter.GaussianBlur(1.5))
        image.paste(ic, box)

    if is_gaussian == True:
        noisy_image = generate_gaussian_noise(np.array(image))
        image=Image.fromarray(noisy_image)
    
    #file_name = data_type +"/"+ data_type+"_"+str(num_of_generated_samples) + ".jpg"
    file_name = 'data/'+data_type+"_"+str(num_of_generated_samples) + ".jpg"
    print(file_name)
    image = image.save(file_name)
    string = string[:-1]

    return

In [None]:
def generate(data_type_num = 0, num=10):
    if data_type_num > 4:
        print("datatype out of range!")
        return
    from transformers import pipeline, set_seed
    generator = pipeline('text-generation', model='distilgpt2')
    set_seed(42)
    prompt = ["A small", "I like", "He is", "She is", "It is", "I am", "you looks", "When a", "A", "Sit down", "There are", "Food"]
    color_dictionary = {'black':[(255,255,255),(255,255,0)],\
                        'orange':[(255,255,255),(255,255,0),(0,0,0)], \
                        'red':[(255,255,255),(255,255,0),(0,0,0)],\
                        'white':[(0,0,0),(255,0,0),(0,0,255)],\
                        'green':[(0,0,0),(255,255,0),(255,255,255)],\
                        'blue':[(255,255,255),(0,0,0)]}


    font_list = ['RobotoMono[wght].ttf','MontserratAlternates-Medium.ttf','KdamThmorPro-Regular.ttf','Joan-Regular.ttf','Roboto-Bold.ttf']
    color_list = ['black','orange','red','white','green','blue']
    num_of_generated_samples = 0
    data_type_list = ['pure','blur','gaussian','blur+gaussian']
    data_type = data_type_list[data_type_num]

    # csv file open
    f = open('data/data_list.csv', 'a', encoding='utf-8', newline='')
    #f = open('{0}/{1}_data_list.csv'.format(data_type,data_type), 'w', encoding='utf-8', newline='')
    wr = csv.writer(f)
    #wr.writerow(['index','text','datatype','background','fill','font'])

    
    while num_of_generated_samples < num:
        result = generator(prompt[random.randrange(0,len(prompt)-1)], max_length=100, num_return_sequences=1)
        for item in result:
                seq = item['generated_text'].replace("\n", " ")
                #Getting rid of \u200b from a string using regular expressions
                seq = seq.replace(u'\u200b', '')
                seq = seq.replace(u'\u200c', '')

                seq = seq.replace("."," ")
                seq = seq.replace("\"","")
                seq = seq.replace("  "," ")
                seq = seq.replace("   "," ")
                seq = seq.replace("    "," ")
                seq = seq.replace("     "," ")
                texts = seq.split(" ")
                index = 0
                for i in range(0,10):
                    if 6*index+6 > len(texts):
                        break

                    background = color_list[random.randrange(len(color_list))]
                    fill = color_dictionary[background]
                    fill = fill[random.randrange(len(fill))]
                    #ImageGenerator(texts[6*index:6*index+6], is_rand=is_rand, is_gaussian=is_gaussian, background=background, fill=fill,font_idx=i,num_of_generated_samples=num_of_generated_samples)
                    ImageGenerator(texts[6*index:6*index+6], data_type = data_type, background=background, fill=fill,font_idx=i,num_of_generated_samples=num_of_generated_samples)
                    #print(num_of_generated_samples," ".join(texts[6*index:6*index+6]),",",background, fill, font_list[i%len(font_list)])
                    #print(" ".join(texts[6*index:6*index+6]).split(" "))
                    wr.writerow([num_of_generated_samples," ".join(texts[6*index:6*index+6]),data_type,background, fill, font_list[i%len(font_list)]])
                    index+=1
                    num_of_generated_samples += 1
    f.close()
    return



###Data generation prompt

You can generate the dataset with the following command

        generate(data_type_num, min_num_of_data)
                data_type_num: can select the type of augmentation
                    0: 'pure'
                    1: 'gaussian'
                    2: 'blur'
                    3: 'blur+gaussian'
                min_num_of_data: determine how many datasets do you make. (more could be made)


If you have created the dataset, we can find dataset and .csv file in each directory.

In [None]:
!rm data/*

In [None]:
f = open('data/data_list.csv', 'a', encoding='utf-8', newline='')
wr = csv.writer(f)
#wr.writerow(['index','text','datatype','background','fill','font'])
f.close()

generate(0,4000)
generate(1,1600)
generate(2,1600)
generate(3,800)

### Generated dataset

We can see generated image it applied augmentation well


In [None]:
img = Image.open('data/pure_1.jpg').convert("RGB")
imshow(img)

In [None]:
img = Image.open('data/blur_5.jpg').convert("RGB")
imshow(img)

In [None]:
img = Image.open('data/gaussian_3.jpg').convert("RGB")
imshow(img)

In [None]:
img = Image.open('data/blur+gaussian_2.jpg').convert("RGB")
imshow(img)

## 2. Generate the model (Pretrained)



First, we use pre-trained imageEncoder-textDecoder model. which called TrOCR.

You can easily download it from hugging face.

We use two types of model to compare. Each model fine-tuned with SROIE dataset.

1. **trocr-small-printed**: DeiT(Encoder) + UniLM(Decoder) - # params: 61596672
2. **trocr-base-printed**: BeiT(Encoder) + Roberta(Decoder) - # params: 333921792

**trocr-small-printed**: DeiT(Encoder) + UniLM

Doesn't works well when we test with blur+gaussian image

In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import requests

import numpy as np
from matplotlib.pyplot import imshow

# load image from the IAM database (actually this model is meant to be used on printed text)
image = Image.open('data/blur+gaussian_3.jpg').convert("RGB")
imshow(np.array(image))

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-printed')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-printed')
pixel_values = processor(images=image, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("\n\n\n\ngenerated_text:{}".format(generated_text))

**trocr-base-printed**: BeiT(Encoder) + Roberta(Decoder)


Works well when we test with blur+gaussian image

In [None]:
# load image from the IAM database (actually this model is meant to be used on printed text)
image = Image.open('data/blur+gaussian_3.jpg').convert("RGB")
imshow(np.array(image))

processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-printed')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-printed')
pixel_values = processor(images=image, return_tensors="pt").pixel_values

generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print("\n\n\n\ngenerated_text:{}".format(generated_text))

## 3. Generate other model - ViT+LSTM

We generate other model to compare with the models below.

### 1. Generate Dataloader

We need to convert character sequence to tensor so that torch can handle it.


* One-hot encoding - we convert each character of the sequence into one-hot vector.

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence, pack_sequence, PackedSequence

from torchvision import transforms, utils
from PIL import Image

import torch
import string
import torch.nn.functional as F
from torchvision.transforms import ToTensor, ToPILImage

toT = ToTensor()
batch_size = 1

In [None]:
print(len('TO SEE WHAT\''))

tensor_list = []
for item in ('TO SEE WHAT\''):
    #print((ord(item)),F.one_hot(torch.tensor(ord(item)), num_classes=128))
    tensor_list.append((F.one_hot(torch.tensor(ord(item)), num_classes=128)))
tensors = torch.stack(tensor_list)

print(tensors[0])

* convert image into tensor - convert image into tensor

In [None]:
image = Image.open('data/blur+gaussian_3.jpg').convert("RGB")
img_resize = image.resize((256, 256))
print(type(img_resize))
print(type(image))
imshow(np.array(img_resize))
toT = ToTensor()
toT(image)

In [None]:
class image_text_Dataset(Dataset):
    def __init__(self, csv_file, root_dir):
        """
        Args:
            csv_file (string): csv 파일의 경로
            root_dir (string): 이미지가 존재하는 디렉토리 경로
        """
        self.data_list = pd.read_csv(csv_file)
        self.root_dir = root_dir

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        #img_path = self.root_dir + str(idx) + '.jpg'
        #img = Image.open(img_path).convert("RGB")
        #img_transformed = toT(img.resize((224, 224)))
        text = self.data_list.values[idx][1]
        img_path = self.root_dir + imgset.data_list.values[idx][2] +'_' +str(imgset.data_list.values[idx][0])+'.jpg' 
        
        img = Image.open(img_path).convert("RGB")
        img_transformed = toT(img.resize((224, 224)))
        tensor_list = []
        for item in text:
            tensor_list.append(ord(item))
            #tensor_list.append((F.one_hot(torch.tensor(ord(item)), num_classes=128)))
        text_tensors = torch.Tensor(tensor_list)

        return img_transformed, text_tensors

imgset = image_text_Dataset(csv_file = 'data/data_list.csv',
                                   root_dir='data/')




img_transformed, target = imgset[1]

In [None]:
train_dataset = image_text_Dataset(csv_file = 'data/data_list.csv',
                                   root_dir='data/')

train_dataloader = DataLoader(train_dataset,
                                   batch_size=batch_size,
                                   shuffle=True)

# batch_iterator = iter(train_dataloader)
# images = next(batch_iterator)

### 3-2. Define other model

In [None]:
!pip install timm

In [None]:
import timm
import torch
# model = timm.create_model('resnet50d', pretrained=True)
model = timm.create_model('vit_small_r26_s32_224_in21k', pretrained=True)
#model1 = timm.create_model('resnet50d', pretrained=True)
o = model.forward_features(torch.randn(2, 3, 224, 224))
#o1 = model1.forward_features(torch.randn(2, 3, 224, 224))
print(o.shape)
#print(o1.shape)


In [None]:
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
import torch
import torch.nn as nn

class TextRecognizer(nn.Module):
    
    def __init__(self, hidden_size,  embedding_size, n_layers=1):
        super(TextRecognizer, self).__init__()
        
        self.hidden_size = hidden_size
        self.embedding_size = embedding_size
        self.n_layers = n_layers
        

        self.lstm = nn.LSTM(self.embedding_size, self.hidden_size, self.n_layers)
        self.dropout = nn.Dropout(0.2)
        self.fc = nn.Linear(self.hidden_size, self.embedding_size)

    def forward(self, x, hidden):
        x=x.unsqueeze(2)
        out, (ht1, ct1) = self.lstm(x, hidden)
        out = self.dropout(out)
        x = self.fc(out)
        return x, (ht1, ct1)
    
    def init_hidden(self, batch_size = 1):
        return (torch.zeros(self.n_layers, batch_size, self.hidden_size, requires_grad=True).to(device))

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = TextRecognizer(hidden_size=384, embedding_size=1, n_layers=1)
model.to(device)
vision_model = timm.create_model('vit_small_r26_s32_224_in21k', pretrained=True)
vision_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, amsgrad=True)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    patience=5, 
    verbose=True, 
    factor=0.5
)

n_epochs = 10
batch_size = 1
loss_avg = []


In [None]:
for epoch in tqdm(range(n_epochs)):
    for img, text_tensor in train_dataloader:
        img = img.to(device)
        zero = torch.zeros(batch_size,1)
        target = text_tensor
        input = torch.cat([zero,target],dim=1)
        print(input.shape)
        #hidden = model.init_hidden(batch_size = 1)
        
        input = input.permute(1, 0).to(device)
        target = target.permute(1, 0).to(device)
        hidden = vision_model.forward_features(img).unsqueeze(1)
        hidden = (hidden, hidden)
        #print(target.shape)
        #print(hidden[0].shape)
        #print(o.shape)

        output,hidden = model(input,hidden)


        loss = criterion(output[:-1].squeeze(1), target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

In [None]:
start_text = ' '

imgset = image_text_Dataset(csv_file = 'data/data_list.csv',
                                   root_dir='data/')

print(imgset.data_list)
img_transformed, target = imgset[10]
img_transformed = img_transformed.to(device)
target = target.to(device)

hidden = vision_model.forward_features(img_transformed.unsqueeze(0)).unsqueeze(1)
hidden = (hidden, hidden)

output,hidden = model(input,hidden)
string = ''
for i in range(0,len(output)-1):
    string = string + (chr(round(output[i].item()+67)))

print(string)

root_dir = 'data/'
text = imgset.data_list.values[10][1]
img_path = root_dir + imgset.data_list.values[10][2] +'_' +str(imgset.data_list.values[10][0])+'.jpg' 

img = Image.open(img_path).convert("RGB")
imshow(img)
