In [1]:
import matplotlib.pyplot as plt
from PIL import Image
import torch
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
from vietocr.tool.translate import build_model
from torch import nn
from vietocr.model.trainer import Trainer

In [2]:
def save_models(model, file_name):
    output_path = './weights/'
    if not os.path.exists(output_path):
        os.mkdir(output_path)   
    saved_path = os.path.join(output_path, file_name)
    if os.path.exists(saved_path):
        os.remove(saved_path)   
    print('Save files in: ', saved_path)
    torch.save(model.state_dict(), saved_path)
    
def save_torchscript_model(model, file_name):
    output_path = './weights/'
    if not os.path.exists(output_path):
        os.mkdir(output_path)   
    model_filepath = os.path.join(output_path, file_name)
    torch.jit.save(torch.jit.script(model), model_filepath)
    print('Save in: ', model_filepath)
    return model_filepath

def load_torchscript_model(model_filepath, device):

    model = torch.jit.load(model_filepath, map_location=device)

    return model

## 1. Download sample dataset

In [None]:
! gdown https://drive.google.com/uc?id=19QU4VnKtgm3gf0Uw_N2QKSquW1SQ5JiE

In [None]:
! unzip -qq -o ./data_line.zip

## 2. Define config

* *data_root*: the folder save your all images
* *train_annotation*: path to train annotation
* *valid_annotation*: path to valid annotation
* *print_every*: show train loss at every n steps
* *valid_every*: show validation loss at every n steps
* *iters*: number of iteration to train your model
* *export*: export weights to folder that you can use for inference
* *metrics*: number of sample in validation annotation you use for computing full_sequence_accuracy, for large dataset it will take too long, then you can reuduce this number

In [5]:
config = Cfg.load_config_from_name('vgg_seq2seq')

In [6]:
dataset_params = {
    'name':'hw',
    'data_root':'./data_line/',
    'train_annotation':'train_line_annotation.txt',
    'valid_annotation':'test_line_annotation.txt'
}

params = {
         'print_every':200,
         'valid_every':15*200,
          'iters':100000,
          'checkpoint':'./weights/transformerocr.pth',    
          'export':'./weights/quantize_transformerocr.pth',
          'metrics': 10000
         }

config['trainer'].update(params)
config['dataset'].update(dataset_params)
config['device'] = 'cuda:1'
config['cnn']['pretrained']=False
config['weights'] = "./weights/transformerocr.pth"

In [7]:
device = config['device']

## 3. Get pretrained model

In [8]:
model, vocab = build_model(config)

In [9]:
weights = config['weights']
model.load_state_dict(torch.load(weights, map_location=torch.device(device)))

<All keys matched successfully>

## 4. Define input and outputs of quantized model

In [10]:
class QuantizedCNN(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedCNN, self).__init__()
        
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        
        # FP32 model
        self.model_fp32 = model_fp32

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.model_fp32(x)
        
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

## 4. Quantize Aware Training

### 4.1. Fuse layer

Fuse 'conv + relu' or 'conv + batchnorm + relu'

In [None]:
model = model.train()
for m in model.cnn.model.modules():
    if type(m) == nn.Sequential:
        for n, layer in enumerate(m):
            if type(layer) == nn.Conv2d:
                torch.quantization.fuse_modules(m, [str(n), str(n + 1), str(n + 2)], inplace=True)

### 4.2. Prepare the model for quantization aware training.

In [None]:
quantized_cnn = QuantizedCNN(model_fp32=model.cnn)
quantized_cnn.qconfig = torch.quantization.get_default_qconfig("fbgemm")

# Print quantization configurations
print(quantized_cnn.qconfig)

# the prepare() is used in post training quantization to prepares your model for the calibration step
quantized_cnn = torch.quantization.prepare_qat(quantized_cnn, inplace=True)

In [None]:
model.cnn = quantized_cnn

### 4.3 Training

Phụ thuộc vào bộ dữ liệu sử dụng huấn luyện sẽ dẫn đến kết quả khác nhau. Trong bài hướng dẫn này, mình sử dụng tạm thời bộ dữ liệu mẫu do thư viện VietOCR cung cấp.

In [None]:
model.train()
model = model.to(device)
trainer = Trainer(qmodel=model, config=config, pretrained=False)

In [None]:
# visualize dataset
trainer.visualize_dataset()

In [None]:
trainer.train()

## 5. Inference

In [None]:
config = Cfg.load_config_from_name('vgg_seq2seq')
# Pytorch support only cpu device
config['device'] = 'cpu'
config['cnn']['pretrained']=False
config['weights'] = "./weights/quantize_transformerocr.pth"

In [None]:
model, vocab = build_model(config)

In [None]:
# fuse layer
model = model.train()
for m in model.cnn.model.modules():
    if type(m) == nn.Sequential:
        for n, layer in enumerate(m):
            if type(layer) == nn.Conv2d:
                torch.quantization.fuse_modules(m, [str(n), str(n + 1), str(n + 2)], inplace=True)

In [None]:
# prepare model for quantize aware training
quantized_cnn = QuantizedCNN(model_fp32=model.cnn)
quantized_cnn.qconfig = torch.quantization.get_default_qconfig("fbgemm")

# Print quantization configurations
print(quantized_cnn.qconfig)

# the prepare() is used in post training quantization to prepares your model for the calibration step
quantized_cnn = torch.quantization.prepare_qat(quantized_cnn, inplace=True)

In [None]:
quantized_cnn = quantized_cnn.to(torch.device('cpu'))
model.cnn = torch.quantization.convert(quantized_cnn, inplace=True)   

In [None]:
# create detector
detector = Predictor(config, qmodel=model)

In [None]:
# Download sample image
! gdown --id 1uMVd6EBjY4Q0G2IkU5iMOQ34X0bysm0b
! unzip  -qq -o sample.zip

In [None]:
img = './sample/031189003299.jpeg'
img = Image.open(img)
plt.imshow(img)
s = detector.predict(img)
s