# OCR文字识别训练

In [2]:
import sagemaker
from sagemaker import get_execution_role
sagemaker_session = sagemaker.Session()

bucket = 'dikers-data'
prefix = 'sagemaker/ocr-pytorch-train'
role = 'arn:aws-cn:iam::690704700794:role/service-role/AmazonSageMaker-ExecutionRole-20200430T123312'

## 准备训练数据

#### 一 使用Demo数据 [Demo数据下载](https://dikers-data.s3.cn-northwest-1.amazonaws.com.cn/dataset/ocr_train_demo_data.zip)

#### 二  自己生成数据  


*  第一步  生成小图片    [参考代码](https://github.com/dikers/ocr-text-renderer)



```
# 生成图片和  label文件
# label 文件格式  前面是图片的路径， 后面是对应的gt
00000000.jpg F六G七H八I九J十
00000001.jpg e六f七g八h九i十
00000002.jpg W千X一Y二Z三?!
00000003.jpg t七u八v九w十x百
00000004.jpg 四P五Q六R七S八T
00000005.jpg Y二Z三?!@#%
00000006.jpg d五e六f七g八h九
00000007.jpg ,.A一B二C三D四
00000008.jpg p三q四r五s六t七
00000009.jpg 六t七u八v九w十x
```

*  第二步  请文件划分成 train.txt  valid.txt

```
head -n 10000 labels.txt > train.txt

tail -n 1000 labels.txt  >  valid.txt
```


* 第三步  将图片转换成mdb格式的文件


```
# 运行脚本
cd data_generate
sh create-lmdb.sh

```


```
# 修改脚本的路径
python3 create_lmdb_dataset.py --inputPath images_path/ \
--gtFile valid.txt \
--outputPath ./output/valid

python3 create_lmdb_dataset.py --inputPath images_path/  \
--gtFile train.txt \
--outputPath ./output/train
```


### 上传数据到S3



可以包含多个训练数据， 数据格式如下： 

```
'train_data': '/opt/ml/input/data/training/train',
'valid_data': '/opt/ml/input/data/training/valid',

'select_data': 'db1-db2',    # 训练数据的名称
'batch_ratio': '0.5-0.5',    # 训练数据对应的比率



.
├── train
│   ├── db1
│   │   ├── data.mdb
│   │   └── lock.mdb
│   └── db2
│       ├── data.mdb
│       └── lock.mdb
└── valid
    ├── db1
    │   ├── data.mdb
    │   └── lock.mdb
    └── db2
        ├── data.mdb
        └── lock.mdb

```

上传的路径   
```
   s3://'bucket'/'prefix'/
                        ├── train
                        │   ├── db1
                        │   │   ├── data.mdb
                        │   │   └── lock.mdb
                        │   └── db2
                        │       ├── data.mdb
                        │       └── lock.mdb
                        └── valid
                            ├── db1
                            │   ├── data.mdb
                            │   └── lock.mdb
                            └── db2
                                ├── data.mdb
                                └── lock.mdb

```

In [None]:
# 第一运行需要上传s3 
inputs = sagemaker_session.upload_data(path='../data_generate/output/', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))


### Run training in SageMaker



In [3]:
from sagemaker.pytorch import PyTorch
inputs = 's3://{}/{}'.format(bucket, prefix)
print(inputs)

estimator = PyTorch(entry_point='train.py',
                    source_dir='../atte/',
                    role=role,
                    framework_version='1.4.0',
                    train_instance_count=1,
                    train_instance_type='ml.p3.2xlarge',
                    base_job_name='ocr-train',
                    train_volume_size=100,
                    train_max_run=432000,
                    output_path='s3://{}/{}/output'.format(bucket, prefix),
                    hyperparameters={
                        'train_data': '/opt/ml/input/data/training/train',
                        'valid_data': '/opt/ml/input/data/training/valid',
                        'select_data': 'db1-db2',
                        'batch_ratio': '0.5-0.5',
                        'batch_size': 160,
                        'num_iter': 10000,
                        'valInterval': 200,
                        'Transformation': 'TPS',
                        'FeatureExtraction': 'ResNet',
                        'SequenceModeling': 'BiLSTM',
                        'Prediction': 'Attn'
                        
                    })
estimator.fit({'training': inputs})

s3://dikers-data/sagemaker/ocr-pytorch-train


's3_input' class will be renamed to 'TrainingInput' in SageMaker Python SDK v2.
'create_image_uri' will be deprecated in favor of 'ImageURIProvider' class in SageMaker Python SDK v2.


2020-08-03 13:42:45 Starting - Starting the training job...
2020-08-03 13:42:47 Starting - Launching requested ML instances......
2020-08-03 13:43:52 Starting - Preparing the instances for training......
2020-08-03 13:45:16 Downloading - Downloading input data......
2020-08-03 13:45:50 Training - Downloading the training image......
2020-08-03 13:47:08 Training - Training image download completed. Training in progress.[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2020-08-03 13:47:09,888 sagemaker-containers INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2020-08-03 13:47:09,913 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2020-08-03 13:47:10,121 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2020-08-03 13:47:10,497 sagemaker-containers INFO     Module default_user_module_name does no