In [70]:
#todo
#0. try to find the token used for masking
#1. train the base model for masking

import argparse
import pandas as pd
from transformers import AutoTokenizer, pipeline, AutoModelForSeq2SeqLM
import torch
import os
import time
from tqdm import tqdm
import editdistance as ed
device = torch.device('cuda')
import difflib
from termcolor import colored

In [71]:
def get_edit_distance(predicted_text, transcript):
    cer = ed.eval(predicted_text, transcript) / max(len(predicted_text), len(transcript))
    return cer

def highlight_changes(line1, line2, label):
    # Generate a diff between the two lines
    diff = list(difflib.ndiff(line1, line2))

    # Create colored output based on diff
    highlighted_line1 = []
    highlighted_line2 = []

    i = 0
    while i < len(diff):
        if diff[i].startswith('- ') and i+1 < len(diff) and diff[i+1].startswith('+ '):
            # Replace operation
            highlighted_line1.append(colored(diff[i][2:], 'yellow', attrs=['bold']))
            highlighted_line2.append(colored(diff[i+1][2:], 'yellow', attrs=['bold']))
            i += 2
        elif diff[i].startswith('- '):
            # Delete operation
            highlighted_line1.append(colored(diff[i][2:], 'blue', attrs=['bold']))
            highlighted_line2.append(colored(' ', 'blue', attrs=['bold']))
            i += 1
        elif diff[i].startswith('+ '):
            # Insert operation
            highlighted_line1.append(colored(' ', 'red', attrs=['bold']))
            highlighted_line2.append(colored(diff[i][2:], 'red', attrs=['bold']))
            i += 1
        elif diff[i].startswith('  '):
            # Unchanged characters
            highlighted_line1.append(diff[i][2:])
            highlighted_line2.append(diff[i][2:])
            i += 1
        else:
            # Skip '?' lines
            i += 1

    final_line1 = ''.join(highlighted_line1)
    final_line2 = ''.join(highlighted_line2)

    # Ensure equal length by padding with spaces
    max_length = max(len(final_line1), len(final_line2))
    final_line1 = final_line1.ljust(max_length)
    final_line2 = final_line2.ljust(max_length)


    # Pad the ground_truth_label and label_label to the same length
    ground_truth = "Ground Truth            "
    max_length_l = max(len(ground_truth), len(label))
    padded_ground_truth = ground_truth.ljust(max_length_l)
    padded_label = label.ljust(max_length_l)

    print(f"{padded_ground_truth}" + final_line2)
    print(f"{padded_label}" + final_line1)


# Example usage:
highlight_changes("abcdef", "abxcef",'kartik')

Ground Truth            ab[1m[31mx[0mc[1m[34m [0mef
kartik                  ab[1m[31m [0mc[1m[34md[0mef


In [72]:
experiment = 'big_7pagetrained_1x'
fold = 1

In [73]:
data_path = f"/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/data/experiment_{experiment}/test_fold_{fold}.csv"
# tokenizer_path = '/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/downloaded_model/models--byt5-sanskrit'

#tokenizer_path = '/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/downloaded_model/models--byt5-sanskrit'
tokenizer_path = '/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/downloaded_model/sanskrit-multitask'

base_model_dir = '/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/downloaded_model/sanskrit-multitask'
pretrained_model_dir = '/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/downloaded_model/models--byt5-sanskrit'
finetuned_model_dir = '/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/model_checkpoints/experiment_big_7pagetrained_1x/1/checkpoint-800'

tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, max_length=512)
print('tokenizer loaded')
base_model = AutoModelForSeq2SeqLM.from_pretrained(base_model_dir)
pretrained_model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_dir)
finetuned_model = AutoModelForSeq2SeqLM.from_pretrained(finetuned_model_dir)
print('models loaded')

# Ensure the results directory exists
dir_path = os.path.dirname(data_path)
experiment_folder = os.path.basename(dir_path)
folder_name = os.path.splitext(os.path.basename(data_path))[0]
results_folder_path = f'/home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/outputs/MASK_{experiment_folder}/{folder_name}'
if not os.path.exists(results_folder_path):
    os.makedirs(results_folder_path)
    print(f"Directory created: {results_folder_path}")
else:
    print(f"Directory already exists: {results_folder_path}")

# Read the test CSV file
test_df = pd.read_csv(data_path, sep=';')
print('Test csv read')

ocr_list = list(test_df.input_text.values)
target_list = list(test_df.target_text.values)
paths = list(test_df.path.values)



Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


tokenizer loaded
models loaded
Directory already exists: /home/ocr_proj/OCR/post_correction/pe-ocr-sanskrit/outputs/MASK_experiment_big_7pagetrained_1x/test_fold_1
Test csv read


In [74]:
def generate_text(model, tokenizer, text):
   
    tensor = torch.tensor([tokenizer(text).input_ids])
    output_ids = model.generate(tensor,max_length=200,num_beams=3, eos_token_id=1, pad_token_id=0,decoder_start_token_id=0).tolist()
    post_corrected_list = tokenizer.batch_decode(output_ids)

    return post_corrected_list[0].replace("<pad>", "").replace("</s>", "").strip()

In [75]:
line_number = 44
ocr_text = ocr_list[line_number]
tgt_text = target_list[line_number]

In [76]:
_in = 0
_out = -1

highlight_changes(tgt_text[_in:_out],tgt_text[_in:_out],"Ground Truth")
print(' ')
highlight_changes(tgt_text[_in:_out],ocr_text[_in:_out], "OCR Output")
print(' ')
highlight_changes(tgt_text[_in:_out],generate_text(pretrained_model,tokenizer,ocr_text[_in:_out]),"Pretrained Model")
print(' ')
highlight_changes(tgt_text[_in:_out],generate_text(finetuned_model,tokenizer,ocr_text[_in:_out]),"Finetuned Model")
print(' ')
highlight_changes(tgt_text[_in:_out],generate_text(base_model,tokenizer,ocr_text[_in:_out]),"Base Model")




Ground Truth            karaṇakatvasaṃbaṃdhenapratyayārthopradhānenvayamapekṣyasvāśrayakaraṇakatvasaṃbaṃdhenaivābhyarhitenānvayaḥsiddh
Ground Truth            karaṇakatvasaṃbaṃdhenapratyayārthopradhānenvayamapekṣyasvāśrayakaraṇakatvasaṃbaṃdhenaivābhyarhitenānvayaḥsiddh
 
Ground Truth            karaṇakatvasaṃbaṃdhenapratyayārthopradhān[1m[33ma[0mnvayamapekṣyasvāśraya[1m[31mṃ[0mkaraṇakatvasa[1m[34m [0mbaṃdhen[1m[31me[0m[1m[31mś[0ma[1m[31m.[0m[1m[34m [0m[1m[34m [0m[1m[34m [0mbhyarhi[1m[31mr[0m[1m[31ma[0mtenānvayaḥsid[1m[34m [0mh[1m[31me[0m
OCR Output              karaṇakatvasaṃbaṃdhenapratyayārthopradhān[1m[33me[0mnvayamapekṣyasvāśraya[1m[31m [0mkaraṇakatvasa[1m[34mṃ[0mbaṃdhen[1m[31m [0m[1m[31m [0ma[1m[31m [0m[1m[34mi[0m[1m[34mv[0m[1m[34mā[0mbhyarhi[1m[31m [0m[1m[31m [0mtenānvayaḥsid[1m[34md[0mh[1m[31m [0m
 
Ground Truth            karaṇakatvasaṃbaṃdhenapratyayārthopradhān[1m[33ma[0mnvayamapekṣyasvāśraya[1m

## Masked Prediction

In [173]:
tokenizer.batch_decode([tokenizer('ā').input_ids])

['ā</s>']

In [186]:
line_number = 1

In [236]:

ocr_text = ocr_list[line_number]
tgt_text = target_list[line_number]
i =30
j=40

input_ids = tokenizer(ocr_text).input_ids
pre = tokenizer.batch_decode([input_ids[:i]])
mask = tokenizer.batch_decode([input_ids[i:j]])
post = tokenizer.batch_decode([input_ids[j:]])
print(pre)
print(mask)
#print(post)

input_ids_tensor = torch.tensor([input_ids[:i] + [258] + input_ids[j:]])
output_ids = base_model.generate(input_ids_tensor, max_length=10)[0].tolist()
print(tokenizer.batch_decode([output_ids]))
line_number+=1


IndexError: list index out of range

['<pad>R virodha']

['kāsenet']

In [110]:
output_ids_list = []
start_token = 0
sentinel_token = 258
while sentinel_token in output_ids:
    split_idx = output_ids.index(sentinel_token)
    output_ids_list.append(output_ids[start_token:split_idx])
    start_token = split_idx
    sentinel_token -= 1

output_ids_list.append(output_ids[start_token:])
output_string = tokenizer.batch_decode([output_ids_list])

TypeError: int() argument must be a string, a bytes-like object or a real number, not 'list'

In [87]:
print(output_string)

['<pad>R saṃkhyāyāḥ_saṃkhyāyāḥ_saṃkhyāyāḥ_saṃkhyāyāḥ_</s>']


In [18]:
ocr_text

'panvayenadhātvarathenanvayāt.saṃkhyāyāapyupadārthatavāviśeghāt.tenasākṣāhmaṃkhyānvayā..vanābheṃda'

['']