In [1]:
# #!git clone https://github.com/sberbank-ai/fusion_brain_aij2021.git
# !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
# !pip install tpu_star==0.0.1rc10
# !pip install albumentations==0.5.2
# !pip install einops==0.3.2
# !pip install pytorch_lightning
# !pip install comet_ml
# !pip install transformers==4.10.0 
# !pip install colorednoise==1.1.1
# !pip install catalyst==21.8 
# !pip install opencv-python==4.5.3
# !pip install gdown==4.0.2
# !pip install pymorphy2

In [None]:
# Доступные ресурсы
import multiprocessing
import torch
from psutil import virtual_memory

ram_gb = round(virtual_memory().total / 1024**3, 1)

print('CPU:', multiprocessing.cpu_count())
print('RAM GB:', ram_gb)

print("PyTorch version:", torch.__version__)
print("CUDA version:", torch.version.cuda)
print("cuDNN version:", torch.backends.cudnn.version())
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device:", device.type)

!nvidia-smi

In [None]:
import json
import random
from collections import OrderedDict

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score
import albumentations as A
from transformers import GPT2Model, GPT2Tokenizer

from fb_utils.utils import simple_detect_lang
from fb_utils.metrics import cer, wer, string_accuracy, acc, vqa_evaluate, detection_evaluate
from fb_utils.c2c_eval import Beam, eval_bleu
from fb_utils.detection_eval import cross_detection_evaluation, inverse_detection_evaluation
from fb_utils.vqa_eval import cross_vqa_evaluation, inverse_vqa_evaluation
from model.utils.utils import CTCLabeling
from model.dataset.dataset import DatasetRetriever, fb_collate_fn
from model.model import CrossAttentionGPT2FusionBrain, InverseAttentionGPT2FusionBrain

# Dataset

In [None]:
# Подготовка данных и сбор в единый DataFrame
# #
# Handwritten
# #
json_marking = json.load(open('true_fb/true_HTR.json', 'rb'))
marking = []
for image_name, text in json_marking.items():
    if '%' not in text:
        marking.append({
            'path': image_name,
            'text': text,
            'lang': simple_detect_lang(text),
        })
df_handwritten = pd.DataFrame(marking)
df_handwritten['stage'] = 'test'
# Detection
marking = []
json_true_zsod_test = json.load(open('true_fb/true_zsOD.json', 'rb'))
for image_name in json_true_zsod_test:
    marking.append({
        'task_id': 'detection',
        'path': image_name,
        'req': ';'.join([request for request in json_true_zsod_test[image_name].keys()]),
        'boxes': [boxes for boxes in json_true_zsod_test[image_name].values()],
    })
df_detection = pd.DataFrame(marking)
df_detection['stage'] = 'test'
# #
# VQA
# #
json_questions = json.load(open('private_fb/VQA/questions.json', 'rb'))
json_answers = json.load(open('true_fb/true_VQA.json', 'rb'))
marking = []
for key in json_questions:
    marking.append({
        'path': json_questions[key]['file_name'],
        'question': json_questions[key]['question'],
        'answer': json_answers[key]['answer'][0],
        'lang': simple_detect_lang(json_answers[key]['answer'][0])
    })
df_vqa = pd.DataFrame(marking)
df_vqa['stage'] = 'test'
# #
# C2C
# #
java_json = json.load(open('private_fb/C2C/requests.json', 'rb'))
python_json = json.load(open('true_fb/true_C2C.json', 'rb'))
marking = []
for key in java_json:
    marking.append({
        'java': java_json[key],
        'python': python_json[key],
    })
df_c2c = pd.DataFrame(marking)
df_c2c['stage'] = 'test'


# #
# Merge in common set
# #
dataset = []
for image_name, text, stage in zip(df_handwritten['path'], df_handwritten['text'], df_handwritten['stage']):
    dataset.append({
        'task_id': 'handwritten',   
        'modality': 'image', 
        'input_image': image_name,
        'output_text': text,
        'stage': stage,
    })
    
for java, python, stage in zip(df_c2c['java'], df_c2c['python'], df_c2c['stage']):
    dataset.append({
        'task_id': 'trans',
        'modality': 'code',    
        'input_text': java,
        'output_text': python,
        'stage': stage,
    })
    
for image_name, text_input, text_output, stage in zip(df_vqa['path'], df_vqa['question'], df_vqa['answer'], df_vqa['stage']):
    dataset.append({
        'task_id': 'vqa', 
        'modality': 'image+text', 
        'input_image': image_name,
        'input_text': text_input,
        'output_text': text_output,
        'stage': stage,
    })
for image_name, text_input, boxes, stage in zip(df_detection['path'], df_detection['req'], df_detection['boxes'], df_detection['stage']):
    dataset.append({
        'task_id': 'detection', 
        'modality': 'image+text', 
        'input_image': image_name,
        'input_text': text_input,
        'output_boxes': boxes,
        'stage': stage,
    })

random.shuffle(dataset)
df = pd.DataFrame(dataset)
df.head(10)

In [None]:
task_augs = {
    'handwritten': A.Compose([
        A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.25, always_apply=False),
        A.Rotate(limit=3, interpolation=1, border_mode=0, p=0.5),
        A.JpegCompression(quality_lower=75, p=0.5),
    ], p=1.0),
    'vqa': A.Compose([
        A.Resize(224, 224, always_apply=True),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ], p=1.0),
    'detection': A.Compose([
        A.Resize(224, 224, always_apply=True),
        A.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ], p=1.0)
}

In [None]:
df_eval = df[df['stage'] == 'test']

In [None]:
# Подготовка предобученной модели и токенизатора, а также CTC Labeling для задачи распознавания рукописного текста
CHARS = ' !"#&\'()*+,-./0123456789:;<=>?ABCDEFGHIJKLMNOPQRSTUVWXYZ' + \
        '[]_abcdefghijklmnopqrstuvwxyz|}ЁАБВГДЕЖЗИКЛМНОПРСТУФХЦЧШЩЫЭЮЯабвгдежзийклмнопрстуфхцчшщъыьэюяё№'
ctc_labeling = CTCLabeling(CHARS)
model_name = '/home/jovyan/vladimir/fb_baseline/gpt3_medium_py'
gpt_tokenizer = GPT2Tokenizer.from_pretrained(model_name, bos_token='<s>',
                                              eos_token='</s>', pad_token='<pad>', unk_token='<|UNKNOWN|>',
                                              sep_token='<|SEP|>')

gpt_model = GPT2Model.from_pretrained(model_name)
gpt_model.resize_token_embeddings(len(gpt_tokenizer))

In [None]:
eval_dataset = DatasetRetriever(
    task_ids=df_eval['task_id'].values,
    input_images=df_eval['input_image'].values,
    input_texts=df_eval['input_text'].values,
    output_texts=df_eval['output_text'].values,
    output_boxes=df_eval['output_boxes'].values,
    stage='test',
    ctc_labeling=ctc_labeling,
    tokenizer=gpt_tokenizer,
    max_request_tokens_length=21,
    vqa_max_tokens_length=21,
    task_augs=task_augs,
)

In [None]:
# Как выглядят семплы для каждой задачи
def demo_sample(sample):
    if sample['task_id'] == 'handwritten':
        print('[gt_text]:',sample['gt_text'])
        return io.imshow(sample['image'].permute(1,2,0).numpy())
    elif sample['task_id'] == 'trans':
        print('[source_text]:', gpt_tokenizer.decode(sample['input_ids'].numpy(), skip_special_tokens=True))
        print('[target_text]:', sample['target'])
        return
    elif sample['task_id'] == 'detection':
        print('[boxes]:', sample['boxes'])
        return
    elif sample['task_id'] == 'vqa':
        print('[question and answer]:', gpt_tokenizer.decode(sample['input_ids'].numpy(), skip_special_tokens=True))
        return
    
    return sample

In [None]:
demo_sample(eval_dataset[(eval_dataset.task_ids == 'detection').argmax()])

In [None]:
demo_sample(eval_dataset[(eval_dataset.task_ids == 'trans').argmax()])

In [None]:
demo_sample(eval_dataset[(eval_dataset.task_ids == 'vqa').argmax()])

In [None]:
demo_sample(eval_dataset[(eval_dataset.task_ids == 'handwritten').argmax()])

# Model

In [None]:
handwritten_config = {
    'patch_w': 8,
    'patch_h': 128,
    'in_layer_sizes': [8*128*3],
    'out_layer_sizes': [64],
    'orth_gain': 1.41,
    'dropout': 0.1,
    'lstm_num_layers': 3,
    'output_dim': len(ctc_labeling), # 152
}

attention_config = {
    'num_attention_layers': 2,
    'num_heads': 8,
    'pf_dim': 2048,
}

vqa_config = {
    'tokens_num': len(gpt_tokenizer),
}

detection_config = {
    'num_mlp_layers': 1,
    'num_queries': 8
}

In [None]:
model = CrossAttentionGPT2FusionBrain(
    gpt_model,
    attention_config=attention_config,
    handwritten_config=handwritten_config,
    vqa_config=vqa_config,
    detection_config=detection_config
) # InverseAttentionGPT2FusionBrain

# Experiments

In [None]:
def run_evaluation(loader, model, tokenizer=None, device=torch.device('cuda:0')):
    result = []
    true_json_detection = {}
    pred_json_detection = {}
    model.to(device)
    model.eval()
    with torch.no_grad():
        for batch in tqdm(loader):
            (htr_images, encoded, encoded_length, gt_texts), (code_input_ids, code_input_labels, code_targets), (vqa_images, vqa_input_ids, labels, targets), (detection_names, detection_images, detection_input_ids, detection_attention_masks, boxes, size) = batch
            if len(htr_images) > 0:
                images = htr_images.to(device, dtype=torch.float32)
                handwritten_outputs = model('handwritten', images=images)
                for encoded, gt_text in zip(handwritten_outputs.argmax(2).data.cpu().numpy(), gt_texts):
                    pred_text = ctc_labeling.decode(encoded)
                    result.append({
                        'task_id': 'handwritten',
                        'gt_output': gt_text,
                        'pred_output': pred_text,
                    })

            if len(code_input_ids) > 0:
                code_input_ids = code_input_ids.to(device, dtype=torch.long)
                code_input_labels = code_input_labels.to(device, dtype=torch.long)
                loss_mask = torch.tensor(code_input_labels.clone().detach() == 2, dtype=torch.uint8)
                loss_mask = loss_mask.to(device)
                hidden_states = model('trans', input_ids=code_input_ids, input_labels=code_input_labels, eval_bleu=True)
                bleu_score, _ = eval_bleu(model, hidden_states, input_ids=code_input_ids, beam_size=5, tokenizer=tokenizer, targets=code_targets)
                result.append({
                        'task_id': 'trans',
                        'true_text': code_targets,
                        'bleu_score': bleu_score,
                })
                
            if len(labels) > 0:
                images = vqa_images.to(device, dtype=torch.float32)
                input_ids = vqa_input_ids.to(device, dtype=torch.long)
                labels = labels.to(device, dtype=torch.float) 
                attention_mask = torch.tensor(labels.clone().detach() != 0, dtype=torch.uint8)
                attention_mask = attention_mask.to(labels.device)
                vqa_outputs = cross_vqa_evaluation(model, images, input_ids, attention_mask, 10)
                for target, pred_labels in zip(targets, vqa_outputs.argmax(-1).cpu().numpy()):
                    result.append({
                        'task_id': 'vqa',
                        'gt_output': target,
                        'pred_output': gpt_tokenizer.decode(pred_labels),
                    })
                    
            if len(boxes) > 0:
                images = detection_images.to(device, dtype=torch.float32)
                input_ids = [input_id.to(device, dtype=torch.long) for input_id in detection_input_ids]
                attention_masks = [attention_mask.to(device, dtype=torch.long) for attention_mask in detection_attention_masks]
                detection_outputs = cross_detection_evaluation(model, images, input_ids, attention_masks, 0.2)
                img_h, img_w = size[0]
                for i in range(len(detection_outputs)):
                    if detection_outputs[i].numel() != 0:
                        detection_outputs[i][:, 0] = detection_outputs[i][:, 0] * img_w
                        detection_outputs[i][:, 2] = detection_outputs[i][:, 2] * img_w
                        detection_outputs[i][:, 1] = detection_outputs[i][:, 1] * img_h
                        detection_outputs[i][:, 3] = detection_outputs[i][:, 3] * img_h
                image_name = detection_names[0]
                for boxes_for_img in boxes:
                    true_json_detection[image_name] = boxes_for_img
                    pred_json_detection[image_name] = {
                        input_text: output.type(torch.int32).cpu().tolist()
                        for input_text, output in zip(boxes_for_img.keys(), detection_outputs)
                    }
                result.append({
                        'task_id': 'detection',
                    })
                
    result = pd.DataFrame(result)

    handwritten_result = result[result['task_id'] == 'handwritten']
    if handwritten_result.shape[0]:
        print('= Handwritten =')
        print('CER:', round(cer(handwritten_result['pred_output'], handwritten_result['gt_output']), 3))
        print('WER:', round(wer(handwritten_result['pred_output'], handwritten_result['gt_output']), 3))
        print('ACC:', round(string_accuracy(handwritten_result['pred_output'], handwritten_result['gt_output']), 3))
        print('=== === === ===')
        
    trans_result = result[result['task_id'] == 'trans']   
    if trans_result.shape[0]:
        print('== C2C ==')
        print('meanBLEU:', np.mean(trans_result['bleu_score']))
        print('=== === === ===')
        
    vqa_result = result[result['task_id'] == 'vqa']
    if vqa_result.shape[0]:
        print('== VQA ==')
        print('ACC:', round(vqa_evaluate(vqa_result), 3))
        print('=== === === ===')
        
    
    if len(true_json_detection):
        print('== Detection ==')
        print('ACC:', round(detection_evaluate(true_json_detection, pred_json_detection), 3))
        print('=== === === ===')

    return result

In [None]:
vqa_valid = df[(df['task_id'] == 'detection') & (df['stage'] == 'test')]

vqa_eval_dataset = DatasetRetriever(
    task_ids=vqa_valid['task_id'].values,
    input_images=vqa_valid['input_image'].values,
    input_texts=vqa_valid['input_text'].values,
    output_texts=vqa_valid['output_text'].values,
    output_boxes=vqa_valid['output_boxes'].values,
    stage='test',
    ctc_labeling=ctc_labeling,
    tokenizer=gpt_tokenizer,
    max_request_tokens_length=21,
    vqa_max_tokens_length=21,
    task_augs=task_augs,
)

model = model.to(device)
state_dict = torch.load('LightningExperimentsNew/main_concat_small/checkpoints/weight-epoch=118.ckpt')['state_dict']
state_dict = OrderedDict({key[6:]: value for key, value in state_dict.items()})
model.load_state_dict(state_dict)

valid_loader = torch.utils.data.DataLoader(
    vqa_eval_dataset,
    batch_size=1,
    sampler=SequentialSampler(vqa_eval_dataset),
    pin_memory=False,
    drop_last=False,
    num_workers=2,
    collate_fn=fb_collate_fn,
)

evaluation_result = run_evaluation(valid_loader, model, tokenizer=gpt_tokenizer)