# BARTモデルの学習

In [1]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="tokenizer.json", 
                                    bos_token="<s>", eos_token="</s>", unk_token="<unk>", pad_token="<pad>")

In [2]:
from transformers import BartForConditionalGeneration, BartConfig
import json

config_facebook_bart_base = json.load(open("config_facebook_bart_base.json", "r"))

del config_facebook_bart_base['_name_or_path']
del config_facebook_bart_base['task_specific_params']
del config_facebook_bart_base['transformers_version']
config_facebook_bart_base['vocab_size'] = tokenizer.vocab_size
config_facebook_bart_base['bos_token_id'] = tokenizer.bos_token_id
config_facebook_bart_base['forced_bos_token_id'] = tokenizer.bos_token_id
config_facebook_bart_base['eos_token_id'] = tokenizer.eos_token_id
config_facebook_bart_base['forced_eos_token_id'] = tokenizer.eos_token_id
config_facebook_bart_base['pad_token_id'] = tokenizer.pad_token_id
config_facebook_bart_base['decoder_start_token_id'] = tokenizer.eos_token_id

config = BartConfig(**config_facebook_bart_base)

model = BartForConditionalGeneration(config)

In [3]:
# This function is copied from modeling_bart.py
def shift_tokens_right(input_ids, pad_token_id):
    """Shift input ids one token to the right, and wrap the last non pad token (usually <eos>)."""
    prev_output_tokens = input_ids.clone()
    index_of_eos = (input_ids.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
    prev_output_tokens[:, 0] = input_ids.gather(1, index_of_eos).squeeze()
    prev_output_tokens[:, 1:] = input_ids[:, :-1]
    return prev_output_tokens

In [4]:
def convert_to_features(example_batch):
    # input_encodings = tokenizer.batch_encode_plus(example_batch['tokenized_kana_text'], pad_to_max_length=True, max_length=1024, return_tensors="pt")
    # target_encodings = tokenizer.batch_encode_plus(example_batch['plain_text'], pad_to_max_length=True, max_length=1024, return_tensors="pt")
    
    input_encodings = tokenizer.batch_encode_plus(example_batch['text'], 
                                                  pad_to_max_length=True, max_length=512, 
                                                  # padding=True,
                                                  return_tensors="pt")
    target_encodings = tokenizer.batch_encode_plus(example_batch['phoneme_text'], 
                                                   pad_to_max_length=True, max_length=512, 
                                                   # padding=True,
                                                   return_tensors="pt")

    labels = target_encodings['input_ids']
    decoder_input_ids = shift_tokens_right(labels, tokenizer.pad_token_id)
    labels[labels[:, :] == tokenizer.pad_token_id] = -100

    encodings = {
        'input_ids': input_encodings['input_ids'],
        'attention_mask': input_encodings['attention_mask'],
        'decoder_input_ids': decoder_input_ids,
        'labels': labels
    }

    return encodings

In [13]:
from datasets import Dataset
dataset_all = Dataset.from_json("all.json")

# 実際に使うデータセットをここで指定する
# dataset_sub = dataset_all
dataset_sub = Dataset.from_dict(dataset_all[:5000])
dataset_dict = Dataset.train_test_split(dataset_sub, test_size=0.1)
dataset_dict = dataset_dict.map(convert_to_features, batched=True)

columns = ['input_ids', 'labels', 'decoder_input_ids','attention_mask',] 
dataset_dict.set_format(type='torch', columns=columns)

Map:   0%|          | 0/4500 [00:00<?, ? examples/s]



Map:   0%|          | 0/500 [00:00<?, ? examples/s]

In [6]:
dataset_dict.save_to_disk("dataset_dict")

Saving the dataset (0/1 shards):   0%|          | 0/4500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

In [7]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(  
    output_dir='./models/g2p_prosody_bart',
    num_train_epochs=50,
    per_device_train_batch_size=2, 
    per_device_eval_batch_size=1,   
    warmup_steps=500,               
    weight_decay=0.01,              
    logging_dir='./logs',          
)

trainer = Trainer(
    model=model,                       
    args=training_args,                  
    train_dataset=dataset_dict['train'],        
    eval_dataset=dataset_dict['test']
)

In [8]:
trainer.train()

Step,Training Loss
500,3.4987
1000,1.9172
1500,1.7399
2000,1.6065
2500,1.4364


KeyboardInterrupt: 

In [9]:
input_text = dataset_all[234]['text']
print(input_text)


行なった実験における個々の乳児の


In [10]:
input_ids=tokenizer.encode(input_text, return_tensors="pt").to("cuda")
print(input_ids)


tensor([[   0, 2583, 3770, 3578, 3972,  366, 4384,  267,  423,  116,    2]],
       device='cuda:0')


In [11]:
output_ids = model.generate(input_ids, max_length=1024, num_beams=5, early_stopping=True)
print(output_ids)


tensor([[   2,    0,   11,   23,   27,    9,   26,   27,    4,   35,   27,    9,
         3303,   23,   13, 3303,   23,   21,   10, 3303,   31,   17,    4,   22,
           21,    9, 3303, 3305,   32,    4,   22,   27,   10,   27,   26,   27,
            5,    2]], device='cuda:0')


In [12]:
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(output_text)

^ k o [ n o # y o [ cl k a cl k i ] cl t e # j i [ cl ts u # j o ] o n o $
