In [1]:
from transformers import PreTrainedModel
from transformers import AutoModel, AutoTokenizer, MT5Tokenizer, GPT2LMHeadModel, AutoConfig

ENCODER_MODEL = "bert-base-multilingual-cased"
DECODER_MODEL = "THUMT/mGPT"

mBERT = AutoModel.from_pretrained(ENCODER_MODEL)
mBERT_tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL)

# gpt_config = AutoConfig.from_pretrained(DECODER_MODEL) , config=gpt_config
mGPT = AutoModel.from_pretrained(DECODER_MODEL)
mGPT_tokenizer = MT5Tokenizer.from_pretrained(DECODER_MODEL)
mGPT.config.use_cache = False

  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "
Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at

In [2]:
mem_params = sum([param.nelement()*param.element_size() for param in mGPT.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in mGPT.buffers()])
mem = mem_params + mem_bufs
print(f'mGPT memory usage: {mem/1e6:.2f} MB')

mem_params = sum([param.nelement()*param.element_size() for param in mBERT.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in mBERT.buffers()])
mem = mem_params + mem_bufs
print(f'mBERT memory usage: {mem/1e6:.2f} MB')

mGPT memory usage: 2263.02 MB
mBERT memory usage: 711.42 MB


In [3]:
print("mGPT Tokenizer info")
print(f"vocab size: {len(mGPT_tokenizer)}")
print(f"special tokens: {mGPT_tokenizer.all_special_tokens}")  # 이 친구 bos 토큰이 없음.
print(mGPT_tokenizer.tokenize("이순신은조선중기의무신이다"))
print(mGPT_tokenizer.tokenize("아버지가방에들어가신다"), '\n')

print("mBERT Tokenizer info")
print(f"vocab size: {len(mBERT_tokenizer)}")
print(f"special tokens: {mBERT_tokenizer.all_special_tokens}")
print(mBERT_tokenizer.tokenize("이순신은조선중기의무신이다"))
print(mBERT_tokenizer.tokenize("아버지가방에들어가신다"))

mGPT Tokenizer info
vocab size: 250100
special tokens: ['</s>', '<unk>', '<pad>']
['▁이', '순', '신', '은', '조선', '중', '기의', '무', '신', '이다']
['▁아', '버', '지가', '방', '에', '들어', '가', '신', '다'] 

mBERT Tokenizer info
vocab size: 119547
special tokens: ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
['이', '##순', '##신', '##은', '##조', '##선', '##중', '##기의', '##무', '##신', '##이다']
['아버지', '##가', '##방', '##에', '##들어', '##가', '##신', '##다']


In [4]:
special_tokens_dict = {"additional_special_tokens": ["<s>"]}
mGPT_tokenizer.add_special_tokens(special_tokens_dict=special_tokens_dict)
mGPT.resize_token_embeddings(len(mGPT_tokenizer))
mGPT_tokenizer.bos_token = "<s>"

In [5]:
import os
import logging
import json
from pathlib import Path

from datasets import load_dataset, DatasetDict

data_dir = Path("/opt/ml/final-project-level3-nlp-01/data/ko-ja")
folder_list = os.listdir(data_dir)

dataset_dict = dict()

for folder in folder_list:
    data = load_dataset("csv", data_files=[str(p) for p in data_dir.joinpath(folder).glob("*.csv")])
    dataset_dict[folder] = data['train']

raw_dataset = DatasetDict(dataset_dict)
train_set = raw_dataset["Training"]
valid_set = raw_dataset["Validation"]

Using custom data configuration default-d15fbe40f58f4a20
Reusing dataset csv (/opt/ml/.cache/huggingface/datasets/csv/default-d15fbe40f58f4a20/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-67c6f6653f5a37c1
Reusing dataset csv (/opt/ml/.cache/huggingface/datasets/csv/default-67c6f6653f5a37c1/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a)


  0%|          | 0/1 [00:00<?, ?it/s]

In [6]:
def func(examples, max_length, max_target_length):

    model_inputs = mBERT_tokenizer(examples['한국어'], max_length=max_length, padding="max_length", truncation=True)
    
    target_sentences = [mGPT_tokenizer.bos_token + ex for ex in examples['일본어']]
    decoder_inputs = mGPT_tokenizer(target_sentences, max_length=max_target_length, padding="max_length", truncation=True)

    model_inputs['decoder_input_ids'] = decoder_inputs['input_ids']
    model_inputs['decoder_attention_mask'] = decoder_inputs['attention_mask']
    model_inputs['labels'] = [ex[1:] for ex in decoder_inputs['input_ids']]

    return model_inputs

fn_kwargs = {"max_length": 512, "max_target_length": 1024}
tokenized_train = train_set.map(func, num_proc=5, batched=True, remove_columns=train_set.column_names, fn_kwargs=fn_kwargs)
tokenized_valid = valid_set.map(func, num_proc=5, batched=True, remove_columns=valid_set.column_names, fn_kwargs=fn_kwargs)


Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/csv/default-d15fbe40f58f4a20/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a/cache-21ef7ec33d993a64.arrow
Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/csv/default-d15fbe40f58f4a20/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a/cache-913a4dd487b68236.arrow
Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/csv/default-d15fbe40f58f4a20/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a/cache-abc9db3d2c1941a2.arrow
Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/csv/default-d15fbe40f58f4a20/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a/cache-0fe2bd0b2320d02e.arrow
Loading cached processed dataset at /opt/ml/.cache/huggingface/datasets/csv/default-d15fbe40f58f4a20/0.0.0/bf68a4c4aefa545d0712b2fcbb1b327f905bbe2f6425fbc5e8c25234acb9e14a/cache-c2fab142a4348dd3.a

In [7]:
import random
sample_indices = random.choices(range(len(tokenized_valid)), k=1)

sample = tokenized_train.select(sample_indices)
for s in sample:
    print(f"inputs: {mBERT_tokenizer.decode(s['input_ids'])}")
    print(f"decoder inputs: {mGPT_tokenizer.decode(s['decoder_input_ids'])}")
    print(f"labels: {mGPT_tokenizer.decode(s['labels'])}", "\n\n")

inputs: [CLS] 드릴 말씀은 신학기 개학과 함께 학생 교통안전에 관한 내용입니다. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [8]:
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

mBERT.to(device)
mGPT.to(device)

mBERT.eval()
mGPT.eval()

print()




In [9]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataloader = DataLoader(
    tokenized_train, batch_size=4, pin_memory=True, shuffle=True, 
    drop_last=True, num_workers=5, collate_fn=default_data_collator,
)

valid_dataloader = DataLoader(
    tokenized_valid, batch_size=4, pin_memory=True, shuffle=False, 
    drop_last=False, num_workers=5, collate_fn=default_data_collator,
)

In [10]:
from copy import deepcopy
bart_enc_config = AutoConfig.from_pretrained("facebook/bart-base")
bart_dec_config = deepcopy(bart_enc_config)


  "Passing `gradient_checkpointing` to a config initialization is deprecated and will be removed in v5 "


In [11]:
import torch.nn as nn
from transformers.models.bart.modeling_bart import BartEncoderLayer, BartDecoderLayer, _expand_mask

class graft_module(nn.Module):
    def __init__(self, num_enc_layers=1, num_dec_layers=1):
        super().__init__()

        # self.graft_enc_layers = nn.ModuleList([BartEncoderLayer() for _ in range(num_enc_layers)])
        


In [30]:
from transformers import models
from transformers.models.bart.modeling_bart import _expand_mask, _make_causal_mask

def make_cross_mask(enc_mask, dec_mask, dtype):
    enc_mask = enc_mask.to(dtype)
    dec_mask = dec_mask.to(dtype)
    return dec_mask.view(-1, 1, 1024, 1) @ enc_mask.view(-1, 1, 1, 512)


In [31]:

train_batch = next(iter(train_dataloader))

decoder_pooler = nn.Linear(1024, 768).to(device)
graft_enc_layer = BartEncoderLayer(bart_enc_config).to(device)
graft_dec_layer = BartDecoderLayer(bart_dec_config).to(device)

with torch.no_grad():
    train_batch = {k: t.to(device) for k, t in train_batch.items()}
    bert_output = mBERT(train_batch["input_ids"], attention_mask=train_batch["attention_mask"])
    gpt_output = mGPT(train_batch["decoder_input_ids"], attention_mask=train_batch["decoder_attention_mask"])

    mask = _expand_mask(train_batch['attention_mask'], bert_output.last_hidden_state.dtype)
    graft_dec_input = decoder_pooler(gpt_output.last_hidden_state)
    dec_mask = _expand_mask(train_batch["decoder_attention_mask"], gpt_output.last_hidden_state.dtype)
    cross_mask = make_cross_mask(train_batch["attention_mask"], train_batch["decoder_attention_mask"], dec_mask.dtype)
    print(mask.dtype)
    
    graft_enc_output = graft_enc_layer(bert_output.last_hidden_state, attention_mask=mask, layer_head_mask=None)[0]
    graft_dec_output = graft_dec_layer(hidden_states=graft_dec_input, attention_mask=dec_mask, 
                                       encoder_hidden_states=bert_output.last_hidden_state, encoder_attention_mask=cross_mask, use_cache=False)

    


torch.cuda.empty_cache()


torch.float32
1
2
3
4


In [None]:
bert_output.last_hidden_state.shape, gpt_output.last_hidden_state.shape

In [None]:
class Graformer(PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    
    self.encoder = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)])
    self.decoder = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)])
    
    # self.init_weights()
    
  def forward(
        self,
        encoder_hidden_states: torch.Tensor,
        encoder_attention_mask: torch.Tensor,
        decoder_hidden_states=None,
        decoder_attention_mask=None,
        output_attentions: bool = False,
        cross_attn_head_mask=None,
        use_cache=None,
        head_mask=None,
    ):

    for idx, encoder_layer in enumerate(self.encoder):
      encoder_layer_outputs = encoder_layer(
              encoder_hidden_states,
              encoder_attention_mask,
              layer_head_mask=(head_mask[idx] if head_mask is not None else None),
              output_attentions=output_attentions,
          )

      encoder_hidden_states = encoder_layer_outputs[0]    
    
    for idx, decoder_layer in enumerate(self.decoder):
      decoder_layer_outputs = decoder_layer(
          decoder_hidden_states,
          attention_mask=decoder_attention_mask,
          encoder_hidden_states=encoder_hidden_states,
          encoder_attention_mask=encoder_attention_mask,
          layer_head_mask=(head_mask[idx] if head_mask is not None else None),
          cross_attn_layer_head_mask=(
              cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
          ),
          past_key_value=None,
          output_attentions=output_attentions,
          use_cache=use_cache,
      )
      decoder_hidden_states = decoder_layer_outputs[0]

    return decoder_hidden_states

In [None]:
import torch
from torch import nn

class GraformerModel(PreTrainedModel):
  def __init__(self, config):
    super().__init__(config)
    
    self.encoder = mBERT
    self.decoder = mGPT
    self.graformer = Graformer(config)
    
    # self.init_weights()
  def forward(
        self,
        input_ids=None,
        attention_mask=None,
        decoder_input_ids=None,
        decoder_attention_mask=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
    
    encoder_outputs = self.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
    encoder_hidden_state = encoder_outputs[0]
    
    decoder_outputs = self.decoder(
            input_ids=decoder_input_ids,
            attention_mask=decoder_attention_mask,
        )
    decoder_hidden_state = decoder_outputs[0]

    graformer_hidden_state = self.graformer(
      encoder_hidden_states=encoder_hidden_state,
      encoder_attention_mask=attention_mask,
      decoder_hidden_states=decoder_hidden_state,
      decoder_attention_mask=decoder_attention_mask,
    )
    
    output_hidden_states = decoder_hidden_state + graformer_hidden_state

    return output_hidden_states

In [None]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained("facebook/bart-base")
model = GraformerModel(config=config)