In [1]:
import argparse
import sys
import pickle
import math
import os 

import torch
from transformers import LayoutLMv3Tokenizer, AutoConfig
import numpy as np


sys.path.append('../src')
from model import My_DataLoader
from model.LayoutLMv3forMIM import LayoutLMv3ForPretraining

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_vocab_dir", type=str, required=True)
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--model_params", type=str)
parser.add_argument("--ratio_train", type=float,default=0.9)
parser.add_argument("--output_model_dir", type=str, required=True)
parser.add_argument("--output_file_name", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--leaning_rate", type=int, default=1e-5)
parser.add_argument("--max_epochs", type=int, default=1)
args_list = ["--tokenizer_vocab_dir", "../data/vocab/tokenizer_vocab/","--input_file",
            "../data/preprocessing_shared/wpa_10000/",
            "--output_model_dir", "../data/train/model/", \
            "--output_file_name", "model.param", \
            "--batch_size", "4", \
            "--model_name", "microsoft/layoutlmv3-base", \
              "--model_params", "../data/train/pretrain_lr_1e-4_datasiez_10/epoch_1/checkpoint.cpt"]
args = parser.parse_args(args_list)

In [3]:
tokenizer = LayoutLMv3Tokenizer(f"{args.tokenizer_vocab_dir}vocab.json", f"{args.tokenizer_vocab_dir}merges.txt")
ids = range(tokenizer.vocab_size)
vocab = tokenizer.convert_ids_to_tokens(ids)

In [4]:
data = []
input_names = os.listdir(args.input_file)
for file_name in input_names[0:1]:
    print(file_name)
    with open(f"{args.input_file}{file_name}", "rb") as f:
        d = pickle.load(f)
        data += d

0.pkl


In [19]:
import collections
import itertools
MaskedLMInstance = collections.namedtuple("MaskedLmInstance",
                                          ["index", "label"])
#token id から span maskをする
#bpe baseではなく word base
def create_span_mask_for_ids(token_ids, masked_lm_prob, max_predictions_per_seq, vocab_words, param , rng):
    cand_indexes = []
    for i, id in enumerate(token_ids):
        if id == vocab_words.index("<s>") or id == vocab_words.index("</s>") or id == vocab_words.index("<pad>"):
            continue

        if len(cand_indexes) >= 1 and not vocab_words[id].startswith("Ġ"):
            cand_indexes[-1].append(i)
        else:
            cand_indexes.append([i])
    output_tokens = list(token_ids)
    #全単語×0.3(masked_lm_prob)がmaskの対象
    num_to_predict = min(max_predictions_per_seq, 
                      max(1, int(round(len(cand_indexes) * masked_lm_prob))))
    

    span_count = 0
    covered_indexes = [] #mask候補のリスト
    covered_set = set()  # 被らないか確かめるための集合
    #spanのword数が全words数の30%を超えたら終了
    while (span_count < num_to_predict):

        span_length = np.random.poisson(lam=param)
        if span_count + span_length > num_to_predict or span_length == 0:
            continue
        #cand_indexesから初めの単語を決める
        if len(cand_indexes) -(1 + span_length) <= 0:
            break
            # continue
        start_index = rng.randint(0, len(cand_indexes)-(1 + span_length))
        #span_lengthからsubword単位のspanの範囲を決める
        covered_index = cand_indexes[start_index: start_index +span_length]
        covered_index = list(itertools.chain.from_iterable(covered_index))
        if covered_set.isdisjoint(set(covered_index)):
            covered_set = covered_set | set(covered_index)
            span_count += span_length
            # print(span_length)
            covered_indexes.append(covered_index)
            # print(covered_indexes)

    masked_lms = []
    for span_index in covered_indexes:
        if rng.random() < 0.8:
            mask_token_id = vocab_words.index("<mask>")
            masked_tokens= [mask_token_id for _ in range(len(span_index))]
            #maskした場所と元のtokenを記録
            for i in span_index:
                masked_lms.append(MaskedLMInstance(index=i, label=token_ids[i]))

        else:
            if rng.random() < 0.5:
                masked_tokens = [token_ids[i] for i in span_index]

            else:
                #replace words
                masked_tokens = [rng.randint(0, len(vocab_words) - 1) for _ in range(len(span_index))]
                
        for i, index in enumerate(span_index):
            output_tokens[index] = masked_tokens[i]

    masked_lms = sorted(masked_lms, key=lambda x: x.index)

    masked_lm_positions = []
    masked_lm_labels = []    
    for p in masked_lms:
        masked_lm_positions.append(p.index)
        masked_lm_labels.append(p.label)
    
    #debag
    # if len(token_ids) > 300 and len(masked_lm_positions) < 2:
    #     print(f"error!!! token length: {len(token_ids)}, postions : {masked_lm_positions}, num_to_predict:{num_to_predict}, span_coont:{span_count},covered_indexes:{covered_indexes}, cand_index:{len(cand_indexes)} coverd_indexes_lenght:{len(covered_indexes)}", flush=True)


    return (output_tokens, masked_lm_positions, masked_lm_labels)

In [20]:
data[0]["input_ids"]

[0,
 7730,
 1567,
 16,
 436,
 22662,
 334,
 903,
 1725,
 971,
 402,
 784,
 65,
 7104,
 400,
 271,
 11482,
 285,
 685,
 16,
 11609,
 295,
 11106,
 299,
 81,
 5384,
 91,
 446,
 903,
 1725,
 351,
 7773,
 3045,
 342,
 1118,
 45476,
 1365,
 397,
 627,
 944,
 557,
 1365,
 408,
 6865,
 334,
 271,
 4387,
 285,
 271,
 816,
 285,
 2336,
 7184,
 271,
 35359,
 860,
 2220,
 380,
 14543,
 285,
 1783,
 295,
 11824,
 1272,
 364,
 271,
 2826,
 334,
 903,
 1725,
 1265,
 368,
 9893,
 295,
 5367,
 8976,
 400,
 271,
 8585,
 295,
 1381,
 6041,
 9823,
 299,
 3284,
 2026,
 13,
 1705,
 5771,
 923,
 2362,
 18,
 1283,
 478,
 351,
 476,
 6074,
 11547,
 1821,
 17948,
 903,
 17,
 1326,
 7730,
 7934,
 271,
 35359,
 16,
 759,
 46099,
 334,
 923,
 2362,
 364,
 262,
 7181,
 1603,
 4373,
 659,
 1990,
 2925,
 275,
 295,
 4268,
 271,
 8976,
 1385,
 342,
 979,
 342,
 673,
 903,
 1725,
 2362,
 402,
 825,
 455,
 271,
 11482,
 4704,
 303,
 551,
 2073,
 903,
 17,
 1326,
 7730,
 342,
 979,
 342,
 903,
 1725,
 3383,
 299,
 958,


In [36]:
if  args.model_params is not None:
    checkpoint = torch.load(args.model_params, map_location=torch.device('cpu'))
    config = AutoConfig.from_pretrained(args.model_name)
    config.num_visual_tokens = 8192
    model = LayoutLMv3ForPretraining(config)
    model.load_state_dict(checkpoint["model_state_dict"])
else:
    config = AutoConfig.from_pretrained(args.model_name)
    config.num_visual_tokens = 8192
    model = LayoutLMv3ForPretraining(config)
    # Roberta_model = RobertaModel.from_pretrained("roberta-base")
    # ## embedidng 層の重みをRobertaの重みで初期化
    # weight_size = model.state_dict()["model.embeddings.word_embeddings.weight"].shape
    # for i in range(weight_size[0]):
    #   model.state_dict()["model.embeddings.word_embeddings.weight"][i] = \
    #   Roberta_model.state_dict()["embeddings.word_embeddings.weight"][i]

In [44]:
import random
random.seed(12)
np.random.seed(12)
ids, pos , lab = create_span_mask_for_ids(data[2]["input_ids"], 0.3, 153, vocab, 1, random)
ids = torch.tensor(ids)
pos = torch.tensor(pos)
lab = torch.tensor(lab)
ids.shape, pos.shape , lab.shape

(torch.Size([512]), torch.Size([124]), torch.Size([124]))

In [45]:
cnt = 0
for i in ids:
  if i == 4:
    cnt += 1
cnt

124

In [46]:
pos, ids[:100]

(tensor([  1,   9,  13,  14,  20,  21,  22,  23,  24,  25,  26,  27,  28,  29,
          45,  46,  55,  65,  66,  67,  76,  77,  78,  79,  82,  83,  86,  87,
          88,  89,  90,  91,  92, 108, 109, 110, 123, 127, 128, 138, 147, 156,
         157, 162, 163, 164, 165, 166, 167, 168, 169, 175, 176, 183, 197, 201,
         205, 208, 209, 210, 211, 212, 213, 228, 229, 233, 234, 248, 249, 255,
         258, 265, 266, 269, 270, 276, 281, 282, 283, 284, 285, 286, 299, 300,
         311, 313, 314, 315, 321, 322, 334, 335, 341, 343, 346, 347, 357, 364,
         368, 371, 378, 379, 380, 386, 394, 395, 396, 397, 398, 407, 428, 429,
         436, 437, 447, 456, 457, 462, 463, 464, 469, 479, 480, 481]),
 tensor([    0,     4, 34922, 45492, 18593,   272,    30,    29,  5374,     4,
         13589, 11654,  3493,     4,     4,  6153,    17,  2708,  4864,   334,
             4,     4,     4,     4,     4,     4,     4,     4,     4,     4,
         19441,    81,    21,  6351,  6843, 20040, 11348,   

In [51]:
bo = torch.zeros(512)
for i in pos:
  bo[i] = 1
bo  = bo.to(torch.bool)
lab  , torch.tensor(data[]["input_ids"])[bo]

(tensor([44427,   428,   848,  6761, 38621,  3565,   287,  7698,  1446, 14340,
          1841,   613,   284,    16,  2011, 14590,   295,  8313,    16, 14735,
          8040,    16,  2483,  8831, 25188,    16,   335,   522,   626,   378,
           285, 39276,    16,  2087,  9823,    16,   262,   659,  4013,   285,
         38621,  4153,    18,   272,    30,    29,  1122,   406,   271,   814,
           285,   428,  1606,   271,  1351,   295,  1272,   295,   546,    18,
         34922, 45492,   280,   293,    18,   295,   271,  1801,  1017,   807,
         49925,  4464,    18, 38041,    31, 15315, 10835,   934,    31, 44427,
         34922, 45492,  3424,    30,   310,  5778,   285,   271, 12795,   339,
          1253, 38041,  1577, 26188,   295,   262,   351, 38621,   351,  2699,
         12530,   325,    16,   271,  6725,    16,  9497,  1206,    16,   271,
           285,  2336,  5116,    16,  1694, 36989,   295,   271,  5125,   285,
           271, 11239,   280,  1690]),
 tensor([4442

In [8]:
import random
my_dataloader = My_DataLoader.My_Dataloader(vocab, random)
dataloader = my_dataloader(data, batch_size=6, shuffle=True)

In [39]:
losses = []
# model.train()
for epoch in range(args.max_epochs):
    for iter, batch in enumerate(dataloader):
        # print(batch["input_ids"].shape)
        inputs = {k: batch[k] for k in ["input_ids", "bbox", "pixel_values", "attention_mask", "bool_mi_pos"]}
        text_logits, image_logits, wpa_logits = model.forward(inputs)
        break
        
        # for i in range(batch["input_ids"].shape[0]):
            
        #     for j in batch["ml_position"][i]:
        #         bo = torch.zeros(512)
        #         if batch["input_ids"][i][j] != 4:
        #             print(f"{i}, {j} ")
        #             break
        #         bo[j] = 1
        #     bo = bo.to(torch.bool)
        #     a = (batch["input_ids"][i][bo] == batch["ml_label"][i]).sum()
        #     print(a)
        #     if a == False:
        #         break
            



torch.Size([6, 512, 50265])

In [20]:
batch["input_ids"][0][batch["ml_position"][0]]

tensor([4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4])

In [28]:
inp = torch.stack([torch.tensor([[1,3], [2, 4]]), torch.tensor([[5,6], [7, 8]])])
inp

tensor([[[1, 3],
         [2, 4]],

        [[5, 6],
         [7, 8]]])

In [31]:
inp.shape

torch.Size([2, 2, 2])

In [35]:
inp[0][0][0]

tensor(1)

In [27]:
inp[torch.tensor([1])]

tensor([[[5, 6],
         [7, 8]]])

In [None]:
def cal_ml_loss(text_logits, batch):
    t = []
    for i in range(len(batch["ml_position"])):
        if len(batch["ml_position"][i]) == 0:
            continue
        t.append(text_logits[i][batch["ml_position"][i]])
    if len(t) == 0:
        return 
    predict_word_token = torch.cat(t)
    labels = torch.cat(batch["ml_label"])
    print(predict_word_token.shape)
    print(labels.shape)
    # labels = labels.to(f'cuda:{model.device_ids[0]}')
    loss = criterion(predict_word_token + 1e-12, labels)
    return loss

In [18]:
logits = torch.arange(512)

torch.cat([torch.arange(10), torch.arange(10,20)])

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19])

In [None]:
bo = torch.zeros(512)
for i in batch["ml_position"][2]:
  bo[i] = 1
bo  = bo.to(torch.bool)
batch["ml_label"][2], batch["input_ids"][2][bo].shape