In [1]:
!pip install -q transformers
!pip install -q datasets jiwer

In [2]:
!mkdir sentences
%cd /content/sentences
!tar -xvf /content/drive/MyDrive/IAM_dataset/sentences.tgz

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
k07/k07-067a/k07-067a-s01-00.png
k07/k07-067a/k07-067a-s01-01.png
k07/k07-067a/k07-067a-s02-00.png
k07/k07-067a/k07-067a-s02-01.png
k07/k07-067a/k07-067a-s03-00.png
k07/k07-067a/k07-067a-s03-01.png
k07/k07-067a/k07-067a-s04-00.png
k07/k07-067a/k07-067a-s05-00.png
k07/k07-067a/k07-067a-s06-00.png
k07/k07-067a/k07-067a-s07-00.png
k07/k07-067a/k07-067a-s08-00.png
k07/k07-067a/k07-067a-s08-01.png
k07/k07-077/
k07/k07-077/k07-077-s00-00.png
k07/k07-077/k07-077-s00-01.png
k07/k07-077/k07-077-s01-00.png
k07/k07-077/k07-077-s02-00.png
k07/k07-077/k07-077-s03-00.png
k07/k07-077/k07-077-s03-01.png
k07/k07-077/k07-077-s04-00.png
k07/k07-077/k07-077-s04-01.png
k07/k07-077/k07-077-s05-00.png
k07/k07-077/k07-077-s05-01.png
k07/k07-077/k07-077-s05-02.png
k07/k07-077/k07-077-s06-00.png
k07/k07-085/
k07/k07-085/k07-085-s00-00.png
k07/k07-085/k07-085-s01-00.png
k07/k07-085/k07-085-s02-00.png
k07/k07-085/k07-085-s03-00.png
k07/k07-085/k07-0

In [3]:
import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image

class IAM_fewshot_dataset(Dataset):
    def __init__(self,
                 image_dir,
                 meta_filename,
                 processor,
                 max_target_length=128,
                 episode_num=1000,
                 shot=5,):

        self.image_dir = image_dir
        self.episode_num = episode_num
        self.shot = shot
        self.processor = processor
        self.max_target_length = max_target_length

        with open(meta_filename, 'r') as json_file:
            meta_data = json.load(json_file)

        for i in range(len(meta_data)):
            sample = meta_data[i]
            dir = os.path.join(image_dir, sample['image_dir'])
            if not os.path.exists(dir):
                print(dir, os.path.exists(dir))
                raise Exception

        self._writer_id_to_ind = {}
        writer_ind = 0
        for sample in meta_data:
            if sample['writer_id'] not in self._writer_id_to_ind:
                self._writer_id_to_ind[sample['writer_id']] = writer_ind
                writer_ind += 1

        self._ind_to_writer_id = {value: key for key, value in self._writer_id_to_ind.items()}

        self.writer_samples = [[] for ind in self._ind_to_writer_id]
        for sample in meta_data:
            writer_id = sample['writer_id']
            writer_ind = self._writer_id_to_ind[writer_id]
            self.writer_samples[writer_ind].append(sample)

        self.writer_num = len(self.writer_samples)

    def __len__(self,):
        return self.episode_num

    def get_encoding(self, sample):
        # get file name + text
        file_name = os.path.join(self.image_dir, sample['image_dir'])
        text = ' '.join(sample['transcription'])

        # prepare image (i.e. resize + normalize)
        image = Image.open(file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding


    def __getitem__(self, idx):
        # get writer
        while True:
            writer_ind = np.random.randint(0, self.writer_num)
            samples = self.writer_samples[writer_ind]
            if len(samples) > self.shot:
                break

        random.shuffle(samples)
        supports = samples[:self.shot]
        query = samples[self.shot]

        supports = [self.get_encoding(sample) for sample in supports]
        query = self.get_encoding(query)

        pixel_values = []
        labels = []
        for batch in supports:
            pixel_values.append(batch['pixel_values'])
            labels.append(batch['labels'])
        pixel_values = torch.stack(pixel_values, 0)
        labels = torch.stack(labels, 0)
        supports = {'pixel_values': pixel_values, "labels": labels}
        return supports, query

In [4]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-handwritten")
# processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-stage1")
fewshot_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_data.json', processor=processor)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [5]:
from torch.utils.data import DataLoader

fewshot_dataloader = DataLoader(fewshot_dataset, batch_size=1)

In [6]:
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-handwritten")
# model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-stage1")
model.to(device)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


VisionEncoderDecoderModel(
  (encoder): DeiTModel(
    (embeddings): DeiTEmbeddings(
      (patch_embeddings): DeiTPatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DeiTEncoder(
      (layer): ModuleList(
        (0-11): 12 x DeiTLayer(
          (attention): DeiTAttention(
            (attention): DeiTSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): DeiTSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): DeiTIntermediate(
            (dense): Linear(

In [7]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [8]:
from datasets import load_metric

cer_metric = load_metric("cer")

  cer_metric = load_metric("cer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [9]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [20]:
from tqdm import tqdm
valid_cer = []
for supports, query in tqdm(fewshot_dataloader):
    model.train()
    # finetune model on support samples


    # evaluate fine-tuned model on query
    model.eval()
    outputs = model.generate(query["pixel_values"].to(device))
    cer = compute_cer(pred_ids=outputs, label_ids=query["labels"])
    valid_cer.append(cer)
print(np.mean(valid_cer))

100%|██████████| 1000/1000 [07:27<00:00,  2.23it/s]

6.552375339671563





In [10]:
init_state_dict = model.state_dict()

In [31]:
def support_process(supports, device):
    pixel_values = []
    labels = []
    for batch in supports:
        pixel_values.append(batch['pixel_values'])
        labels.append(batch['labels'])
    print(len(pixel_values))
    pixel_values = torch.cat(pixel_values, 0).to(device)
    labels = torch.cat(labels, 0).to(device)
    return pixel_values, labels

In [11]:
# 5-shot
from tqdm import tqdm
from transformers import AdamW

valid_cer = []
inner_iter_num = 5

for supports, query in tqdm(fewshot_dataloader):
    model.load_state_dict(init_state_dict)
    optimizer = AdamW(model.parameters(), lr=5e-5)
    model.train()
    # finetune model on support samples
    for _ in range(inner_iter_num):
        for k,v in supports.items():
            supports[k] = v.to(device).squeeze(0)
        outputs = model(**supports)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    # evaluate fine-tuned model on query
    model.eval()
    outputs = model.generate(query["pixel_values"].to(device))
    cer = compute_cer(pred_ids=outputs, label_ids=query["labels"])
    valid_cer.append(cer)
print(np.mean(valid_cer))

100%|██████████| 1000/1000 [23:35<00:00,  1.42s/it]

0.8604835506041866





In [14]:
supports.keys()

dict_keys(['pixel_values', 'labels'])

In [16]:
supports['pixel_values'].shape

torch.Size([1, 15, 384, 384])