In [1]:
# #!git clone https://github.com/sberbank-ai/fusion_brain_aij2021.git
# !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install tpu_star==0.0.1rc10
# !pip install albumentations==0.5.2
# !pip install einops==0.3.2
# !pip install pytorch_lightning
# !pip install comet_ml
# !pip install transformers==4.10.0 
# !pip install colorednoise==1.1.1
# !pip install catalyst==21.8 
# !pip install opencv-python==4.5.3
# !pip install gdown==4.0.2
# !pip install pymorphy2

In [None]:
# Доступные ресурсы
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)

print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

In [None]:
import os
import json
import random
from collections import OrderedDict

import numpy as np
import pandas as pd
from tqdm import tqdm
from skimage import io
from sklearn.metrics import accuracy_score
import albumentations as A
from transformers import GPT2Model, GPT2Tokenizer
from torch.utils.data import SequentialSampler

from fb_baseline.fb_utils.metrics import cer, wer, string_accuracy, acc, vqa_evaluate, detection_evaluate
from fb_baseline.fb_utils.c2c_eval import Beam, eval_bleu
from fb_baseline.fb_utils.detection_eval import inverse_detection_evaluation
from fb_baseline.fb_utils.vqa_eval import inverse_vqa_evaluation
from fb_baseline.model.utils.utils import CTCLabeling
from fb_baseline.model.dataset.dataset import (
    HTRDataset, VQADataset, C2CDataset, DetectionDataset, FusionDataset,
    htr_collate_fn, c2c_collate_fn, vqa_collate_fn, detection_collate_fn
)
from fb_baseline.model.model import InverseAttentionGPT2FusionBrain

# Dataset

In [None]:
# #
# Handwritten
# #
with open('/home/jovyan/vladimir/fusion_brain/data/test/true_fb/true_HTR.json', 'rb') as f:
    json_marking = json.load(f)
marking = []
for image_name, text in json_marking.items():
    marking.append({
        'task_ids': 'handwritten',
        'images': os.path.join('/home/jovyan/vladimir/fusion_brain/data/private_fb/HTR/images/', image_name),
        'gt_texts': text,
    })
df_handwritten = pd.DataFrame(marking)
df_handwritten['stage'] = 'test'
# #
# C2C
# #
with open('/home/jovyan/vladimir/fusion_brain/data/test/private_fb/C2C/requests.json', 'rb') as f:
    java_json = json.load(f)
with open('/home/jovyan/vladimir/fusion_brain/data/test/true_fb/true_C2C.json', 'rb') as f:
    python_json = json.load(f)
marking = []
for key in java_json:
    marking.append({
        'task_ids': 'c2c',
        'java': java_json[key],
        'python': python_json[key],
    })
df_c2c = pd.DataFrame(marking)
df_c2c['stage'] = 'test'
# #
# VQA
# #
with open('/home/jovyan/vladimir/fusion_brain/data/test/private_fb/VQA/questions.json', 'rb') as f:
    json_questions = json.load(f)
with open('/home/jovyan/vladimir/fusion_brain/data/test/true_fb/true_VQA.json', 'rb') as f:
    json_answers = json.load(f)
marking = []
for key in json_questions:
    marking.append({
        'task_ids': 'vqa',
        'images': os.path.join(
            "/home/jovyan/vladimir/fusion_brain/data/test/private_fb/VQA/images/", json_questions[key]['file_name']
        ),
        'questions': json_questions[key]['question'],
        'answers': json_answers[key]['answer'],
    })
df_vqa = pd.DataFrame(marking)
df_vqa['stage'] = 'test'
# #
# Detection
# #
with open('/home/jovyan/vladimir/fusion_brain/data/test/true_fb/true_zsOD.json', 'rb') as f:
    json_true_zsod_test = json.load(f)
marking = []
for image_name in json_true_zsod_test:
    marking.append({
        'task_ids': 'detection',
        'images': os.path.join("/home/jovyan/vladimir/fusion_brain/data/test/private_fb/zsOD/images", image_name),
        'requests': [request for request in json_true_zsod_test[image_name].keys()],
        'boxes': [boxes for boxes in json_true_zsod_test[image_name].values()],
    })
df_detection = pd.DataFrame(marking)
df_detection['stage'] = 'test'

In [None]:
task_augs = {
    'handwritten': A.Compose([
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.25, always_apply=False),
        A.Rotate(limit=3, interpolation=1, border_mode=0, p=0.5),
        A.JpegCompression(quality_lower=75, p=0.5),
    ], p=1.0),
    'vqa': A.Compose([
        A.Resize(224, 224, always_apply=True),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ], p=1.0),
    'detection': A.Compose([
        A.Resize(224, 224, always_apply=True),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ], p=1.0)
}

In [None]:
with open('/home/jovyan/vladimir/fusion_brain/FusionBrain/fb_baseline/configs/ctc_chars.txt') as f:
    ctc_labeling = CTCLabeling(f.read())
model_name = 'gpt2-medium'
gpt_tokenizer = GPT2Tokenizer.from_pretrained(model_name, pad_token='<|pad|>')

gpt_model = GPT2Model.from_pretrained(model_name)
gpt_model.resize_token_embeddings(len(gpt_tokenizer))

In [None]:
test_datasets = {
    "handwritten": HTRDataset(df_handwritten[:100], ctc_labeling, 512, 128, task_augs),
    "c2c": C2CDataset(df_c2c[:10], gpt_tokenizer, 300, 250, 'test'),
    "vqa": VQADataset(df_vqa[:100], gpt_tokenizer, 21, 8, 'test', task_augs),
    "detection": DetectionDataset(df_detection[:100], gpt_tokenizer, 21, 'test', task_augs)
}

In [None]:
# Как выглядят семплы для каждой задачи
def demo_sample(sample):
    if sample['task_id'] == 'handwritten':
        print('[gt_text]:', sample['gt_text'])
        return io.imshow(sample['image'].permute(1,2,0).numpy())
    elif sample['task_id'] == 'c2c':
        print('[source_text]:', gpt_tokenizer.decode(sample['input_ids'].numpy(), skip_special_tokens=True))
        print('[target_text]:', sample['target'])
        return
    elif sample['task_id'] == 'detection':
        print('[input_ids]:', [
            gpt_tokenizer.decode(input_ids.numpy(), skip_special_tokens=True)
            for input_ids in sample['input_ids']
        ])
        print('[boxes]:', sample['boxes'])
        return
    elif sample['task_id'] == 'vqa':
        print('[question]:', gpt_tokenizer.decode(sample['input_ids'].numpy(), skip_special_tokens=True))
        print('[answers]:', sample['target'])
        return
    return sample

In [None]:
demo_sample(test_datasets['handwritten'][np.random.randint(len(test_datasets['handwritten']))])

In [None]:
demo_sample(test_datasets['c2c'][np.random.randint(len(test_datasets['c2c']))])

In [None]:
demo_sample(test_datasets['vqa'][np.random.randint(len(test_datasets['vqa']))])

In [None]:
demo_sample(test_datasets['detection'][np.random.randint(len(test_datasets['detection']))])

# Model

In [None]:
handwritten_config = {
    'patch_w': 8,
    'patch_h': 128,
    'in_layer_sizes': [8*128*3],
    'out_layer_sizes': [64],
    'orth_gain': 1.41,
    'dropout': 0.1,
    'lstm_num_layers': 3,
    'output_dim': len(ctc_labeling),
}

vqa_config = {
    'tokens_num': len(gpt_tokenizer),
}

detection_config = {
    'num_mlp_layers': 3,
    'num_queries': 12
}

In [None]:
model = InverseAttentionGPT2FusionBrain(
    gpt_model,
    handwritten_config=handwritten_config,
    vqa_config=vqa_config,
    detection_config=detection_config
)

# Test

In [None]:
def run_evaluation(task, loader, model, threshold=None, tokenizer=None, device=torch.device('cuda:0')):
    result = []
    true_json_detection = {}
    pred_json_detection = {}
    with torch.no_grad():
        for batch in tqdm(loader):
            if task == 'handwritten':
                htr_images, encoded, encoded_length, gt_texts = batch
                images = htr_images.to(device)
                handwritten_outputs = model('handwritten', images=images)
                for encoded, gt_text in zip(handwritten_outputs.argmax(2).data.cpu().numpy(), gt_texts):
                    pred_text = ctc_labeling.decode(encoded)
                    result.append({
                        'task_id': 'handwritten',
                        'gt_output': gt_text,
                        'pred_output': pred_text,
                    })

            if task == 'c2c':
                code_input_ids, code_attention_masks, code_targets = batch
                input_ids = code_input_ids.to(device)
                attention_masks = code_attention_masks.to(device)
                _, hidden_states = model('c2c', input_ids=input_ids)
                bleu_score, _ = eval_bleu(model, hidden_states, input_ids=input_ids, beam_size=5, tokenizer=tokenizer, targets=code_targets)
                result.append({
                        'task_id': 'trans',
                        'true_text': code_targets,
                        'bleu_score': bleu_score,
                })

            if task == 'vqa':
                vqa_images, vqa_input_ids, _, targets = batch
                images = vqa_images.to(device)
                input_ids = vqa_input_ids.to(device)
                vqa_outputs = inverse_vqa_evaluation(model, images, input_ids, 10)
                for target, pred_labels in zip(targets, vqa_outputs.argmax(-1).cpu().numpy()):
                    result.append({
                        'task_id': 'vqa',
                        'gt_output': target,
                        'pred_output': gpt_tokenizer.decode(pred_labels).split('.')[0],
                    })

            if task == 'detection':
                detection_names, detection_images, detection_input_ids, _, boxes, size = batch
                images = detection_images.to(device)
                input_ids = [input_id.unsqueeze(0).to(device) for input_id in detection_input_ids[0]]
                detection_outputs = inverse_detection_evaluation(model, images, input_ids, threshold)
                img_h, img_w = size[0]
                for i in range(len(detection_outputs)):
                    if detection_outputs[i].numel() != 0:
                        detection_outputs[i][:,[0, 2]] *= img_w
                        detection_outputs[i][:, [1, 3]] *= img_h
                    detection_outputs[i] = detection_outputs[i].type(torch.int32).cpu().tolist()
                image_name = detection_names[0]
                true_json_detection[image_name] = {}
                pred_json_detection[image_name] = {}
                for requeste, pred_boxes, real_boxes in zip(detection_input_ids[0], detection_outputs, boxes[0]):
                    true_json_detection[image_name][gpt_tokenizer.decode(requeste.numpy())[9:-1]] = real_boxes
                    pred_json_detection[image_name][gpt_tokenizer.decode(requeste.numpy())[9:-1]] = pred_boxes
                result.append({
                        'task_id': 'detection',
                    })

    result = pd.DataFrame(result)

    handwritten_result = result[result['task_id'] == 'handwritten']
    if handwritten_result.shape[0]:
        print('= Handwritten =')
        print('CER:', round(cer(handwritten_result['pred_output'], handwritten_result['gt_output']), 3))
        print('WER:', round(wer(handwritten_result['pred_output'], handwritten_result['gt_output']), 3))
        print('ACC:', round(string_accuracy(handwritten_result['pred_output'], handwritten_result['gt_output']), 3))
        print('=== === === ===')

    trans_result = result[result['task_id'] == 'trans']
    if trans_result.shape[0]:
        print('== C2C ==')
        print('meanBLEU:', np.mean(trans_result['bleu_score']))
        print('=== === === ===')

    vqa_result = result[result['task_id'] == 'vqa']
    if vqa_result.shape[0]:
        print('== VQA ==')
        print('ACC:', round(vqa_evaluate(vqa_result), 3))
        print('=== === === ===')


    if len(true_json_detection):
        print('== Detection ==')
        print('ACC:', round(detection_evaluate(true_json_detection, pred_json_detection), 3))
        print('=== === === ===')

    return round(detection_evaluate(true_json_detection, pred_json_detection), 3)

In [None]:
model = model.to(device)
model.eval()
state_dict = torch.load(
    '/home/jovyan/vladimir/fusion_brain/experiments/fusion/checkpoints/weight-epoch=02-v1.ckpt'
)['state_dict']
state_dict = OrderedDict({key[6:]: value for key, value in state_dict.items()})
model.load_state_dict(state_dict)

# Handwritten

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_datasets['handwritten'],
    batch_size=1,
    sampler=SequentialSampler(test_datasets['handwritten']),
    pin_memory=False,
    drop_last=False,
    num_workers=2,
    collate_fn=htr_collate_fn,
)

evaluation_result = run_evaluation('handwritten', test_loader, model, tokenizer=gpt_tokenizer)

# C2C

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_datasets['c2c'],
    batch_size=1,
    sampler=SequentialSampler(test_datasets['c2c']),
    pin_memory=False,
    drop_last=False,
    num_workers=2,
    collate_fn=c2c_collate_fn,
)

evaluation_result = run_evaluation('c2c', test_loader, model, tokenizer=gpt_tokenizer)

# VQA

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_datasets['vqa'],
    batch_size=1,
    sampler=SequentialSampler(test_datasets['vqa']),
    pin_memory=False,
    drop_last=False,
    num_workers=2,
    collate_fn=vqa_collate_fn,
)

evaluation_result = run_evaluation('vqa', test_loader, model, tokenizer=gpt_tokenizer)

# Detection

In [None]:
test_loader = torch.utils.data.DataLoader(
    test_datasets['detection'],
    batch_size=1,
    sampler=SequentialSampler(test_datasets['detection']),
    pin_memory=False,
    drop_last=False,
    num_workers=2,
    collate_fn=detection_collate_fn,
)

evaluation_result = run_evaluation('detection', test_loader, model, 0.0, tokenizer=gpt_tokenizer)