In [1]:
!pip install -q transformers
!pip install -q datasets jiwer

In [None]:
!mkdir sentences
%cd /content/sentences
!tar -xvf /content/drive/MyDrive/IAM_dataset/sentences.tgz

In [3]:
import os
import json
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from PIL import Image
from tqdm import tqdm
from transformers import AdamW
from functorch import make_functional, make_functional_with_buffers, grad

In [4]:
class IAM_fewshot_dataset(Dataset):
    def __init__(self,
                 image_dir,
                 meta_filename,
                 processor,
                 max_target_length=128,
                 episode_num=600,
                 shot=5,):

        self.image_dir = image_dir
        self.episode_num = episode_num
        self.shot = shot
        self.processor = processor
        self.max_target_length = max_target_length

        with open(meta_filename, 'r') as json_file:
            meta_data = json.load(json_file)

        for i in range(len(meta_data)):
            sample = meta_data[i]
            dir = os.path.join(image_dir, sample['image_dir'])
            if not os.path.exists(dir):
                print(dir, os.path.exists(dir))
                raise Exception

        self._writer_id_to_ind = {}
        writer_ind = 0
        for sample in meta_data:
            if sample['writer_id'] not in self._writer_id_to_ind:
                self._writer_id_to_ind[sample['writer_id']] = writer_ind
                writer_ind += 1

        self._ind_to_writer_id = {value: key for key, value in self._writer_id_to_ind.items()}

        self.writer_samples = [[] for ind in self._ind_to_writer_id]
        for sample in meta_data:
            writer_id = sample['writer_id']
            writer_ind = self._writer_id_to_ind[writer_id]
            self.writer_samples[writer_ind].append(sample)

        self.writer_num = len(self.writer_samples)

    def __len__(self,):
        return self.episode_num

    def get_encoding(self, sample):
        # get file name + text
        file_name = os.path.join(self.image_dir, sample['image_dir'])
        text = ' '.join(sample['transcription'])

        # prepare image (i.e. resize + normalize)
        image = Image.open(file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding


    def __getitem__(self, idx):
        # get writer
        while True:
            writer_ind = np.random.randint(0, self.writer_num)
            samples = self.writer_samples[writer_ind]
            if len(samples) > self.shot:
                break

        random.shuffle(samples)
        supports = samples[:self.shot]
        query = samples[self.shot]

        supports = [self.get_encoding(sample) for sample in supports]
        query = self.get_encoding(query)

        pixel_values = []
        labels = []
        for batch in supports:
            pixel_values.append(batch['pixel_values'])
            labels.append(batch['labels'])
        pixel_values = torch.stack(pixel_values, 0)
        labels = torch.stack(labels, 0)
        supports = {'pixel_values': pixel_values, "labels": labels}
        return supports, query

In [5]:
class IAM_global_dataset(Dataset):
    def __init__(self,
                 image_dir,
                 meta_filename,
                 processor,
                 max_target_length=128):

        self.image_dir = image_dir
        self.processor = processor
        self.max_target_length = max_target_length

        with open(meta_filename, 'r') as json_file:
            self.meta_data = json.load(json_file)

        for i in range(len(self.meta_data)):
            sample = self.meta_data[i]
            dir = os.path.join(image_dir, sample['image_dir'])
            if not os.path.exists(dir):
                print(dir, os.path.exists(dir))
                raise Exception

    def __len__(self,):
        return len(self.meta_data)

    def get_encoding(self, sample):
        # get file name + text
        file_name = os.path.join(self.image_dir, sample['image_dir'])
        text = ' '.join(sample['transcription'])

        # prepare image (i.e. resize + normalize)
        image = Image.open(file_name).convert("RGB")
        pixel_values = self.processor(image, return_tensors="pt").pixel_values
        # add labels (input_ids) by encoding the text
        labels = self.processor.tokenizer(text,
                                          padding="max_length",
                                          max_length=self.max_target_length).input_ids
        # important: make sure that PAD tokens are ignored by the loss function
        labels = [label if label != self.processor.tokenizer.pad_token_id else -100 for label in labels]

        encoding = {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(labels)}
        return encoding


    def __getitem__(self, idx):
        sample = self.meta_data[idx]
        sample = self.get_encoding(sample)
        return sample

In [6]:
from transformers import TrOCRProcessor

processor = TrOCRProcessor.from_pretrained("microsoft/trocr-small-printed")

global_train_dataset = IAM_global_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_train_data.json', processor=processor)

fewshot_train_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_train_data.json', processor=processor, episode_num=2000)
fewshot_test_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_test_data.json', processor=processor, episode_num=100)
fewshot_val_dataset = IAM_fewshot_dataset('/content/sentences', '/content/drive/MyDrive/IAM_dataset/meta_val_data.json', processor=processor, episode_num=100)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [7]:
from torch.utils.data import DataLoader

global_train_dataloader = DataLoader(global_train_dataset, batch_size=8)

fewshot_train_dataloader = DataLoader(fewshot_train_dataset, batch_size=1)
fewshot_test_dataloader = DataLoader(fewshot_test_dataset, batch_size=1)
fewshot_val_dataloader = DataLoader(fewshot_val_dataset, batch_size=1)

In [8]:
from transformers import VisionEncoderDecoderModel
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-small-printed")
model.to(device)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


VisionEncoderDecoderModel(
  (encoder): DeiTModel(
    (embeddings): DeiTEmbeddings(
      (patch_embeddings): DeiTPatchEmbeddings(
        (projection): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): DeiTEncoder(
      (layer): ModuleList(
        (0-11): 12 x DeiTLayer(
          (attention): DeiTAttention(
            (attention): DeiTSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): DeiTSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): DeiTIntermediate(
            (dense): Linear(

In [9]:
model.config.decoder_start_token_id = processor.tokenizer.eos_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
model.config.vocab_size = model.config.decoder.vocab_size

In [10]:
sd = model.state_dict()
for k in sd:
    sd[k] = torch.randn_like(sd[k])
model.load_state_dict(sd)

<All keys matched successfully>

In [11]:
from datasets import load_metric

cer_metric = load_metric("cer")

  cer_metric = load_metric("cer")
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [12]:
def compute_cer(pred_ids, label_ids):
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return cer

In [13]:
def compute_confidence_interval(data):
    """
    Compute 95% confidence interval
    :param data: An array of mean accuracy (or mAP) across a number of sampled episodes.
    :return: the 95% confidence interval for this data.
    """
    a = 1.0 * np.array(data)
    m = np.mean(a)
    std = np.std(a)
    pm = 1.96 * (std / np.sqrt(len(a)))
    return m, pm

In [18]:
def fewshot_testing(model, fewshot_dataloader):
    model.eval()
    valid_cer = []
    inner_iter_num = 5
    init_state_dict = model.state_dict()

    for supports, query in tqdm(fewshot_dataloader):
        model.load_state_dict(init_state_dict)
        optimizer = AdamW(model.parameters(), lr=5e-5)
        # finetune model on support samples
        for _ in range(inner_iter_num):
            # supports_new = {}
            # for k,v in supports.items():
            #     supports_new[k] = v.to(device).squeeze(0)
            for k in supports:
                supports[k] = supports[k].to(model.device).squeeze(0)
            outputs = model(**supports)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        # evaluate fine-tuned model on query
        model.eval()
        outputs = model.generate(query["pixel_values"].to(device))
        cer = compute_cer(pred_ids=outputs, label_ids=query["labels"])
        valid_cer.append(cer)
    return compute_confidence_interval(valid_cer)

In [19]:
best_cer, best_std = fewshot_testing(model, fewshot_val_dataloader)
best_cer

100%|██████████| 100/100 [01:35<00:00,  1.05it/s]


4.209980950708013

In [20]:
from tqdm import tqdm
num_epochs = 10
inner_iter_num = 1
valid_cer = []
optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()
best_cer, best_std = fewshot_testing(model, fewshot_val_dataloader)
print(f"cer: {best_cer} += {best_std}")

def compute_loss(params, **x):
    output = func(params, **x)
    return output.loss

def tree_map(func, x, y):
    print(isinstance(x[0], list))

for epoch in range(num_epochs):
    pbar = tqdm(fewshot_train_dataloader)
    for supports, query in pbar:
        func, params = make_functional(model)
        for k in supports:
            supports[k] = supports[k].to(model.device).squeeze(0)
            query[k] = query[k].to(model.device)

        updated_params = params

        for _ in range(inner_iter_num):
            grad_weights = grad(compute_loss)(updated_params, **supports)
            updated_params = [x - 5e-5 * y for x, y in zip(updated_params, grad_weights)]
            updated_params = tuple(updated_params)

        outputs = func(updated_params, **query)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_description(f"loss {loss.item()}")

    cer, std = fewshot_testing(model, fewshot_val_dataloader)
    print(f"cer: {cer} += {std}")
    if cer < best_cer:
        best_cer, best_std = cer, std
        print(f"Best checkpoint found at epoch {epoch}")
        torch.save(model.state_dict(), f'/content/best_epoch_{epoch}.pth')

100%|██████████| 100/100 [01:30<00:00,  1.10it/s]


cer: 2.7390523053513642 += 0.915965384795954


  warn_deprecated('make_functional', 'torch.func.functional_call')
  warn_deprecated('grad')
loss 15.900981903076172: 100%|██████████| 2000/2000 [21:19<00:00,  1.56it/s]
100%|██████████| 100/100 [01:30<00:00,  1.11it/s]


cer: 1.6860801187690975 += 0.45653230124290495
Best checkpoint found at epoch 0


loss 25.235246658325195: 100%|██████████| 2000/2000 [21:16<00:00,  1.57it/s]
100%|██████████| 100/100 [01:28<00:00,  1.13it/s]


cer: 1.201715406162465 += 0.2698849438099793
Best checkpoint found at epoch 1


loss 10.685157775878906: 100%|██████████| 2000/2000 [21:18<00:00,  1.56it/s]
100%|██████████| 100/100 [01:29<00:00,  1.11it/s]


cer: 1.1310205866970948 += 0.10930752539283478
Best checkpoint found at epoch 2


loss 14.53603744506836: 100%|██████████| 2000/2000 [21:18<00:00,  1.56it/s]
100%|██████████| 100/100 [01:29<00:00,  1.11it/s]


cer: 1.0531909902496408 += 0.050901938522680545
Best checkpoint found at epoch 3


loss 7.898475170135498: 100%|██████████| 2000/2000 [21:17<00:00,  1.57it/s]
100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


cer: 1.0767455950911833 += 0.1220590579449819


loss 10.606493949890137: 100%|██████████| 2000/2000 [21:17<00:00,  1.57it/s]
100%|██████████| 100/100 [01:28<00:00,  1.13it/s]


cer: 1.0242862797236993 += 0.024582452159994342
Best checkpoint found at epoch 5


loss 12.865299224853516: 100%|██████████| 2000/2000 [21:17<00:00,  1.57it/s]
100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


cer: 1.003720930232558 += 0.007256466518061305
Best checkpoint found at epoch 6


loss 9.770898818969727: 100%|██████████| 2000/2000 [21:18<00:00,  1.56it/s]
100%|██████████| 100/100 [01:29<00:00,  1.12it/s]


cer: 1.0399197860962566 += 0.06840283024053091


loss 7.15339469909668: 100%|██████████| 2000/2000 [21:18<00:00,  1.56it/s]
100%|██████████| 100/100 [01:28<00:00,  1.13it/s]


cer: 1.0125567435082141 += 0.02095103640624679


loss 9.036985397338867: 100%|██████████| 2000/2000 [21:17<00:00,  1.57it/s]
100%|██████████| 100/100 [01:27<00:00,  1.14it/s]


cer: 1.0034166666666666 += 0.005195155531839253
Best checkpoint found at epoch 9
