In [None]:
from datasets import load_dataset

dataset = load_dataset("Trailblazer-Yoo/boostcamp-docvqa")

In [None]:
import numpy as np
import json
import torch
import editdistance
import transformers
import random
import gc
import time
import wandb
import os

from datasets import load_dataset, Dataset
from tqdm.auto import tqdm
import pandas as pd
from PIL import Image
from transformers import LayoutLMv2FeatureExtractor
from transformers import AutoTokenizer


os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = 'cuda'


from heapq import heappush
from nltk.tag import pos_tag
from typing import *
import re

def NLD(s1:str,s2:str) -> float:
    return editdistance.eval(s1.lower(),s2.lower()) / ((len(s1)+len(s2))/2) # normalized_levenshtein_distance

def check_answer(answer_list:List[str], words_list:List[str], boundary:int=0.2) -> float:
    similarity_score = 0
    for answer, word in zip(answer_list, words_list):
        # print('answer word', answer, word)
        ld_score = NLD(answer, word)

        # 각 단어의 레벤슈타인 거리가 0.2보다 크면 너무 차이가 많아서 정답이 아님
        if ld_score >= boundary:
            return 100 # 레벤슈타인의 최대값은 1이므로 100은 나올 수 없음

        else: similarity_score += ld_score
        
    return similarity_score / len(answer_list) # 가장 좋은 값은 0.0
    
def calculate_euclidean_mean(question_points:List[Tuple[int]], boxes:List[Tuple[int]]):
    euclidean = 0
    a_point = ((boxes[2] + boxes[0])/2, (boxes[3] + boxes[1])/2)
    for q_point in question_points:
        euclidean += ((a_point[0] - q_point[0])**2 + (a_point[1] - q_point[1])**2)**(1/2)
        
    return euclidean / len(question_points)

 
def clean_text(raw_string:str) -> str:
    #텍스트에 포함되어 있는 특수 문자 제거
    text = re.sub('[-=+,#/\?^$.@*\※~ㆍ!…]','', raw_string)
 
    return text
 
def find_noun_ngram(questions:str, ngram:int) -> Set[Tuple[str]]:
    if ngram == 1: # unigram일 경우
        part_of_speech = {'NN', 'NNS','NNP', 'NNPS', 'POS','RP', 'CD', 'FW', 'VBG'}
    else:
        part_of_speech = {'NN', 'NNS','NNP', 'NNPS', 'IN', 'POS','RP', 'CD', 'FW', 'VBG', 'JJR', 'JJS', 'RBR', 'RBS'}
        
    result = set()
    questions:List[str] = questions.split()
    ngram_questions = [questions[i:i+ngram] for i in range(len(questions) - (ngram-1))]
    for question in ngram_questions:
        tmp_storage = []
        for tag in pos_tag(question):
            if tag[1] in part_of_speech:
                tmp_storage.append(clean_text(tag[0]))
                
        if len(tmp_storage) == ngram:
            result.add(tuple(tmp_storage))

    yield from result

def find_points(questions:str, words_list:List[str], boxes:List[List[int]], ngrams:int=3) -> List[Tuple[int]]:
    question_words = sum([[ngram_question for ngram_question in find_noun_ngram(questions, ngram)] for ngram in range(1, ngrams+1)], [])
    # boxes : (x1, y1, x2, y2)
    result = []
    for question in question_words:
        question_list = list(question)
        search_range = len(words_list) - (len(question_list)-1)
        for idx, i in enumerate(range(search_range)):
            nld = check_answer(question_list, words_list[i:i+len(question_list)], boundary=0.2)
            if nld != 100:
                if len(question) % 2 == 0:
                    bb1 =boxes[idx+(len(question)//2)-1]
                    bb2 = boxes[idx+(len(question)//2)]
                    # 
                    bp1 = ((bb1[2] + bb1[0])/2, (bb1[3] + bb1[1])/2)
                    bp2 = ((bb2[2] + bb2[0])/2, (bb2[3] + bb2[1])/2)
                    result.append(((bb2[0] + bb1[0])/2, (bb2[1] + bb1[1])/2))
                else:
                    bb =boxes[idx+ (len(question)//2)]
                    result.append(((bb[2] + bb[0])/2, (bb[3] + bb[1])/2))
                    
    return result

def find_candidates(answer_list:List[str], words_list:List[str], questions, boxes):
    nld_l = []
    search_range = len(words_list) - (len(answer_list)-1)
    for idx, i in enumerate(range(len(words_list))):
        nld = check_answer(answer_list, words_list[i:i+len(answer_list)])
        if nld != 100:
            # 각 원소 : normalized_levenshtein_distance, answer, start_idx, end_idx
            nld_l.append((nld, answer_list, idx, idx+len(answer_list)-1))
                
    if nld_l:
        nld_l.sort(key=lambda x: x[0]) # nld 최솟값 정렬
        if len(nld_l) == 1: # 하나만 뽑힘
            return nld_l[0][1], nld_l[0][2], nld_l[0][3]
        
        elif nld_l[0][0] == nld_l[1][0]: # 여러개 뽑힌 것들 중에 동일한 NLD 존재
            # 핵심 단어와의 유클리디안 거리를 통해 동일한 정답 중에서 최적을 선택
            question_points = find_points(questions, words_list, boxes, ngrams=3)
            
            if not question_points:
                return nld_l[0][1], nld_l[0][2], nld_l[0][3]
          
            candidates = []
            for q in nld_l:
                if q[0] == nld_l[0][0]:
                    euc_dist = calculate_euclidean_mean(question_points, boxes[q[2]])
                    heappush(candidates, [euc_dist, q[1], q[2], q[3]])
                else:
                    break

            return candidates[0][1], candidates[0][2], candidates[0][3]
        else: # 여러개 뽑힌 것들 중에 첫번째가 가장 적합함
            return nld_l[0][1], nld_l[0][2], nld_l[0][3]
    else:
        return None, 0, 0

In [None]:
import torch
from transformers import AutoTokenizer

class DocVQADataset(torch.utils.data.Dataset):
    def __init__(self,split):
        datasets = load_dataset("Trailblazer-Yoo/boostcamp-docvqa")
        if split=='train':
          self.dataset = datasets['train']
        else:
          self.dataset = datasets['val']

        try:
          model_checkpoint = "microsoft/layoutlmv2-large-uncased"
          self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        except:
          model_checkpoint = "microsoft/layoutlmv2-large-uncased"
          self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

    def encode_dataset(self,example, max_length=512):
          # take a batch 
          questions = example['question']
          words = [w for w in example['words']] #handles numpy and list
          boxes = example['boxes']

          # encode it
          encoding = self.tokenizer([questions], [words], [boxes], max_length=max_length, padding="max_length", truncation=True,return_tensors="pt")
          batch_index=0
          input_ids = encoding.input_ids[batch_index].tolist()

          # next, add start_positions and end_positions
          start_positions = []
          end_positions = []
          answers = example['answers']
          #print("Batch index:", batch_index)
          cls_index = input_ids.index(self.tokenizer.cls_token_id)
          # try to find one of the answers in the context, return first match
          for answer in answers:
            # 바뀐 부분 ####################################################################################################################
            answer_list = answer.split()
            match, word_idx_start, word_idx_end = find_candidates(answer_list, words, questions, boxes)
            #if match:
            #  break
            # EXPERIMENT (to account for when OCR context and answer don't perfectly match):
            if not match and len(answer)>1:
              for i in range(len(answer), 0, -1):
                  answer_i_list = (answer[:i-1] + answer[i:]).split()
                  match, word_idx_start, word_idx_end = find_candidates(answer_i_list, words, questions, boxes)
                  if match:
                    break
            #바뀐 부분 ####################################################################################################################
            # END OF EXPERIMENT 
            if match:
              sequence_ids = encoding.sequence_ids(batch_index)
              # Start token index of the current span in the text.
              token_start_index = 0
              while sequence_ids[token_start_index] != 1:
                  token_start_index += 1

              # End token index of the current span in the text.
              token_end_index = len(input_ids) - 1
              while sequence_ids[token_end_index] != 1:
                  token_end_index -= 1
              
              word_ids = encoding.word_ids(batch_index)[token_start_index:token_end_index+1]

              hit=False
              for id in word_ids:
                if id == word_idx_start:
                  start_positions.append(token_start_index)
                  hit=True
                  break
                else:
                  token_start_index += 1

              if not hit:
                  continue
        
              hit=False
              for id in word_ids[::-1]:
                if id == word_idx_end:
                  end_positions.append(token_end_index)
                  hit=True
                  break
                else:
                  token_end_index -= 1

              if not hit:
                  end_positions.append(token_end_index)
              
              #print("Verifying start position and end position:")
              #print("True answer:", answer)
              #start_position = start_positions[-1]
              #end_position = end_positions[-1]
              #reconstructed_answer = tokenizer.decode(encoding.input_ids[batch_index][start_position:end_position+1])
              #print("Reconstructed answer:", reconstructed_answer)
              #print("-----------")
            
            #else:
              #print("Answer not found in context")
              #print("-----------")
              #start_positions.append(cls_index)
              #end_positions.append(cls_index)

          if len(start_positions)==0:
              return None
        
          ans_i = random.randrange(len(start_positions))

          encoding = {
                  'input_ids': encoding['input_ids'],
                  'attention_mask': encoding['attention_mask'],
                  'token_type_ids': encoding['token_type_ids'],
                  'bbox': encoding['bbox'],
                  'answers' : answers
                  }
          ## 바뀐 부분 example['image'].copy() -> example['image'].copy()[0]
          encoding['image'] = torch.LongTensor(example['image'].copy()[0])
          encoding['start_position'] = torch.LongTensor([start_positions[ans_i]])
          encoding['end_position'] = torch.LongTensor([end_positions[ans_i]])

          return encoding

    def __len__(self):
      return len(self.dataset)

    def __getitem__(self,index):
      data = self.dataset[index]
      data = self.encode_dataset(data)

      if data is None:
                #return self.__getitem__((index+1)%len(self))
        index = random.randrange(len(self))
        return self.__getitem__(index)

      return data

def collate(data):
    return {
            'input_ids': torch.cat([d['input_ids'] for d in data],dim=0),
            'attention_mask': torch.cat([d['attention_mask'] for d in data],dim=0),
            'token_type_ids': torch.cat([d['token_type_ids'] for d in data],dim=0),
            'bbox': torch.cat([d['bbox'] for d in data],dim=0),
            'image': torch.stack([d['image'] for d in data],dim=0),
            'start_positions': torch.cat([d['start_position'] for d in data],dim=0),
            'end_positions': torch.cat([d['end_position'] for d in data],dim=0),
            'answers': [d['answers'] for d in data],
            }

In [None]:
dataset = DocVQADataset('train')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2,collate_fn=collate, shuffle=True)
sample = next(iter(dataloader))
print(sample)

In [None]:
from transformers import AutoModelForQuestionAnswering

org_model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-large-uncased")

In [None]:
sample.keys()

In [None]:
enc_result = org_model.layoutlmv2(input_ids = sample['input_ids'],
                    attention_mask = sample['attention_mask'],
                    token_type_ids = sample['token_type_ids'],
                    bbox = sample['bbox'],
                    image = sample['image']    
                )

In [None]:
enc_result['last_hidden_state'].shape

# Decoder

In [None]:
import torch.nn as nn
from transformers import MBartConfig, MBartForCausalLM, AutoTokenizer
import torch.nn.functional as F
from transformers.file_utils import ModelOutput

class BARTDecoder(nn.Module):
    """
    Donut Decoder based on Multilingual BART
    Set the initial weights and configuration with a pretrained multilingual BART model,
    and modify the detailed configurations as a Donut decoder

    Args:
        decoder_layer:
            Number of layers of BARTDecoder
        max_position_embeddings:
            The maximum sequence length to be trained
        name_or_path:
            Name of a pretrained model name either registered in huggingface.co. or saved in local,
            otherwise, `hyunwoongko/asian-bart-ecjk` will be set (using `transformers`)
    """

    def __init__(
        self, decoder_layer: int, max_position_embeddings: int, name_or_path: Union[str, bytes, os.PathLike] = None
    ):
        super().__init__()
        self.decoder_layer = decoder_layer
        self.max_position_embeddings = max_position_embeddings

        self.tokenizer = AutoTokenizer.from_pretrained(
             "microsoft/layoutlmv2-base-uncased" if not name_or_path else name_or_path
        )

        self.model = MBartForCausalLM(
            config=MBartConfig(
                is_decoder=True,
                is_encoder_decoder=False,
                add_cross_attention=True,
                decoder_layers=self.decoder_layer,
                max_position_embeddings=self.max_position_embeddings,
                vocab_size=len(self.tokenizer),
                scale_embedding=True,
                add_final_layer_norm=True,
            )
        )
        self.model.forward = self.forward  #  to get cross attentions and utilize `generate` function

        self.model.config.is_encoder_decoder = True  # to get cross-attention
        self.add_special_tokens(["<sep/>"])  # <sep/> is used for representing a list in a JSON
        self.model.model.decoder.embed_tokens.padding_idx = self.tokenizer.pad_token_id
        self.model.prepare_inputs_for_generation = self.prepare_inputs_for_inference

        # weight init with asian-bart
        if not name_or_path:
            bart_state_dict = MBartForCausalLM.from_pretrained("hyunwoongko/asian-bart-ecjk").state_dict()
            new_bart_state_dict = self.model.state_dict()
            for x in new_bart_state_dict:
                if x.endswith("embed_positions.weight") and self.max_position_embeddings != 1024:
                    new_bart_state_dict[x] = torch.nn.Parameter(
                        self.resize_bart_abs_pos_emb(
                            bart_state_dict[x],
                            self.max_position_embeddings
                            + 2,  # https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L118-L119
                        )
                    )
                elif x.endswith("embed_tokens.weight") or x.endswith("lm_head.weight"):
                    new_bart_state_dict[x] = bart_state_dict[x][: len(self.tokenizer), :]
                else:
                    new_bart_state_dict[x] = bart_state_dict[x]
            self.model.load_state_dict(new_bart_state_dict)

    def add_special_tokens(self, list_of_tokens: List[str]):
        """
        Add special tokens to tokenizer and resize the token embeddings
        """
        newly_added_num = self.tokenizer.add_special_tokens({"additional_special_tokens": sorted(set(list_of_tokens))})
        if newly_added_num > 0:
            self.model.resize_token_embeddings(len(self.tokenizer))

    def prepare_inputs_for_inference(self, input_ids: torch.Tensor, encoder_outputs: torch.Tensor, past=None, use_cache: bool = None, attention_mask: torch.Tensor = None):
        """
        Args:
            input_ids: (batch_size, sequence_lenth)
        Returns:
            input_ids: (batch_size, sequence_length)
            attention_mask: (batch_size, sequence_length)
            encoder_hidden_states: (batch_size, sequence_length, embedding_dim)
        """
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id).long()
        if past is not None:
            input_ids = input_ids[:, -1:]
        output = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "past_key_values": past,
            "use_cache": use_cache,
            "encoder_hidden_states": encoder_outputs.last_hidden_state,
        }
        return output

    def forward(
        self,
        input_ids,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        past_key_values: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        use_cache: bool = None,
        output_attentions: Optional[torch.Tensor] = None,
        output_hidden_states: Optional[torch.Tensor] = None,
        return_dict: bool = None,
    ):
        """
        A forward fucntion to get cross attentions and utilize `generate` function

        Source:
        https://github.com/huggingface/transformers/blob/v4.11.3/src/transformers/models/mbart/modeling_mbart.py#L1669-L1810

        Args:
            input_ids: (batch_size, sequence_length)
            attention_mask: (batch_size, sequence_length)
            encoder_hidden_states: (batch_size, sequence_length, hidden_size)

        Returns:
            loss: (1, )
            logits: (batch_size, sequence_length, hidden_dim)
            hidden_states: (batch_size, sequence_length, hidden_size)
            decoder_attentions: (batch_size, num_heads, sequence_length, sequence_length)
            cross_attentions: (batch_size, num_heads, sequence_length, sequence_length)
        """
        output_attentions = output_attentions if output_attentions is not None else self.model.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.model.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.model.config.use_return_dict
        outputs = self.model.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            past_key_values=past_key_values,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        logits = self.model.lm_head(outputs[0])

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return ModelOutput(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            decoder_attentions=outputs.attentions,
            cross_attentions=outputs.cross_attentions,
        )

    @staticmethod
    def resize_bart_abs_pos_emb(weight: torch.Tensor, max_length: int) -> torch.Tensor:
        """
        Resize position embeddings
        Truncate if sequence length of Bart backbone is greater than given max_length,
        else interpolate to max_length
        """
        if weight.shape[0] > max_length:
            weight = weight[:max_length, ...]
        else:
            weight = (
                F.interpolate(
                    weight.permute(1, 0).unsqueeze(0),
                    size=max_length,
                    mode="linear",
                    align_corners=False,
                )
                .squeeze(0)
                .permute(1, 0)
            )
        return weight

In [None]:
# config확인용 프린트 안해도 됨 / 확인하고 싶으면 찍어보셈
# from transformers import VisionEncoderDecoderConfig

# config = VisionEncoderDecoderConfig.from_pretrained("naver-clova-ix/donut-base")
# print(config)

In [None]:
decoder = BARTDecoder(decoder_layer = 4,
                    max_position_embeddings = 1536
)

In [None]:
decoder_result = decoder(input_ids = sample['input_ids'], 
                    encoder_hidden_states = enc_result['last_hidden_state']
)

In [None]:
decoder_result