In [27]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
import os
import mlflow
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset
import torch
import torchvision

In [19]:
# load model and labels
processor = TrOCRProcessor.from_pretrained('microsoft/trocr-large-handwritten')
model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-large-handwritten')
path = '/home/riikoro/fossil_data/tooth_samples/v1'
label_filename = 'labels.txt'
with open(os.path.join(path, label_filename)) as label_file:
    labels = label_file.readlines()

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-large-handwritten and are newly initialized: ['encoder.pooler.dense.weight', 'encoder.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
def evaluate(generated, label):
    """
    Generate word counts required for precision and recall

    Args:
        generated: sentence generated with OCR
        label:     correct text in the image

    Returns:
        int: count of correct words in the generated sequence (correct word in wrong position is wrong)
        int: count of words in the generated sequence
        int: count of words in the label
    """
    generated_words = generated.split(' ')
    label_words = label.split(' ')
    correct_words = 0

    generated_word_count = len(generated_words)
    label_word_count = len(label_words)
    for i in range(generated_word_count):
        if i >= label_word_count:
            break
        if label_words[i] == generated_words[i]:
            correct_words += 1

    return correct_words, generated_word_count, label_word_count

In [5]:
total_correct = 0
total_generated = 0
total_label_words = 0
for file in os.listdir(path):
    if not file.endswith('.png'):
        continue
    img_path = os.path.join(path, file)
    img_no = int(file.split('.')[0])

    with Image.open(img_path) as image:
        pixel_values = processor(images=image, return_tensors="pt").pixel_values

        generated_ids = model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        
        correct_label = labels[img_no]
        print(f'--- reading dental marking from {file} ---')
        print(generated_text)
        print(correct_label)
        print
        correct_word_count, generated_word_count, label_word_count = evaluate(generated_text, correct_label)
        print(correct_word_count)
        print(generated_word_count)
        print(label_word_count)

        total_correct += correct_word_count
        total_generated += generated_word_count
        total_label_words += label_word_count



--- yks sana taas juuuu ---
Left mandible frog with broken Px, M, and
Left mandible frag with broken P4̌, M1̌ and

4
8
8
--- yks sana taas juuuu ---
A- Proximal # distal end right femur
Proximal & distal end right femur

0
7
6
--- yks sana taas juuuu ---
Isolated.
Isolated

0
1
1
--- yks sana taas juuuu ---
As Lktibia.
A= Lt tibia

0
2
3
--- yks sana taas juuuu ---
proximal.
proximal

0
1
1
--- yks sana taas juuuu ---
As left mandibular corpus with teeth " : " 3
A= Left mandibular corpus with teeth M1̄̌-M3̄̌

4
10
7
--- yks sana taas juuuu ---
Shalt lacking desstal A proximal ends
Shaft lacking distal & proximal ends

2
6
6
--- yks sana taas juuuu ---
isolated # RC. RM3, HM3, WHO, HM2.
Isolated L.C̱, RM3̱̂, LM3̱̂, LM2̱̂, LM3̌

0
7
6
--- yks sana taas juuuu ---
10kote.
lokote

0
1
1
--- yks sana taas juuuu ---
llms.
M3̂

0
1
1
--- yks sana taas juuuu ---
Ai prox.
A: prox

0
2
2
--- yks sana taas juuuu ---
Aitupper premolas Uninterrupted. L.S. MDR
A: 4 upper premolar (P3̱̌) unerrupted. L

In [6]:
precision = total_correct / total_generated
recall = total_correct / total_label_words
f1 = (2*precision*recall)/(precision+recall)


In [8]:
print(precision)
print(recall)
print(f1)

0.3310344827586207
0.3490909090909091
0.3398230088495576


In [13]:
%env MLFLOW_TRACKING_URI=sqlite:///mlruns.db
mlflow.set_experiment("Dental element OCR")

params = {
    'data_v': 1
}

# Start an MLflow run
with mlflow.start_run():
    # Log the hyperparameters
    mlflow.log_params(params)

    # Log the loss metric
    mlflow.log_metric("precision", precision)
    mlflow.log_metric("recall", recall)
    mlflow.log_metric("f1", f1)

    # Set a tag that we can use to remind ourselves what this run was for
    mlflow.set_tag("info", "Untuned TrOCR-handwritten-large")


env: MLFLOW_TRACKING_URI=sqlite:///home/riikoro/thesis/code/mlruns.db


## Train on V1 data & check how metrics change

todo:
- uniformize train image sizes
- perform training (try allowing any weight change & only last layer changes)
- log metrics to mlflow

In [23]:
# modified from https://pytorch.org/tutorials/beginner/basics/data_tutorial.html#creating-a-custom-dataset-for-your-files
class DentalElementDataset(Dataset):
    def __init__(self, annotations_file, img_dir):
        with open(os.path.join(img_dir, annotations_file)) as label_file:
            labels = label_file.readlines()
        self.img_labels = pd.DataFrame(labels, columns=['label'])
        self.img_dir = img_dir

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, f"{idx}.png")
        image = read_image(img_path)
        label = self.img_labels.iloc[idx]
        return image, label

In [28]:
dental_data = DentalElementDataset(label_filename, path)
train_loader = torch.utils.data.DataLoader(dental_data, batch_size=4, shuffle=True)

In [29]:
import matplotlib.pyplot as plt
import numpy as np

# Helper function for inline image display
def matplotlib_imshow(img, one_channel=False):
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))

dataiter = iter(train_loader)
images, labels = next(dataiter)

# Create a grid from the images and show them
img_grid = torchvision.utils.make_grid(images)
matplotlib_imshow(img_grid, one_channel=True)
print('  '.join(labels[j] for j in range(4)))

# left off: each image should be of the same size --> convert to equal sizes

RuntimeError: stack expects each tensor to be equal size, but got [3, 158, 1608] at entry 0 and [3, 138, 1492] at entry 1