In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [1]:
!rm -rf tmp-master/

In [2]:
!git clone https://github.com/israelcamp/tmp-master.git

Cloning into 'tmp-master'...
remote: Enumerating objects: 188, done.[K
remote: Counting objects: 100% (188/188), done.[K
remote: Compressing objects: 100% (140/140), done.[K
remote: Total 188 (delta 85), reused 148 (delta 45), pack-reused 0[K
Receiving objects: 100% (188/188), 5.07 MiB | 18.21 MiB/s, done.
Resolving deltas: 100% (85/85), done.


In [3]:
%pip install -q transformers imagecorruptions pytorch-ignite neptune sentencepiece evaluate jiwer

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m104.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m97.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.8/266.8 kB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m448.1/448.1 kB[0m [31m47.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m90.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m30.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m120.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

In [116]:
CODE_PATH = "../trainer"

In [117]:
import sys
sys.path.append(CODE_PATH)

In [118]:
import os

import torch
import torchvision as tv
from transformers import AutoTokenizer
from transformers import AutoImageProcessor, ViTModel
import srsly

from ignite.engine import (
    Engine,
    Events,
)
from ignite.handlers import Checkpoint
from ignite.contrib.handlers import global_step_from_engine
from ignite.contrib.handlers import ProgressBar
from ignite.contrib.handlers.neptune_logger import NeptuneLogger

In [119]:
from datamodule import SROIETask2DataModule
from ctc import GreedyDecoder
from igmetrics import ExactMatch, WordF1

In [120]:
tokenizer = AutoTokenizer.from_pretrained(f"{CODE_PATH}/sroie-tokenizers/tokenizer-pad0")
decoder = GreedyDecoder(tokenizer.pad_token_id)

In [121]:
tokenizer.eos_token_id, tokenizer.pad_token_id, tokenizer.vocab_size

(1, 0, 77)

In [122]:
if tokenizer.eos_token_id is None:
    tokenizer.eos_token = tokenizer.sep_token
    tokenizer.eos_token_id = tokenizer.sep_token_id

In [123]:
tokenizer.eos_token, tokenizer.eos_token_id

('</s>', 1)

In [124]:
VIT_IMAGE_PROCESSOR = "google/vit-base-patch16-224-in21k"
VIT_MODEL = "baudm/vitstr-small-patch16-224"
IMAGES_DIR = "../data/SROIETask2/data/"
DATA_PATH = "../data/SROIETask2/data.json"

# Loader

In [125]:
val_pct = 0.1
image_processor = AutoImageProcessor.from_pretrained(VIT_IMAGE_PROCESSOR)

In [126]:
img2label = srsly.read_json(DATA_PATH)
image_names = sorted(
    list(set([k.split("__")[0] for k in img2label]))
)
train_size = round(len(image_names) * (1.0 - val_pct))
train_image_names = image_names[:train_size]
valid_image_names = image_names[train_size:]

# create datasets
train_img2label = {
    k: v
    for k, v in img2label.items()
    if k.split("__")[0] in train_image_names
}
valid_img2label = {
    k: v
    for k, v in img2label.items()
    if k.split("__")[0] in valid_image_names
}

In [127]:
import imgaug.augmenters as iaa
from datamodule.augs import ImgaugBackend

def train_augs():
    rotate = iaa.KeepSizeByResize(
        iaa.Affine(rotate=(-5, 5), cval=255, fit_output=True)
    )
    affine = iaa.Affine(
        scale=(0.98, 1.02),
        cval=255,
    )
    pad = iaa.Pad(
        percent=((0, 0.01), (0, 0.1), (0, 0.01), (0, 0.1)),
        keep_size=False,
        pad_cval=255,
    )
    elastic = iaa.ElasticTransformation(alpha=(0.0, 10.0), sigma=2.0)
    gaussian = iaa.imgcorruptlike.GaussianNoise(severity=(1, 3))
    jpeg = iaa.imgcorruptlike.JpegCompression(severity=(1, 5))
    pixelate = iaa.imgcorruptlike.Pixelate(severity=(1, 4))
    dropout = iaa.Dropout(p=(0, 0.05))
    tfms = [
        rotate,
        affine,
        pad,
        elastic,
        gaussian,
        jpeg,
        pixelate,
        dropout,
    ]
    augment = iaa.OneOf(
        [
            iaa.OneOf(tfms),
            iaa.SomeOf(2, tfms),
            iaa.SomeOf(3, tfms),
            iaa.SomeOf(4, tfms),
            iaa.SomeOf(5, tfms),
            iaa.Sequential(tfms),
        ]
    )
    tfms = iaa.OneOf([augment, augment, augment, iaa.Noop()])
    return ImgaugBackend(tfms=tfms)

In [128]:
from PIL import Image
class ViTDataset(torch.utils.data.Dataset):
    def __init__(self, images_dir, img2label, image_processor, training=False):
        self.images_dir = images_dir
        self.img2label = img2label
        self.image_processor = image_processor
        self.labeltuple = sorted(
            [(k, v) for k, v in self.img2label.items()], key=lambda x: x[0]
        )
        self.training = training
        self.tfms = train_augs()
        self.height = 32
        self.min_width = 40

    def __len__(self):
        return len(self.labeltuple)
    
    @staticmethod
    def expand_image(img, h, w):
        expanded = Image.new("RGB", (w, h), color=3 * (255,))  # white
        expanded.paste(img)
        return expanded

    def __getitem__(self, idx):
        image_name, label = self.labeltuple[idx]
        image_path = os.path.join(self.images_dir, f"{image_name}.png")

        image = Image.open(image_path).convert("RGB")

        w, h = image.size
        ratio = self.height / h  # how the height will change
        nw = round(w * ratio)

        image = image.resize((nw, self.height))

        if nw < self.min_width:
            image = self.expand_image(image, self.height, self.min_width)

        if self.training:
            image = self.tfms(image)

        image = self.image_processor(image, return_tensors="pt")
        return image, label

In [129]:
train_dataset = ViTDataset(IMAGES_DIR, train_img2label, image_processor, training=True)
valid_dataset = ViTDataset(IMAGES_DIR, valid_img2label, image_processor)

In [130]:
def collate_fn(samples):
    images = [s[0]["pixel_values"] for s in samples]
    labels = [s[1] for s in samples]
    images = torch.cat(images)
    tokens = tokenizer.batch_encode_plus(
        labels, padding="max_length", return_tensors="pt", max_length=197
    )
    input_ids = tokens.get("input_ids")
    attention_mask = tokens.get("attention_mask")
    return images, input_ids, attention_mask

In [131]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=2,
    shuffle=True,
    collate_fn=collate_fn,
    num_workers=4,
)
val_loader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
)

# Model

In [132]:
class ViTSTR(torch.nn.Module):
    def __init__(self, vit_model):
        super().__init__()
        self.vit_model = vit_model
        self.lm_head = torch.nn.Linear(
            self.vit_model.config.hidden_size, tokenizer.vocab_size
        )

    def forward(self, pixel_values):
        outputs = self.vit_model(pixel_values=pixel_values)
        logits = self.lm_head(outputs.last_hidden_state)
        return logits

vit_model = ViTModel.from_pretrained("WinKawaks/vit-small-patch16-224")
model = ViTSTR(vit_model)

Some weights of the model checkpoint at WinKawaks/vit-small-patch16-224 were not used when initializing ViTModel: ['classifier.bias', 'classifier.weight']
- This IS expected if you are initializing ViTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ViTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ViTModel were not initialized from the model checkpoint at WinKawaks/vit-small-patch16-224 and are newly initialized: ['vit.pooler.dense.weight', 'vit.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [133]:
# count models parameters
c = sum(p.numel() for p in model.parameters() if p.requires_grad)
# format with _
f"{c:_}"

'21_843_149'

In [134]:
_ = model.to("cuda")

In [135]:
batch = next(iter(val_loader))
model.eval()
with torch.no_grad():
    images, labels, attention_mask = batch
    images = images.to("cuda")
    labels = labels.to("cuda")
    attention_mask = attention_mask.to("cuda")
    logits = model(images)
logits.shape

torch.Size([2, 197, 77])

In [136]:
y_pred = logits.argmax(-1)
y_pred = tokenizer.batch_decode(y_pred)
# stop at eos token
y_pred = [s.split(tokenizer.eos_token)[0] for s in y_pred]
y_pred

['&Z!<&!GHU<<&U!!))``)<&<<<..!.)!>>{<!<0)<]&.I&|.<#)I|)`><<I#U7<pad>#7##~|##&.RI|#)7#<7R#~#A>|]#?#7?|IIR)~*~##)##?!R)#I)Rl|#4##<&|)<7I|`||)R|)R7I&.(*&7|)||<&>|',
 'GZ|<GGDGZGZ)UUZ.<<<()<<I))?<X<])(#<Y<G))))G?#)>7G?.G#U}U<<#*)]7>I>]DG]#<#U~#)`U>>I<##AG#)>)~#>)7Y4###U)))<G])-GG##GU)~))IG)]<G']

In [137]:
labels.shape

torch.Size([2, 197])

In [138]:
torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)(logits.view(-1, tokenizer.vocab_size), labels.view(-1))

tensor(4.9064, device='cuda:0')

# Ignite

In [99]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [100]:
_ = model.to(device)

In [79]:
MAX_EPOCHS=30
STEPS = len(train_loader) * MAX_EPOCHS
STEPS

511110

In [80]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=0)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, STEPS, 1e-6)
criterion = torch.nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)


In [112]:
def get_preds_from_logits(logits, labels):
    y_pred = logits.argmax(-1)
    y_pred = tokenizer.batch_decode(y_pred)
    # stop at eos token
    y_pred = [s.split(tokenizer.eos_token)[0] for s in y_pred]
    y = tokenizer.batch_decode(labels, skip_special_tokens=True)
    return y_pred, y

In [113]:
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()

    images, labels, attention_mask = [x.to(device) for x in batch]

    logits = model(images)

    loss = criterion(logits.view(-1, logits.shape[-1]), labels.view(-1))
    loss.backward()

    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

    optimizer.step()
    lr_scheduler.step()
    return loss.item()

In [114]:
def val_step(engine, batch):
    model.eval()
    images, labels, attention_mask = [x.to(device) for x in batch]
    with torch.no_grad():
        logits = model(images)

    y_pred, y = get_preds_from_logits(logits, labels)
    return y_pred, y
    

In [105]:
def log_validation_results(engine):
    validation_evaluator.run(val_loader)
    metrics = validation_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_f1 = metrics['f1']
    print(f"Validation Results - Epoch: {engine.state.epoch}  Avg accuracy: {avg_accuracy:.3f} Avg F1: {avg_f1:.3f}")

In [85]:
trainer = Engine(train_step)
train_evaluator = Engine(val_step)
validation_evaluator = Engine(val_step)

In [115]:
val_step(engine=None, batch=next(iter(val_loader)))

(['RESTORAN WAN SHENG', '002043319-W'], ['RESTORAN WAN SHENG', '002043319-W'])

In [40]:
trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results)

ExactMatch().attach(train_evaluator, "accuracy")
ExactMatch().attach(validation_evaluator, "accuracy")
WordF1().attach(train_evaluator, "f1")
WordF1().attach(validation_evaluator, "f1")

In [71]:
!rm -rf vit-checkpoint-models

In [72]:
to_save = {'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'trainer': trainer}
gst = lambda *_: trainer.state.epoch
handler = Checkpoint(
    to_save, 
    'vit-checkpoint-models', 
    n_saved=1, 
    global_step_transform=gst,
)
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)

<ignite.engine.events.RemovableEventHandle at 0x7f3d581f6bf0>

In [73]:
to_save = {'model': model}
handler = Checkpoint(
    to_save, 
    "vit-checkpoint-models",
    n_saved=1, 
    filename_prefix='best',
    score_name="accuracy",
    global_step_transform=global_step_from_engine(trainer)
)
validation_evaluator.add_event_handler(Events.COMPLETED, handler)

<ignite.engine.events.RemovableEventHandle at 0x7f3d582d6350>

In [74]:
neptune_logger = NeptuneLogger(
    project="i155825/OCRMsc",
    api_token="eyJhcGlfYWRkcmVzcyI6Imh0dHBzOi8vYXBwLm5lcHR1bmUuYWkiLCJhcGlfdXJsIjoiaHR0cHM6Ly9hcHAubmVwdHVuZS5haSIsImFwaV9rZXkiOiJhOGUyY2VlOS1hZTU5LTQ2NGQtYTY5Zi04OGJmZWM2M2NlMDAifQ==",
)

neptune_logger.attach_output_handler(
    trainer,
    event_name=Events.ITERATION_COMPLETED,
    tag="training",
    output_transform=lambda loss: {"loss": loss},
)

neptune_logger.attach_output_handler(
    validation_evaluator,
    event_name=Events.EPOCH_COMPLETED,
    tag="validation",
    metric_names=["f1", "accuracy"],
    global_step_transform=global_step_from_engine(trainer),  
)

neptune_logger["code"].upload_files([
    f"{CODE_PATH}/*.py", 
    f"{CODE_PATH}/**/*.py",
    __file__
])

https://app.neptune.ai/i155825/OCRMsc/e/OC-56


In [41]:
pbar = ProgressBar()
pbar.attach(trainer, output_transform=lambda x: {'loss': x})

In [42]:
trainer.run(train_loader, max_epochs=MAX_EPOCHS)

Epoch [1/30]: [437/8519]   5%|▌         , loss=1.87 [00:33<09:47]Engine run is terminating due to exception: 


KeyboardInterrupt: 

Epoch [1/30]: [438/8519]   5%|▌         , loss=1.87 [00:50<09:47]

# Test

In [44]:
import collections
import os

from tqdm.auto import tqdm

In [40]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [41]:
_ = model.to(device)

In [None]:
!unzip {DATA_PATH}/testdata.zip

In [38]:
test_img2label = srsly.read_json(f"../data/SROIETask2/testdata.json")
test_images_dir = "../data/SROIETask2/testdata/"

test_dataset = ViTDataset(test_images_dir, test_img2label, image_processor)
test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    collate_fn=collate_fn,
    num_workers=4,
)

In [41]:
test_image_names = [x[0] for x in test_dataset.labeltuple]

In [42]:
test_image_names[0]

'X00016469670.jpg__0'

In [47]:
state_dict = torch.load(handler.last_checkpoint)
model.load_state_dict(state_dict)

<All keys matched successfully>

In [45]:
test_results = collections.defaultdict(list)
start = 0
for i, batch in tqdm(enumerate(test_loader), total=len(test_loader)):
    y_pred, y = val_step(None, batch)

    names = test_image_names[start : start + len(y_pred)]
    start += len(y_pred)

    for p, n in zip(y_pred, names):
        n = n.split(".")[0]
        test_results[n].extend(p.strip().split())

    break

  0%|          | 0/9692 [00:00<?, ?it/s]


In [50]:
dir_path = "testsroie"
os.makedirs(dir_path, exist_ok=True)
for key, values in test_results.items():
    with open(f"{dir_path}/{key}.txt", "w") as f:
        f.write("\n".join(values))

In [None]:
! cd {dir_path} && zip -r sub.zip *.txt && mv sub.zip ../

In [46]:
!python ../sroie_evaluator/script.py -g=../sroie_evaluator/gtz.zip -s=../sub.zip

Error!
Error loading the ZIP archive

