In [1]:
!pip install torch torchvision transformers tqdm sentencepiece



In [2]:
import os
import csv
import json
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms as T
import transformers
from transformers import DonutProcessor, VisionEncoderDecoderModel
from PIL import Image
from tqdm import tqdm

In [8]:
# Create json files to be read by model
def make_json(csv_file_path, json_file_path):

    data = {}
    with open(csv_file_path, mode='r', encoding='utf-8') as csv_file:
        csv_reader = csv.DictReader(csv_file)
        for row in csv_reader:
            filename = row['IMAGE']
            transcription = row['MEDICINE_NAME']
            data[filename] = transcription
    
    with open(json_file_path, mode='w', encoding='utf-8') as json_file:
        json.dump(data, json_file, ensure_ascii=False, indent=4)
    
    print(f"Successfully created {json_file_path}")

make_json('dataset/Training/training_labels.csv', 'dataset/Training/training_labels.json')
make_json('dataset/Validation/validation_labels.csv', 'dataset/Validation/validation_labels.json')
make_json('dataset/Testing/testing_labels.csv', 'dataset/Testing/testing_labels.json')

Successfully converted dataset/Training/training_labels.csv to dataset/Training/training_labels.json
Successfully converted dataset/Validation/validation_labels.csv to dataset/Validation/validation_labels.json
Successfully converted dataset/Testing/testing_labels.csv to dataset/Testing/testing_labels.json


In [36]:
# Load dataset
train_img_folder = "dataset/Training/training_words"
train_annotations_file = "dataset/Training/training_labels.json"
val_img_folder = "dataset/Validation/validation_words"
val_annotations_file = "dataset/Validation/validation_labels.json"
test_img_folder = "dataset/Testing/testing_words"
test_annotations_file = "dataset/Testing/testing_labels.json"

transform = T.Compose([
    T.Resize((296,296), 
    interpolation=T.InterpolationMode.BILINEAR),
    T.ToTensor(),
    T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class DonutOCRDataset(Dataset):
    def __init__(self, img_folder, annotations_file, transform):
        self.img_folder = img_folder
        self.transform = transform
        with open(annotations_file, "r", encoding="utf-8") as f:
            self.annotations = json.load(f)
        self.image_files = list(self.annotations.keys())
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_file = self.image_files[idx]
        img_path = os.path.join(self.img_folder, img_file)
        try:
            image = Image.open(img_path).convert("RGB")
        except Exception as e:
            raise RuntimeError(f"Error opening image {img_path}: {e}")
        if self.transform:
            image = self.transform(image)
        target_text = self.annotations[img_file]
        return image, target_text

train_dataset = DonutOCRDataset(train_img_folder, train_annotations_file, transform)
val_dataset = DonutOCRDataset(val_img_folder, val_annotations_file, transform)
test_dataset = DonutOCRDataset(test_img_folder, test_annotations_file, transform)

def collate_fn(batch):
    images, texts = zip(*batch)
    images = torch.stack(images, 0)
    tokenized = processor(
        text=list(texts),
        padding='max_length',  
        truncation=True,       
        max_length=512,     
        return_tensors="pt"
    )
    return images, tokenized.input_ids, tokenized.attention_mask

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

In [37]:
def train(model, train_loader, val_loader, optimizer, device, processor, epoch):
    model.train()
    total_loss = 0.0
    for images, labels, attention_mask in tqdm(train_loader, desc="Training", leave=False):
        images = images.to(device)
        labels = labels.to(device)
        attention_mask = attention_mask.to(device)
        optimizer.zero_grad()
        outputs = model(pixel_values=images, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    loss = total_loss / len(train_loader)
    val_accuracy = evaluate_accuracy(model, val_loader, device, processor)
    print(f"Epoch {epoch+1}: Loss: {loss}, Val Accuracy: {val_accuracy}")


def evaluate_accuracy(model, loader, device, processor):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, targets, _ in loader:  # Assuming collate_fn returns (images, input_ids, attention_mask)
            images = images.to(device)
            targets = targets.to(device)
            generated_ids = model.generate(pixel_values=images)
            # Decode predictions (list of strings)
            preds = processor.batch_decode(generated_ids, skip_special_tokens=True)
            # Decode targets (convert tensor to list of strings)
            target_strs = processor.batch_decode(targets, skip_special_tokens=True)
            for pred, targ in zip(preds, target_strs):
                if pred.strip().lower() == targ.strip().lower():
                    correct += 1
                total += 1
    model.train()
    return correct / total if total > 0 else 0.0

In [38]:
# train model
transformers.logging.set_verbosity_error()

device = torch.device("cpu")

model_name = "naver-clova-ix/donut-base"
processor = DonutProcessor.from_pretrained(model_name, use_fast=False)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)

num_epochs = 3
for epoch in range(num_epochs):
    train(model, train_loader, val_loader, optimizer, device, processor, epoch)
    
test_accuracy = evaluate_accuracy(model, test_loader, device, processor)
print(f"Test accuracy: {test_accuracy}")

model.save_pretrained("nocom_donut_finetuned")
processor.save_pretrained("nocom_donut_finetuned")

                                                                                

Epoch 1: Loss: 0.19403019092666607, Val Accuracy: 0.002564102564102564


                                                                                

Epoch 2: Loss: 0.009468480019877927, Val Accuracy: 0.008974358974358974


                                                                                

Epoch 3: Loss: 0.008908019508593358, Val Accuracy: 0.01282051282051282


NameError: name 'test_loader' is not defined