In [224]:
import argparse
import os
import  pickle5 as pickle
import math
import contextlib
import random 
import sys

import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoConfig, RobertaModel, LayoutLMv3Tokenizer
from torch.optim import AdamW
from transformers import get_constant_schedule_with_warmup

sys.path.append('../src')
from model import  My_DataLoader
from model.LayoutLMv3forMIM import LayoutLMv3ForPretraining
from utils.slack import notification_slack

In [225]:
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_vocab_dir", type=str, required=True)
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--model_params", type=str)
parser.add_argument("--ratio_train", type=float,default=0.9)
parser.add_argument("--output_model_dir", type=str, required=True)
parser.add_argument("--output_file_name", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--learning_rate", type=int, default=1e-4)
parser.add_argument("--max_epochs", type=int, default=3)
parser.add_argument("--datasize", type=int, default=2)
args_list = ["--tokenizer_vocab_dir", "../data/vocab/tokenizer_vocab/","--input_file",
            "../data/preprocessing_shared/wpa_10000/",
            "--output_model_dir", "../data/train/test2/ ", \
            "--output_file_name", "model.param", \
            "--batch_size", "32", \
            "--model_name", "microsoft/layoutlmv3-base", \
            ]
args = parser.parse_args(args_list)

In [226]:
def _collate_fn(self, batch):
        output_dict = {}
        for i, b in enumerate(batch):
            #λ=1のポアソン分布からspanを生成
            batch[i]["mask_input_ids"], batch[i]["ml_position"], batch[i]["ml_label"] = create_span_mask_for_ids(b["input_ids"], 0.3, 153, self.vocab, 1, self.rng)
            # if len(batch[i]["ml_position"]) == 0:
            #     notification_slack(f"maske lenght is 0!!!! and batch[i][input_ids]length is {len(batch[i]['input_ids'])}")   
        for i in ["input_ids", "bbox", "pixel_values"]:
            padding_value=0
            if i == "mask_input_ids":
                padding_value = self.vocab.index("<pad>")
            output_dict[i] = torch.nn.utils.rnn.pad_sequence(
                [torch.tensor(b[i]) for b in batch],
                batch_first=True,
                padding_value=padding_value
            )
            #pad_sequenceしても長さがseq_len以下の場合(not pixel values)
            if i != "pixel_values" and output_dict[i].shape[1] != self.seq_len:
                notification_slack(f"padding_{i}:{output_dict[i].shape} < 512, do pading")
                pad_len= self.seq_len -output_dict[i].shape[1]
                if i == "input_ids":
                    #iput_ids > 0
                    pad_tensor = torch.ones((output_dict[i].shape[0], pad_len), dtype=torch.long)*padding_value
                else:
                    #bbox > [0, 0, 0, 0]
                    pad_tensor = torch.ones((output_dict[i].shape[0], pad_len, 4), dtype=torch.long)*padding_value

                output_dict[i] = torch.cat((output_dict[i], pad_tensor), dim=1)

        for i in ["ml_position", "ml_label"]:
            output_dict[i] = [torch.LongTensor(b[i]) for b in batch]

        output_dict["bool_mi_pos"] = torch.cat([b["bool_masked_pos"] for b in batch])
        output_dict["mi_label"] = [b["label"] for b in batch]

        attention_mask = self._create_attention_mask(output_dict["input_ids"])
        output_dict["attention_mask"] = attention_mask
        
        #alignmentlabel for wpa
        al_labels = torch.nn.utils.rnn.pad_sequence(
            [b["alignment_labels"] for b in batch],
            batch_first=True,
            padding_value=False
        )
        if al_labels.shape[1] != self.seq_len:
            notification_slack(f"padding_alignment_labels:{al_labels.shape} < 512, do pading")
            pad_len= self.seq_len - al_labels.shape[1]
            pad_tensor = torch.zeros((al_labels.shape[0], pad_len)).to(torch.bool)
            al_labels = torch.cat((al_labels, pad_tensor), dim=1)
        
        output_dict["alignment_labels"] = al_labels

        return output_dict

In [227]:
# -*-coding:utf-8-*-

from ctypes import alignment
import sys

sys.path.append('../')
import torch
from torch.utils.data import DataLoader, Dataset
from utils import utils
from utils.slack import notification_slack
import numpy as np


class My_Dataloader():
    def __init__(self, vocab, random, seq_len=512,DataLoader=DataLoader):
        self.vocab = vocab
        self.DataLoader = DataLoader
        self.random = random
        self.seq_len = seq_len
        self.rng = random
    
    def __call__(self, dataset,  batch_size, shuffle):
        return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=True, collate_fn=self._collate_fn)
    
    def _create_attention_mask(self, x):
        return torch.masked_fill(torch.ones(x.shape), x == self.vocab.index("<pad>"), 0)

    def _collate_fn(self, batch):
        output_dict = {}
        for i, b in enumerate(batch):
            #λ=1のポアソン分布からspanを生成
            batch[i]["mask_input_ids"], batch[i]["ml_position"], batch[i]["ml_label"] = create_span_mask_for_ids(b["input_ids"], 0.3, 153, self.vocab, 1, self.rng)
            # if len(batch[i]["ml_position"]) == 0:
            #     notification_slack(f"maske lenght is 0!!!! and batch[i][input_ids]length is {len(batch[i]['input_ids'])}")   
        for i in ["mask_input_ids", "bbox", "pixel_values"]:
            padding_value=0
            if i == "mask_input_ids":
                padding_value = self.vocab.index("<pad>")
            output_dict[i] = torch.nn.utils.rnn.pad_sequence(
                [torch.tensor(b[i]) for b in batch],
                batch_first=True,
                padding_value=padding_value
            )
            #pad_sequenceしても長さがseq_len以下の場合(not pixel values)
            if i != "pixel_values" and output_dict[i].shape[1] != self.seq_len:
                notification_slack(f"padding_{i}:{output_dict[i].shape} < 512, do pading")
                pad_len= self.seq_len -output_dict[i].shape[1]
                if i == "mask_input_ids":
                    #iput_ids > 0
                    pad_tensor = torch.ones((output_dict[i].shape[0], pad_len), dtype=torch.long)*padding_value
                else:
                    #bbox > [0, 0, 0, 0]
                    pad_tensor = torch.ones((output_dict[i].shape[0], pad_len, 4), dtype=torch.long)*padding_value

                output_dict[i] = torch.cat((output_dict[i], pad_tensor), dim=1)

        for i in ["ml_position", "ml_label"]:
            output_dict[i] = [torch.LongTensor(b[i]) for b in batch]

        output_dict["bool_mi_pos"] = torch.cat([b["bool_masked_pos"] for b in batch])
        output_dict["mi_label"] = [b["label"] for b in batch]

        attention_mask = self._create_attention_mask(output_dict["mask_input_ids"])
        output_dict["attention_mask"] = attention_mask
        
        #alignmentlabel for wpa
        al_labels = torch.nn.utils.rnn.pad_sequence(
            [b["alignment_labels"] for b in batch],
            batch_first=True,
            padding_value=False
        )
        if al_labels.shape[1] != self.seq_len:
            notification_slack(f"padding_alignment_labels:{al_labels.shape} < 512, do pading")
            pad_len= self.seq_len - al_labels.shape[1]
            pad_tensor = torch.zeros((al_labels.shape[0], pad_len)).to(torch.bool)
            al_labels = torch.cat((al_labels, pad_tensor), dim=1)
        
        output_dict["alignment_labels"] = al_labels

        return output_dict
       
            

In [228]:
import collections

import sys
import fitz
import numpy as np
import itertools
import numpy as np
import torch


MaskedLMInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])

def create_span_mask_for_ids(token_ids, masked_lm_prob, max_predictions_per_seq, vocab_words, param , rng):
    if 4 in token_ids:
        print("error!!!!!! 4 in token_ids")

    cand_indexes = []
    for i, id in enumerate(token_ids):
        if id == vocab_words.index("<s>") or id == vocab_words.index("</s>") or id == vocab_words.index("<pad>"):
            continue

        if len(cand_indexes) >= 1 and not vocab_words[id].startswith("Ġ"):
            cand_indexes[-1].append(i)
        else:
            cand_indexes.append([i])
    output_tokens = list(token_ids)
    # output_tokens = copy.deepcopy(token_ids)
    #全単語×0.3(masked_lm_prob)がmaskの対象
    num_to_predict = min(max_predictions_per_seq, 
                      max(1, int(round(len(cand_indexes) * masked_lm_prob))))
    

    span_count = 0
    covered_indexes = [] #mask候補のリスト
    covered_set = set()  # 被らないか確かめるための集合
    #spanのword数が全words数の30%を超えたら終了
    while (span_count < num_to_predict):

        span_length = np.random.poisson(lam=param)
        if span_count + span_length > num_to_predict or span_length == 0:
            continue
        #cand_indexesから初めの単語を決める
        if len(cand_indexes) -(1 + span_length) <= 0:
            break
            # continue
        start_index = rng.randint(0, len(cand_indexes)-(1 + span_length))
        #span_lengthからsubword単位のspanの範囲を決める
        covered_index = cand_indexes[start_index: start_index +span_length]
        covered_index = list(itertools.chain.from_iterable(covered_index))
        if covered_set.isdisjoint(set(covered_index)):
            covered_set = covered_set | set(covered_index)
            span_count += span_length
            # print(span_length)
            covered_indexes.append(covered_index)
            # print(covered_indexes)

    masked_lms = []
    for span_index in covered_indexes:
        if rng.random() < 0.8:
            mask_token_id = vocab_words.index("<mask>")
            masked_tokens= [mask_token_id for _ in range(len(span_index))]
            #maskした場所と元のtokenを記録
            for i in span_index:
                masked_lms.append(MaskedLMInstance(index=i, label=token_ids[i]))
                # if token_ids[i] == 4:
                    # print(f"token_ids[i]==4!!!index = {i}, {token_ids}")

        else:
            if rng.random() < 0.5:
                masked_tokens = [token_ids[i] for i in span_index]

            else:
                #replace words
                masked_tokens = [rng.randint(0, len(vocab_words) - 1) for _ in range(len(span_index))]
         ###################################bag#####################################       
        for i, index in enumerate(span_index):
            output_tokens[index] = masked_tokens[i]
####################################################################################
    masked_lms = sorted(masked_lms, key=lambda x: x.index)

    masked_lm_positions = []
    masked_lm_labels = []    
    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)
        
    return (output_tokens, masked_lm_positions, masked_lm_labels)    

In [229]:
output_tokens

NameError: name 'output_tokens' is not defined

In [None]:
data = data[:1000]

In [230]:
import copy 
data1 = copy.deepcopy(data)

In [231]:
#create dataloader
my_dataloader = My_Dataloader(vocab, random)
train_dataloader = my_dataloader(data1, batch_size=args.batch_size, shuffle=True)

In [232]:
len(data)

1000

In [233]:
data[0].keys()

dict_keys(['input_ids', 'bbox', 'pixel_values', 'label', 'bool_masked_pos', 'alignment_labels'])

In [234]:
cnt = 0
data4 = []
for d in data1:
  if 4 in d["input_ids"]:
    cnt += 4
    data4.append(d)

cnt, vocab[4]

(0, '<mask>')

In [218]:
cnt = 0
for input in train_dataloader:
    cnt += 1

In [237]:
ma = input["mask_input_ids"]

In [238]:
input["input_ids"] = input.pop("mask_input_ids")

In [240]:
input.keys()

dict_keys(['bbox', 'pixel_values', 'ml_position', 'ml_label', 'bool_mi_pos', 'mi_label', 'attention_mask', 'alignment_labels', 'input_ids'])

In [243]:
ma

tensor([[    0,     4,     4,  ...,     1,     1,     1],
        [    0,     4, 10445,  ...,     1,     1,     1],
        [    0,   337, 14003,  ...,   446,  1118,     2],
        ...,
        [    0,  4664,  5459,  ...,  3274,  1668,     2],
        [    0,   939,  3348,  ...,     1,     1,     1],
        [    0, 28849,  8048,  ...,     1,     1,     1]])

In [241]:
input["input_ids"]

tensor([[    0,     4,     4,  ...,     1,     1,     1],
        [    0,     4, 10445,  ...,     1,     1,     1],
        [    0,   337, 14003,  ...,   446,  1118,     2],
        ...,
        [    0,  4664,  5459,  ...,  3274,  1668,     2],
        [    0,   939,  3348,  ...,     1,     1,     1],
        [    0, 28849,  8048,  ...,     1,     1,     1]])

In [158]:
a = [1, 2, 3, 4, 5]
b = a
b[0] = -100
a, b

([-100, 2, 3, 4, 5], [-100, 2, 3, 4, 5])