In [1]:
import numpy as np
import json
import torch
import editdistance
import transformers
import random
import gc
import time
import wandb
import os

from datasets import load_dataset, Dataset
from tqdm.auto import tqdm
import pandas as pd
from PIL import Image
from transformers import LayoutLMv2FeatureExtractor
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = 'cuda'

In [3]:
def normalize_bbox(bboxes, width, height):
    a = [bboxes[0],bboxes[1],bboxes[4],bboxes[5]]
    return [
         int(1000*(a[0] / width)),
         int(1000*(a[1] / height)),
         int(1000*(a[2] / width)),
         int(1000*(a[3] / height)),
     ]

In [4]:
def fuzzy(s1,s2):
    return (editdistance.eval(s1,s2)/((len(s1)+len(s2))/2)) < 0.2

In [7]:
class DocVQADataset(torch.utils.data.Dataset):
    def __init__(self,split):
        datasets = load_dataset("Trailblazer-Yoo/boostcamp-docvqa")
        if split=='train':
          self.dataset = datasets['train']
        else:
          self.dataset = datasets['val']

        try:
          model_checkpoint = "microsoft/layoutlmv2-base-uncased"
          self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
        except:
          model_checkpoint = "microsoft/layoutlmv2-base-uncased"
          self.tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
    
    # source: https://stackoverflow.com/a/12576755
    def subfinder(self, words_list, answer_list):  
        matches = []
        start_indices = []
        end_indices = []
        for idx, i in enumerate(range(len(words_list))):
            #if words_list[i] == answer_list[0] and words_list[i:i+len(answer_list)] == answer_list:
            if len(words_list[i:i+len(answer_list)])==len(answer_list) and all(fuzzy(words_list[i+j],answer_list[j]) for j in range(len(answer_list))):
                matches.append(answer_list)
                start_indices.append(idx)
                end_indices.append(idx + len(answer_list) - 1)
        if matches:
          return matches[0], start_indices[0], end_indices[0]
        else:
          return None, 0, 0

    def encode_dataset(self,example, max_length=512):
          # take a batch 
          questions = example['question']
          words = [w for w in example['words']] #handles numpy and list
          boxes = example['boxes']

          # encode it
          encoding = self.tokenizer([questions], [words], [boxes], max_length=max_length, padding="max_length", truncation=True,return_tensors="pt")
          batch_index=0
          input_ids = encoding.input_ids[batch_index].tolist()

          # next, add start_positions and end_positions
          start_positions = []
          end_positions = []
          answers = example['answers']
          #print("Batch index:", batch_index)
          cls_index = input_ids.index(self.tokenizer.cls_token_id)
          # try to find one of the answers in the context, return first match
          words_example = [word.lower() for word in words]
          for answer in answers:
            match, word_idx_start, word_idx_end = self.subfinder(words_example, answer.lower().split())
            #if match:
            #  break
            # EXPERIMENT (to account for when OCR context and answer don't perfectly match):
            if not match and len(answer)>1:
                for i in range(len(answer)):
                  # drop the ith character from the answer
                  answer_i = answer[:i] + answer[i+1:]
                  # check if we can find this one in the context
                  match, word_idx_start, word_idx_end = self.subfinder(words_example, answer_i.lower().split())
                  if match:
                    break
            # END OF EXPERIMENT 
            if match:
              sequence_ids = encoding.sequence_ids(batch_index)
              # Start token index of the current span in the text.
              token_start_index = 0
              while sequence_ids[token_start_index] != 1:
                  token_start_index += 1

              # End token index of the current span in the text.
              token_end_index = len(input_ids) - 1
              while sequence_ids[token_end_index] != 1:
                  token_end_index -= 1
              
              word_ids = encoding.word_ids(batch_index)[token_start_index:token_end_index+1]

              hit=False
              for id in word_ids:
                if id == word_idx_start:
                  start_positions.append(token_start_index)
                  hit=True
                  break
                else:
                  token_start_index += 1

              if not hit:
                  continue
        
              hit=False
              for id in word_ids[::-1]:
                if id == word_idx_end:
                  end_positions.append(token_end_index)
                  hit=True
                  break
                else:
                  token_end_index -= 1

              if not hit:
                  end_positions.append(token_end_index)
              
              #print("Verifying start position and end position:")
              #print("True answer:", answer)
              #start_position = start_positions[-1]
              #end_position = end_positions[-1]
              #reconstructed_answer = tokenizer.decode(encoding.input_ids[batch_index][start_position:end_position+1])
              #print("Reconstructed answer:", reconstructed_answer)
              #print("-----------")
            
            #else:
              #print("Answer not found in context")
              #print("-----------")
              #start_positions.append(cls_index)
              #end_positions.append(cls_index)

          if len(start_positions)==0:
              return None
        
          ans_i = random.randrange(len(start_positions))

          encoding = {
                  'input_ids': encoding['input_ids'],
                  'attention_mask': encoding['attention_mask'],
                  'token_type_ids': encoding['token_type_ids'],
                  'bbox': encoding['bbox'],
                  'answers' : answers
                  }
          ## 바뀐 부분 example['image'].copy() -> example['image'].copy()[0]
          encoding['image'] = torch.LongTensor(example['image'].copy()[0])
          encoding['start_position'] = torch.LongTensor([start_positions[ans_i]])
          encoding['end_position'] = torch.LongTensor([end_positions[ans_i]])

          return encoding

    def __len__(self):
      return len(self.dataset)

    def __getitem__(self,index):
      data = self.dataset[index]
      data = self.encode_dataset(data)

      if data is None:
                #return self.__getitem__((index+1)%len(self))
        index = random.randrange(len(self))
        return self.__getitem__(index)

      return data

In [8]:
def collate(data):
    return {
            'input_ids': torch.cat([d['input_ids'] for d in data],dim=0),
            'attention_mask': torch.cat([d['attention_mask'] for d in data],dim=0),
            'token_type_ids': torch.cat([d['token_type_ids'] for d in data],dim=0),
            'bbox': torch.cat([d['bbox'] for d in data],dim=0),
            'image': torch.stack([d['image'] for d in data],dim=0),
            'start_positions': torch.cat([d['start_position'] for d in data],dim=0),
            'end_positions': torch.cat([d['end_position'] for d in data],dim=0),
            'answers': [d['answers'] for d in data],
            }

In [9]:
from transformers import LayoutLMv2Processor
processor = LayoutLMv2Processor.from_pretrained("microsoft/layoutlmv2-base-uncased")

In [1]:
from transformers import AutoModelForQuestionAnswering, AutoModel
from transformers import AdamW

model = AutoModelForQuestionAnswering.from_pretrained("microsoft/layoutlmv2-base-uncased")
# model = AutoModel.from_pretrained("microsoft/layoutlmv2-base-uncased")
optimizer = AdamW(model.parameters(),lr=5e-5)
start_epoch=0
start_idx=-1

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/802M [00:00<?, ?B/s]

Some weights of the model checkpoint at microsoft/layoutlmv2-base-uncased were not used when initializing LayoutLMv2ForQuestionAnswering: ['layoutlmv2.visual.backbone.bottom_up.res4.12.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.21.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.2.conv3.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.0.shortcut.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.7.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.1.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res2.0.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.1.conv2.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res3.3.conv1.norm.num_batches_tracked', 'layoutlmv2.visual.backbone.bottom_up.res4.18.conv2.norm.num_batches_tracke

In [5]:
for name, modele in model.named_modules():
    if 'self.dropout' in name:
        print(name)

layoutlmv2.encoder.layer.0.attention.self.dropout
layoutlmv2.encoder.layer.1.attention.self.dropout
layoutlmv2.encoder.layer.2.attention.self.dropout
layoutlmv2.encoder.layer.3.attention.self.dropout
layoutlmv2.encoder.layer.4.attention.self.dropout
layoutlmv2.encoder.layer.5.attention.self.dropout
layoutlmv2.encoder.layer.6.attention.self.dropout
layoutlmv2.encoder.layer.7.attention.self.dropout
layoutlmv2.encoder.layer.8.attention.self.dropout
layoutlmv2.encoder.layer.9.attention.self.dropout
layoutlmv2.encoder.layer.10.attention.self.dropout
layoutlmv2.encoder.layer.11.attention.self.dropout


In [11]:
def train_epoch():
  model.train()
  epoch_loss = 0
  steps = 0
  pbar = tqdm(dataloader)  # loop over the dataset multiple times
  for idx, batch in enumerate(pbar):
      
      input_ids = batch["input_ids"].to(device)
      attention_mask = batch["attention_mask"].to(device)
      token_type_ids = batch["token_type_ids"].to(device)
      bbox = batch["bbox"].to(device)
      image = batch["image"].to(device)
      start_positions = batch["start_positions"].to(device)
      end_positions = batch["end_positions"].to(device)

        # zero the parameter gradients
      optimizer.zero_grad()
      steps += 1
        # forward + backward + optimize
      outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                       bbox=bbox, image=image, 
                       start_positions=start_positions, 
                       end_positions=end_positions
                       )
      loss = outputs.loss
      epoch_loss += loss.detach().cpu().numpy().item()
      #   print("Loss:", loss.item())
      loss.backward()
      optimizer.step()
      pbar.set_postfix({
                'loss' : epoch_loss / steps,
                'lr' : optimizer.param_groups[0]['lr'],
            })
      wandb.log({'train_loss':epoch_loss/steps})
  # torch.save(model.state_dict(), '/opt/ml/docvqa/model.pt')
  pbar.close()

def ANLS(pred,answers):
    if answers[0] is not None:
        scores = []
        for ans in answers:
            ed = editdistance.eval(ans.lower(),pred.lower())
            NL = ed/max(len(ans),len(pred))
            scores.append(1-NL if NL<0.5 else 0)
        return max(scores)
    return []

def run (batch, start_logits,end_logits):
    batch_score = 0
    length = len(batch['input_ids'])
    for i in range(length):
        predicted_start_idx = start_logits[i].argmax(-1).item()
        predicted_end_idx = end_logits[i].argmax(-1).item()
        try:
            valid = processor.tokenizer.decode(batch['input_ids'][i][predicted_start_idx:predicted_end_idx+1])
            batch_score += ANLS(valid, batch['answers'][i])
        except:
            continue
    return batch_score / length

min_loss = 1
score = 0
def valid_epoch(epoch):
    global min_loss
    global score
    model.eval()
    gc.collect()
    start_logits_all, end_logits_all = [], []
    epoch_val_loss = 0
    steps = 0
    with torch.no_grad():
        pbar = tqdm(valid_dataloader)
        length = len(valid_dataloader)
        for idx, batch in enumerate(pbar):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            token_type_ids = batch["token_type_ids"].to(device)
            bbox = batch["bbox"].to(device)
            image = batch["image"].to(device)
            start_positions = batch["start_positions"].to(device)
            end_positions = batch["end_positions"].to(device)            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids,
                       bbox=bbox, image=image, 
                       start_positions=start_positions, 
                       end_positions=end_positions
                       )
            steps += 1
            loss = outputs.loss
            epoch_val_loss += loss.detach().cpu().numpy().item()
            batch_score = run(batch, outputs.start_logits, outputs.end_logits)
            score += batch_score
            pbar.set_postfix({
                'val_loss' : epoch_val_loss / steps,
                'val_score' : batch_score,
                'lr' : optimizer.param_groups[0]['lr'],
            })
    epoch_val_loss /= steps
    score /= steps
    # start_logits_all = np.concatenate(start_logits_all)[:length]
    # end_logits_all = np.concatenate(end_logits_all)[:length]
    print(f"Epoch [{epoch+1}] Val_loss : {epoch_val_loss}")
    print(f"Epoch [{epoch+1}] Val_Score : {score}" )
    wandb.log({'epoch' : epoch+1, 'val_loss' : epoch_val_loss})
    wandb.log({'epoch' : epoch+1, 'val_score' : score})

    if epoch_val_loss < min_loss:
        torch.save(model.state_dict(), '/opt/ml/docvqa/model.pt')
        min_loss = epoch_val_loss

    pbar.close()

In [12]:
dataset = DocVQADataset('train')
valid_dataset = DocVQADataset('valid')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16,collate_fn=collate, shuffle=True)
valid_dataloader = torch.utils.data.DataLoader(valid_dataset, batch_size=16,collate_fn=collate, shuffle=False)

model.to(device)
print(device)

Using custom data configuration Trailblazer-Yoo--boostcamp-docvqa-36fdb9dc869c2269
Found cached dataset parquet (/opt/ml/.cache/huggingface/datasets/Trailblazer-Yoo___parquet/Trailblazer-Yoo--boostcamp-docvqa-36fdb9dc869c2269/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 2/2 [00:00<00:00, 238.71it/s]
Using custom data configuration Trailblazer-Yoo--boostcamp-docvqa-36fdb9dc869c2269
Found cached dataset parquet (/opt/ml/.cache/huggingface/datasets/Trailblazer-Yoo___parquet/Trailblazer-Yoo--boostcamp-docvqa-36fdb9dc869c2269/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)
100%|██████████| 2/2 [00:00<00:00, 269.99it/s]


cuda


In [None]:
wandb.login()
wandb.init(entity='hundredeuk2',
                project='LayoutLM',
                group='test_QA',
                name='test',
                )

In [None]:
for i in range(3):
    train_epoch()
    valid_epoch(i)
    gc.collect()
wandb.finish()