In [1]:
!pip install -q transformers
!pip install -q datasets jiwer

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m13.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
!mkdir sentences
%cd /content/sentences
!tar -xvf /content/drive/MyDrive/IAM_dataset/sentences.tgz

/content/sentences
./
a01/
a01/a01-000u/
a01/a01-000u/a01-000u-s00-00.png
a01/a01-000u/a01-000u-s00-01.png
a01/a01-000u/a01-000u-s00-02.png
a01/a01-000u/a01-000u-s00-03.png
a01/a01-000u/a01-000u-s01-00.png
a01/a01-000u/a01-000u-s01-01.png
a01/a01-000u/a01-000u-s01-02.png
a01/a01-000u/a01-000u-s01-03.png
a01/a01-000x/
a01/a01-000x/a01-000x-s00-00.png
a01/a01-000x/a01-000x-s00-01.png
a01/a01-000x/a01-000x-s00-02.png
a01/a01-000x/a01-000x-s01-00.png
a01/a01-000x/a01-000x-s01-01.png
a01/a01-000x/a01-000x-s01-02.png
a01/a01-000x/a01-000x-s01-03.png
a01/a01-003/
a01/a01-003/a01-003-s00-00.png
a01/a01-003/a01-003-s00-01.png
a01/a01-003/a01-003-s00-02.png
a01/a01-003/a01-003-s01-00.png
a01/a01-003/a01-003-s01-01.png
a01/a01-003/a01-003-s01-02.png
a01/a01-003/a01-003-s01-03.png
a01/a01-003/a01-003-s01-04.png
a01/a01-003/a01-003-s02-00.png
a01/a01-003/a01-003-s02-01.png
a01/a01-003/a01-003-s02-02.png
a01/a01-003/a01-003-s02-03.png
a01/a01-003/a01-003-s02-04.png
a01/a01-003u/
a01/a01-003u/a01-003

In [None]:
import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from tqdm import tqdm
from transformers import AdamW

In [None]:
class IAM_fewshot_dataset(Dataset):
    def __init__(self,
                 image_dir,
                 meta_filename,
                 processor,
                 max_target_length=128,
                 episode_num=600,
                 shot=5,):

        self.image_dir = image_dir
        self.episode_num = episode_num
        self.shot = shot
        self.processor = processor
        self.max_target_length = max_target_length

        with open(meta_filename, 'r') as json_file:
            meta_data = json.load(json_file)

        for i in range(len(meta_data)):
            sample = meta_data[i]
            dir = os.path.join(image_dir, sample['image_dir'])
            if not os.path.exists(dir):
                print(dir, os.path.exists(dir))
                raise Exception

        self._writer_id_to_ind = {}
        writer_ind = 0
        for sample in meta_data:
            if sample['writer_id'] not in self._writer_id_to_ind:
                self._writer_id_to_ind[sample['writer_id']] = writer_ind
                writer_ind += 1

        self._ind_to_writer_id = {value: key for key, value in self._writer_id_to_ind.items()}

        self.writer_samples = [[] for ind in self._ind_to_writer_id]
        for sample in meta_data:
            writer_id = sample['writer_id']
            writer_ind = self._writer_id_to_ind[writer_id]
            self.writer_samples[writer_ind].append(sample)

        self.writer_num = len(self.writer_samples)

    def __len__(self,):
        return self.episode_num

    def get_encoding(self, sample):
        # get file name + text
        file_name = os.path.join(self.image_dir, sample['image_dir'])
        text = ' '.join(sample['transcription'])

        # prepare image (i.e. resize + normalize)
        image = Image.open(file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding


    def __getitem__(self, idx):
        # get writer
        while True:
            writer_ind = np.random.randint(0, self.writer_num)
            samples = self.writer_samples[writer_ind]
            if len(samples) > self.shot:
                break

        random.shuffle(samples)
        supports = samples[:self.shot]
        query = samples[self.shot]

        supports = [self.get_encoding(sample) for sample in supports]
        query = self.get_encoding(query)

        pixel_values = []
        labels = []
        for batch in supports:
            pixel_values.append(batch['pixel_values'])
            labels.append(batch['labels'])
        pixel_values = torch.stack(pixel_values, 0)
        labels = torch.stack(labels, 0)
        supports = {'pixel_values': pixel_values, "labels": labels}
        return supports, query

In [None]:
class IAM_global_dataset(Dataset):
    def __init__(self,
                 image_dir,
                 meta_filename,
                 processor,
                 max_target_length=128):

        self.image_dir = image_dir
        self.processor = processor
        self.max_target_length = max_target_length

        with open(meta_filename, 'r') as json_file:
            self.meta_data = json.load(json_file)

        for i in range(len(self.meta_data)):
            sample = self.meta_data[i]
            dir = os.path.join(image_dir, sample['image_dir'])
            if not os.path.exists(dir):
                print(dir, os.path.exists(dir))
                raise Exception

    def __len__(self,):
        return len(self.meta_data)

    def get_encoding(self, sample):
        # get file name + text
        file_name = os.path.join(self.image_dir, sample['image_dir'])
        text = ' '.join(sample['transcription'])

        # prepare image (i.e. resize + normalize)
        image = Image.open(file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding


    def __getitem__(self, idx):
        sample = self.meta_data[idx]
        sample = self.get_encoding(sample)
        return sample

In [None]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")

global_train_dataset = IAM_global_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_train_data.json', processor=processor)

fewshot_train_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_train_data.json', processor=processor)
fewshot_test_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_test_data.json', processor=processor)
fewshot_val_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_val_data.json', processor=processor)

In [None]:
from torch.utils.data import DataLoader

global_train_dataloader = DataLoader(global_train_dataset, batch_size=8)

fewshot_train_dataloader = DataLoader(fewshot_train_dataset, batch_size=1)
fewshot_test_dataloader = DataLoader(fewshot_test_dataset, batch_size=1)
fewshot_val_dataloader = DataLoader(fewshot_val_dataset, batch_size=1)

In [None]:
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
model.to(device)

In [None]:
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [None]:
from datasets import load_metric

cer_metric = load_metric("cer")

In [None]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [None]:
def compute_confidence_interval(data):
    """
    Compute 95% confidence interval
    :param data: An array of mean accuracy (or mAP) across a number of sampled episodes.
    :return: the 95% confidence interval for this data.
    """
    a = 1.0 * np.array(data)
    m = np.mean(a)
    std = np.std(a)
    pm = 1.96 * (std / np.sqrt(len(a)))
    return m, pm

In [None]:
def fewshot_testing(model, fewshot_dataloader):
    model.eval()
    valid_cer = []
    inner_iter_num = 5
    init_state_dict = model.state_dict()

    for supports, query in tqdm(fewshot_dataloader):
        model.load_state_dict(init_state_dict)
        optimizer = AdamW(model.parameters(), lr=5e-5)
        model.train()
        # finetune model on support samples
        for _ in range(inner_iter_num):
            # supports_new = {}
            # for k,v in supports.items():
            #     supports_new[k] = v.to(device).squeeze(0)
            for k in supports:
                supports[k] = supports[k].to(model.device).squeeze(0)
            outputs = model(**supports)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # evaluate fine-tuned model on query
        model.eval()
        outputs = model.generate(query["pixel_values"].to(device))
        cer = compute_cer(pred_ids=outputs, label_ids=query["labels"])
        valid_cer.append(cer)
    return compute_confidence_interval(valid_cer)

In [None]:
from tqdm import tqdm
num_epochs = 10
valid_cer = []
optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()
best_cer, best_std = fewshot_testing(model, fewshot_val_dataloader)
print(f"cer: {best_cer} += {best_std}")
for epoch in range(num_epochs):
    pbar = tqdm(global_train_dataloader)
    for samples in pbar:
        for k in samples:
            samples[k] = samples[k].to(model.device)
        output = model(**samples)
        loss = output.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_description(f"loss {loss.item()}")

    cer, std = fewshot_testing(model, fewshot_val_dataloader)
    print(f"cer: {cer} += {std}")
    if cer < best_cer:
        best_cer, best_std = cer, std
        print(f"Best checkpoint found at epoch {epoch}")
        torch.save(model.state_dict(), f'/content/best_epoch_{epoch}.pth')
print(compute_confidence_interval(valid_cer))