# test de validation

In [1]:
%matplotlib inline
import os
import json

os.environ['USE_TORCH'] = '1'

import matplotlib.pyplot as plt

from doctr.io import DocumentFile
from doctr.models import ocr_predictor

In [2]:
path = '/home/skit/formation/ml/datas/final'

In [7]:
def ocr_treatment(img_path):
	doc = DocumentFile.from_images(img_path)
	predictor = ocr_predictor('linknet_resnet18', pretrained=True, assume_straight_pages=False, preserve_aspect_ratio=True, det_bs=4, reco_bs=1024)
	result = predictor(doc)
	return result

def ocr_display(result):
	synthetic_pages = result.synthesize()
	plt.figure(figsize=(15, 15))
	plt.imshow(synthetic_pages[0]); plt.axis('off'); plt.show()
	
	
def extract_text_blocks(json_data):
    text_blocks = []
    for page in json_data["pages"]:
        for block in page["blocks"]:
            text = ""
            for line in block["lines"]:
                for word in line["words"]:
                    text += word["value"] + " "
            text_blocks.append(text.strip())
    return text_blocks


def ocr_display_text(result):
	json_data = result.export()
	text_blocks = extract_text_blocks(json_data)
	for i, block in enumerate(text_blocks, start=1):
		print(f"Block {i}: {block}")
	
	
def save_json(result, file_name='result'):
	with open(file_name + '.json', 'w') as f:
		f.write(json.dumps(result.export()))
	
	
def save_txt(result, file_name='result'):
	with open(file_name + '.txt', 'w') as f:
		f.write(result.render())

In [8]:
# load list from 'bad_ocr.txt' file
with open('bad_ocr.txt', 'r') as f:
	bad_ocr = f.read().splitlines()
from random import randint

In [11]:
img_path = os.path.join(path, bad_ocr[randint(0, len(bad_ocr) - 1)])
result = ocr_treatment(img_path)
result.show()
#ocr_display(result)
ocr_display_text(result)
save_txt(result)
save_json(result)

In [5]:
model = 'linknet_resnet50'

In [7]:
totreat = []

with open('to_ocr.txt', 'r') as f:
	files = f.read().splitlines()
	
import os
treated_files = [f for f in os.listdir(model) if os.path.isfile(os.path.join(model, f))]

for file in files:
	temp = file + '.txt'
	if temp not in treated_files:
		totreat.append(file)

In [None]:
from multiprocessing import Pool

def treat(file):
	print(file)
	img_path = os.path.join(path, file)
	result = ocr_treatment(img_path)
	save_txt(result, model + '/' + file)

with Pool(4) as p:
    print(p.map(treat, totreat))

In [7]:
import os
len([name for name in os.listdir(model) if os.path.isfile(os.path.join(model, name))])

In [1]:
models_detection = [
	'db_resnet34',
	'db_resnet50',
	'db_mobilenet_v3_large',
	'linknet_resnet18',
	'linknet_resnet34',
	'linknet_resnet50'
]

In [2]:
models_recognition = [
	'crnn_vgg16_bn',
	'crnn_mobilenet_v3_small',
	'crnn_mobilenet_v3_large',
	'master',
	'sar_resnet31',
	'vitstr_small',
	'vitstr_base',
	'parseq'
]