In [None]:
from transformers import PreTrainedModel
from transformers import AutoModel, AutoTokenizer, MT5Tokenizer, GPT2LMHeadModel, AutoConfig

ENCODER_MODEL = "bert-base-multilingual-cased"
DECODER_MODEL = "THUMT/mGPT"

mBERT = AutoModel.from_pretrained(ENCODER_MODEL)
mBERT_tokenizer = AutoTokenizer.from_pretrained(ENCODER_MODEL)

# gpt_config = AutoConfig.from_pretrained(DECODER_MODEL) , config=gpt_config
mGPT = AutoModel.from_pretrained(DECODER_MODEL)
mGPT_tokenizer = MT5Tokenizer.from_pretrained(DECODER_MODEL)
mGPT.config.use_cache = False

In [None]:
mem_params = sum([param.nelement()*param.element_size() for param in mGPT.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in mGPT.buffers()])
mem = mem_params + mem_bufs
print(f'mGPT memory usage: {mem/1e6:.2f} MB')

mem_params = sum([param.nelement()*param.element_size() for param in mBERT.parameters()])
mem_bufs = sum([buf.nelement()*buf.element_size() for buf in mBERT.buffers()])
mem = mem_params + mem_bufs
print(f'mBERT memory usage: {mem/1e6:.2f} MB')

In [None]:
print("mGPT Tokenizer info")
print(f"vocab size: {len(mGPT_tokenizer)}")
print(f"special tokens: {mGPT_tokenizer.all_special_tokens}")  # 이 친구 bos 토큰이 없음.
print(mGPT_tokenizer.tokenize("이순신은조선중기의무신이다"))
print(mGPT_tokenizer.tokenize("아버지가방에들어가신다"), '\n')

print("mBERT Tokenizer info")
print(f"vocab size: {len(mBERT_tokenizer)}")
print(f"special tokens: {mBERT_tokenizer.all_special_tokens}")
print(mBERT_tokenizer.tokenize("이순신은조선중기의무신이다"))
print(mBERT_tokenizer.tokenize("아버지가방에들어가신다"))

In [None]:
special_tokens_dict = {"additional_special_tokens": ["<s>"]}
mGPT_tokenizer.add_special_tokens(special_tokens_dict=special_tokens_dict)
mGPT.resize_token_embeddings(len(mGPT_tokenizer))
mGPT_tokenizer.bos_token = "<s>"

In [None]:
import os
import logging
import json
from pathlib import Path

from datasets import load_dataset, DatasetDict

data_dir = Path("/opt/ml/final-project-level3-nlp-01/data/ko-ja")
folder_list = os.listdir(data_dir)

dataset_dict = dict()

for folder in folder_list:
    data = load_dataset("csv", data_files=[str(p) for p in data_dir.joinpath(folder).glob("*.csv")])
    dataset_dict[folder] = data['train']

raw_dataset = DatasetDict(dataset_dict)
train_set = raw_dataset["Training"]
valid_set = raw_dataset["Validation"]

In [None]:
def func(examples, max_length, max_target_length):

    model_inputs = mBERT_tokenizer(examples['한국어'], max_length=max_length, padding="max_length", truncation=True)
    
    target_sentences = [mGPT_tokenizer.bos_token + ex for ex in examples['일본어']]
    decoder_inputs = mGPT_tokenizer(target_sentences, max_length=max_target_length, padding="max_length", truncation=True)

    model_inputs['decoder_input_ids'] = decoder_inputs['input_ids']
    model_inputs['decoder_attention_mask'] = decoder_inputs['attention_mask']
    model_inputs['labels'] = [ex[1:] for ex in decoder_inputs['input_ids']]

    return model_inputs

fn_kwargs = {"max_length": 512, "max_target_length": 1024}
tokenized_train = train_set.map(func, num_proc=5, batched=True, remove_columns=train_set.column_names, fn_kwargs=fn_kwargs)
tokenized_valid = valid_set.map(func, num_proc=5, batched=True, remove_columns=valid_set.column_names, fn_kwargs=fn_kwargs)

In [None]:
import random
sample_indices = random.choices(range(len(tokenized_valid)), k=1)

sample = tokenized_train.select(sample_indices)
for s in sample:
    print(f"inputs: {mBERT_tokenizer.decode(s['input_ids'])}")
    print(f"decoder inputs: {mGPT_tokenizer.decode(s['decoder_input_ids'])}")
    print(f"labels: {mGPT_tokenizer.decode(s['labels'])}", "\n\n")

In [None]:
import torch

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

mBERT.to(device)
mGPT.to(device)

mBERT.eval()
mGPT.eval()

print()

In [None]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

train_dataloader = DataLoader(
    tokenized_train, batch_size=4, pin_memory=True, shuffle=True, 
    drop_last=True, num_workers=5, collate_fn=default_data_collator,
)

valid_dataloader = DataLoader(
    tokenized_valid, batch_size=4, pin_memory=True, shuffle=False, 
    drop_last=False, num_workers=5, collate_fn=default_data_collator,
)

In [None]:
from copy import deepcopy
bart_enc_config = AutoConfig.from_pretrained("facebook/bart-base")
bart_dec_config = deepcopy(bart_enc_config)

In [None]:
import torch.nn as nn
from transformers.models.bart.modeling_bart import BartEncoderLayer, BartDecoderLayer

class graft_module(nn.Module):
    def __init__(self, num_enc_layers=1, num_dec_layers=1):
        super().__init__()

        self.graft_enc_layers = nn.ModuleList([BartEncoderLayer() for _ in range(num_enc_layers)])
        

In [None]:
from transformers import models
from transformers.models.bart.modeling_bart import _expand_mask
train_batch = next(iter(train_dataloader))

decoder_pooler = nn.Linear(1024, 768).to(device)
graft_enc_layer = BartEncoderLayer(bart_enc_config).to(device)
graft_dec_layer = BartDecoderLayer(bart_dec_config).to(device)

with torch.no_grad():
    train_batch = {k: t.to(device) for k, t in train_batch.items()}
    bert_output = mBERT(train_batch["input_ids"], attention_mask=train_batch["attention_mask"])
    gpt_output = mGPT(train_batch["decoder_input_ids"], attention_mask=train_batch["decoder_attention_mask"])

    mask = _expand_mask(train_batch['attention_mask'], bert_output.last_hidden_state.dtype)
    graft_dec_input = decoder_pooler(gpt_output.last_hidden_state)
    dec_mask = _expand_mask(train_batch["decoder_attention_mask"], gpt_output.last_hidden_state.dtype)
    print(dec_mask.shape)
    
    graft_enc_output = graft_enc_layer(bert_output.last_hidden_state, attention_mask=mask, layer_head_mask=None)[0]
    graft_dec_output = graft_dec_layer(hidden_states=graft_dec_input, attention_mask=dec_mask, 
                                       encoder_hidden_states=bert_output.last_hidden_state, encoder_attention_mask=mask, use_cache=False)


torch.cuda.empty_cache()


In [None]:
bert_output.last_hidden_state.shape, gpt_output.last_hidden_state.shape