In [8]:
import sys
import argparse
import os
import random
import math

sys.path.append('../src')
import  pickle
import torch
from transformers import LayoutLMv3Tokenizer, AutoConfig, AutoModel, RobertaModel, LayoutLMv3Model
from model import LayoutLMv3forMLM, My_DataLoader
from utils import utils, masking_generator
from torch.optim import AdamW
from transformers import get_constant_schedule_with_warmup
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    BaseModelOutputWithPoolingAndCrossAttentions,
)
from transformers.modeling_outputs import (
    BaseModelOutput,
    QuestionAnsweringModelOutput,
    SequenceClassifierOutput,
    TokenClassifierOutput,
)

In [9]:
parser = argparse.ArgumentParser()
parser.add_argument("--tokenizer_vocab_dir", type=str, required=True)
parser.add_argument("--input_file", type=str, required=True)
parser.add_argument("--model_params", type=str)
parser.add_argument("--ratio_train", type=float,default=0.9)
parser.add_argument("--output_model_dir", type=str, required=True)
parser.add_argument("--output_file_name", type=str, required=True)
parser.add_argument("--model_name", type=str, required=True)
parser.add_argument("--batch_size", type=int, default=2)
parser.add_argument("--leaning_rate", type=int, default=1e-5)
parser.add_argument("--max_epochs", type=int, default=1)
args_list = ["--tokenizer_vocab_dir", "../data/vocab/tokenizer_vocab/","--input_file",
            "../data/preprocessing_shared/encoded_dataset.pkl",
            "--output_model_dir", "../data/train/model/", \
            "--output_file_name", "model.param", \
            "--model_name", "microsoft/layoutlmv3-base"]
args = parser.parse_args(args_list)

In [5]:
tokenizer = LayoutLMv3Tokenizer(f"{args.tokenizer_vocab_dir}vocab.json", f"{args.tokenizer_vocab_dir}merges.txt")
ids = range(tokenizer.vocab_size)
vocab = tokenizer.convert_ids_to_tokens(ids)

In [37]:
if not args.model_params is None:
    model = torch.load(args.model_params)
else:
    config = AutoConfig.from_pretrained(args.model_name)
    model = LayoutLMv3forMLM.LayoutLMv3ForMLM(config)
    # Roberta_model = RobertaModel.from_pretrained("roberta-base")
    # ## embedidng 層の重みをRobertaの重みで初期化
    # weight_size = model.state_dict()["model.embeddings.word_embeddings.weight"].shape
    # for i in range(weight_size[0]):
    #   model.state_dict()["model.embeddings.word_embeddings.weight"][i] = \
    #   Roberta_model.state_dict()["embeddings.word_embeddings.weight"][i]

In [5]:
with open(args.input_file, 'rb') as f:
    data = pickle.load(f)

In [48]:
patch_size = (config.patch_size, config.patch_size)
image_size = (config.input_size, config.input_size)
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
patch_size, patch_shape,

((16, 16), (14, 14))

In [10]:
proj = torch.nn.Conv2d(config.num_channels, config.hidden_size, kernel_size=patch_size, stride=patch_size)

In [11]:
proj

Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))

In [12]:
pixel_values.shape

(3, 224, 224)

In [13]:
pixel_values = torch.tensor(pixel_values)

In [14]:
pixel_values

tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]],

        [[1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         ...,
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 1., 1., 1.]]])

In [15]:
vision_emb = proj(pixel_values)
vision_emb.shape

torch.Size([768, 14, 14])

In [16]:
x = vision_emb.flatten(2).transpose(1,2)

In [17]:
x.size()

torch.Size([768, 14, 14])

In [19]:
mask_token = torch.nn.Parameter(torch.zeros(1, 1, config.hidden_size))

In [22]:
mask_token.shape

torch.Size([1, 1, 768])

In [28]:
mask_tokens = mask_token.expand(args.batch_size, x.shape[1], -1)
mask_tokens.shape

torch.Size([2, 14, 768])

In [36]:
bool_masked_pos = torch.zeros([1, 14, 14])
bool_masked_pos.shape

torch.Size([1, 14, 14])

In [43]:
w = bool_masked_pos.type_as(mask_tokens)

In [44]:
w.shape

torch.Size([1, 14, 14])

In [47]:
w.type()

'torch.FloatTensor'

In [45]:
mask_tokens * w

RuntimeError: The size of tensor a (768) must match the size of tensor b (14) at non-singleton dimension 2

In [109]:
model.config.use_mask_tokne=True

In [110]:
model.config

LayoutLMv3Config {
  "_name_or_path": "microsoft/layoutlmv3-base",
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "coordinate_size": 128,
  "eos_token_id": 2,
  "has_relative_attention_bias": true,
  "has_spatial_attention_bias": true,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "input_size": 224,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_2d_position_embeddings": 1024,
  "max_position_embeddings": 514,
  "max_rel_2d_pos": 256,
  "max_rel_pos": 128,
  "model_type": "layoutlmv3",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "patch_size": 16,
  "rel_2d_pos_bins": 64,
  "rel_pos_bins": 32,
  "second_input_size": 112,
  "shape_size": 128,
  "text_embed": true,
  "torch_dtype": "float32",
  "transformers_version": "4.21.1",
  "type_vocab_size": 1,
  "use_mask_tokne": true,
  "visual_embed": true,
  "vocab_size"

In [3]:
class LayoutLMv3(LayoutLMv3Model):
    def __init__(self, config):
        super().__init__(config)
        self.mask_token = torch.nn.Parameter(torch.zeros(1, 1, config.hidden_size))
  
    def forward_image(self, pixel_values, bool_masked_pos=None):
        embeddings = self.patch_embed(pixel_values)
        batch_size, seq_len, _ = embeddings.size()
        print(embeddings.shape)
        print(self.mask_token.shape)

        if bool_masked_pos is not None:
            mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
            # replace the masked visual tokens by mask_tokens
            w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
            print(w.shape)
            embeddings = embeddings * (1 - w) + mask_tokens * w

        # add [CLS] token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        embeddings = torch.cat((cls_tokens, embeddings), dim=1)

        # add position embeddings
        if self.pos_embed is not None:
            embeddings = embeddings + self.pos_embed
      


        embeddings = self.pos_drop(embeddings)
        embeddings = self.norm(embeddings)

        return embeddings


    def forward(
        self,
        input_ids=None,
        bbox=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        pixel_values=None,
        output_attentions=None,
        output_hidden_states=None,
        bool_masked_pos=None,
        return_dict=None,
    ):
        r"""
        Returns:
        Examples:
        ```python
        >>> from transformers import AutoProcessor, AutoModel
        >>> from datasets import load_dataset
        >>> processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
        >>> model = AutoModel.from_pretrained("microsoft/layoutlmv3-base")
        >>> dataset = load_dataset("nielsr/funsd-layoutlmv3", split="train")
        >>> example = dataset[0]
        >>> image = example["image"]
        >>> words = example["tokens"]
        >>> boxes = example["bboxes"]
        >>> encoding = processor(image, words, boxes=boxes, return_tensors="pt")
        >>> outputs = model(**encoding)
        >>> last_hidden_states = outputs.last_hidden_state
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is not None:
            input_shape = input_ids.size()
            batch_size, seq_length = input_shape
            device = input_ids.device
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size, seq_length = input_shape
            device = inputs_embeds.device
        elif pixel_values is not None:
            batch_size = len(pixel_values)
            device = pixel_values.device
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds or pixel_values")

        if input_ids is not None or inputs_embeds is not None:
            if attention_mask is None:
                attention_mask = torch.ones(((batch_size, seq_length)), device=device)
            if token_type_ids is None:
                token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
            if bbox is None:
                bbox = torch.zeros(tuple(list(input_shape) + [4]), dtype=torch.long, device=device)

            embedding_output = self.embeddings(
                input_ids=input_ids,
                bbox=bbox,
                position_ids=position_ids,
                token_type_ids=token_type_ids,
                inputs_embeds=inputs_embeds,
            )

        final_bbox = final_position_ids = None
        patch_height = patch_width = None
        if pixel_values is not None:
            patch_height, patch_width = int(pixel_values.shape[2] / self.config.patch_size), int(
                pixel_values.shape[3] / self.config.patch_size
            )
            visual_embeddings = self.forward_image(pixel_values, bool_masked_pos)
            visual_attention_mask = torch.ones(
                (batch_size, visual_embeddings.shape[1]), dtype=torch.long, device=device
            )
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask, visual_attention_mask], dim=1)
            else:
                attention_mask = visual_attention_mask

            if self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
                if self.config.has_spatial_attention_bias:
                    visual_bbox = self.calculate_visual_bbox(device, dtype=torch.long, batch_size=batch_size)
                    if bbox is not None:
                        final_bbox = torch.cat([bbox, visual_bbox], dim=1)
                    else:
                        final_bbox = visual_bbox

                visual_position_ids = torch.arange(
                    0, visual_embeddings.shape[1], dtype=torch.long, device=device
                ).repeat(batch_size, 1)
                if input_ids is not None or inputs_embeds is not None:
                    position_ids = torch.arange(0, input_shape[1], device=device).unsqueeze(0)
                    position_ids = position_ids.expand(input_shape)
                    final_position_ids = torch.cat([position_ids, visual_position_ids], dim=1)
                else:
                    final_position_ids = visual_position_ids

            if input_ids is not None or inputs_embeds is not None:
                embedding_output = torch.cat([embedding_output, visual_embeddings], dim=1)
            else:
                embedding_output = visual_embeddings

            embedding_output = self.LayerNorm(embedding_output)
            embedding_output = self.dropout(embedding_output)
        elif self.config.has_relative_attention_bias or self.config.has_spatial_attention_bias:
            if self.config.has_spatial_attention_bias:
                final_bbox = bbox
            if self.config.has_relative_attention_bias:
                position_ids = self.embeddings.position_ids[:, : input_shape[1]]
                position_ids = position_ids.expand_as(input_ids)
                final_position_ids = position_ids

        extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
            attention_mask, None, device, dtype=embedding_output.dtype
        )

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
        # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
        head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

        encoder_outputs = self.encoder(
            embedding_output,
            bbox=final_bbox,
            position_ids=final_position_ids,
            attention_mask=extended_attention_mask,
            head_mask=head_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            patch_height=patch_height,
            patch_width=patch_width,
        )

        sequence_output = encoder_outputs[0]

        if not return_dict:
            return (sequence_output,) + encoder_outputs[1:]

        return BaseModelOutput(
            last_hidden_state=sequence_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

In [10]:
config = AutoConfig.from_pretrained(args.model_name)
model = LayoutLMv3(config)

In [17]:
model.config.patch_size

16

In [47]:
patch_size

(16, 16)

In [49]:
window_size = patch_shape
num_masking_patches = 75
max_mask_patches_per_block = None
min_mask_patches_per_block = 16

# generating mask for the corresponding image
mask_generator = masking_generator.MaskingGenerator(
            window_size, num_masking_patches=num_masking_patches,
            max_num_patches=max_mask_patches_per_block,
            min_num_patches=min_mask_patches_per_block,
        )

In [61]:
bool_masked_pos = mask_generator()
bool_masked_pos = torch.from_numpy(bool_masked_pos).unsqueeze(0)

In [62]:
bool_masked_pos.shape

torch.Size([1, 14, 14])

In [63]:
bool_masked_pos = bool_masked_pos.flatten(1).to(torch.bool)

In [65]:
bool_masked_pos.shape

torch.Size([1, 196])

In [73]:
 bool_masked_pos

tensor([[False, False, False,  True,  True,  True, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         False, False,  True,  True,  True, False, False, False, False, False,
         False, False, False, False, False, False,  True,  True,  True, False,
         False, False, False, False, False, False, False, False, False, False,
          True,  True,  True, False, False, False, False, False, False, False,
         False, False, False, False,  True,  True,  True, False, False, False,
         False, False, False, False, False, False, False, False,  True,  True,
          True, False, False, False, False, False,  True,  True,  True, False,
         False, False,  True,  True,  True, False, False, False,  True,  True,
          True,  True,  True,  True,  True,  True,  

In [47]:
#divide into train and valid
n_train = math.floor(len(data) * args.ratio_train)
train_data = data[:n_train]
valid_data = data[n_train:]

my_dataloader = My_DataLoader.My_Dataloader(vocab)
train_dataloader = my_dataloader(train_data, batch_size=args.batch_size, shuffle=False)
valid_dataloader = my_dataloader(valid_data, batch_size=args.batch_size, shuffle=False)

In [104]:
losses = []
model.train()
for epoch in range(args.max_epochs):
    for iter, batch in enumerate(train_dataloader):
        # inputs = {k: v.to(f'cuda:{model.device_ids[0]}') for k in ["input_ids, bbox", "pixel_values", "attention_mask"]}
        inputs = {k: batch[k] for k in ["input_ids", "bbox", "pixel_values", "attention_mask"]}
        inputs["bool_masked_pos"] = torch.ones([1, 14, 14]).flatten(1)
        print(inputs.keys())
        logits = model.forward(**inputs)
        # loss = cal_loss(logits, batch)
        # if loss is None:
        #     continue
        # # labels = labels.to(f'cuda:{model.device_ids[0]}')
        # loss.backward()
        # optimizer.step()
        # scheduler.step()
        # optimizer.zero_grad()
        # losses.append(loss.item())
        # if iter % math.floor(iter_per_epoch*0.01) == 0:
        #     val_loss = validation()
        #     print(iter, loss.item())
        #     print(iter,"val", val_loss)

dict_keys(['input_ids', 'bbox', 'pixel_values', 'attention_mask', 'bool_masked_pos'])
torch.Size([2, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 196, 1])
dict_keys(['input_ids', 'bbox', 'pixel_values', 'attention_mask', 'bool_masked_pos'])
torch.Size([2, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 196, 1])
dict_keys(['input_ids', 'bbox', 'pixel_values', 'attention_mask', 'bool_masked_pos'])
torch.Size([2, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 196, 1])
dict_keys(['input_ids', 'bbox', 'pixel_values', 'attention_mask', 'bool_masked_pos'])
torch.Size([2, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 196, 1])
dict_keys(['input_ids', 'bbox', 'pixel_values', 'attention_mask', 'bool_masked_pos'])
torch.Size([2, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 196, 1])
dict_keys(['input_ids', 'bbox', 'pixel_values', 'attention_mask', 'bool_masked_pos'])
torch.Size([2, 196, 768])
torch.Size([1, 1, 768])
torch.Size([1, 196, 1])
dict_keys(['input_ids', 'bbox', 'pixel_v

KeyboardInterrupt: 

In [27]:
batch["mask_position"]

torch.Size([89])

In [13]:
# losses = []
# model.train()
# for epoch in range(args.max_epochs):
#     for iter, batch in enumerate(dataloader):
#         # inputs = {k: v.to(f'cuda:{model.device_ids[0]}') for k in ["input_ids, bbox", "pixel_values", "attention_mask"]}
#         inputs = {k: batch[k] for k in ["input_ids", "bbox", "pixel_values", "attention_mask"]}
#         logits = model.forward(inputs)
#         t = []
#         for i in range(len(batch["mask_position"])):
#             if len(batch["mask_position"][i]) == 0:
#                 continue
#             t.append(logits[i][batch["mask_position"][i]])
#         logits = torch.cat(t)

#         labels = torch.cat(batch["mask_label"])
#         # labels = labels.to(f'cuda:{model.device_ids[0]}')
        
#         loss = loss_fn(logits, labels)
#         loss.backward()
#         optimizer.step()
#         scheduler.step()
#         optimizer.zero_grad()
#         losses.append(loss.item())
#         if iter % 4 == 0:
#             print(iter, loss.item())

0 10.87540340423584
1 10.913202285766602
2 10.854154586791992
3 11.003539085388184
4 10.977672576904297
5 10.788056373596191
6 11.0016450881958
7 10.83245849609375
8 11.038727760314941
9 10.941429138183594
10 10.746644020080566
11 10.765816688537598
12 10.817831993103027
13 10.815743446350098
14 10.858794212341309
15 10.890091896057129
16 10.817254066467285
17 10.730331420898438
18 10.704963684082031
19 10.630881309509277
20 10.737689971923828
21 10.534788131713867
22 10.640432357788086
23 10.627967834472656
24 10.67199993133545
25 10.671628952026367
26 10.399262428283691
27 10.393364906311035
28 10.457799911499023
29 10.454362869262695
30 10.372509956359863
31 10.474432945251465
32 10.55550765991211
33 10.313406944274902
34 10.34122371673584
35 10.251405715942383
36 10.119297981262207
37 10.092342376708984
38 9.98108959197998
39 10.217385292053223
40 10.075906753540039
41 10.293874740600586
42 10.036898612976074
43 10.015714645385742
44 10.167057991027832
45 10.199872970581055
46 10.4

KeyboardInterrupt: 

In [25]:
torch.save(
    {
        "epoch": args.max_epochs,
        "batch_size": args.batch_size,
        "loss_list": losses,
        "model_state_dict": model.module.to("cpu").state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
    },
    f"{args.output_model_dir}{args.output_file_name}",
)     

In [29]:
# state = torch.load(args.output_model_dir+args.output_file_name)