In [1]:
import torch
from config import get_model_config, get_trainer_config
from model.utils import load_openai_weights, set_seed
from model.transformer_model import TransformerModel
import random
from model.trainer import Trainer
from model.text import BPEVocab
from model.dataset import FacebookDataset

model_config = get_model_config()
trainer_config = get_trainer_config()

set_seed(trainer_config.seed)
device = torch.device(trainer_config.device)

vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

transformer = TransformerModel(n_layers=model_config.n_layers,
                               n_embeddings=len(vocab),
                               n_pos_embeddings=model_config.n_pos_embeddings,
                               embeddings_size=model_config.embeddings_size,
                               padding_idx=vocab.pad_id,
                               n_heads=model_config.n_heads,
                               dropout=model_config.dropout,
                               embed_dropout=model_config.embed_dropout,
                               attn_dropout=model_config.attn_dropout,
                               ff_dropout=model_config.ff_dropout,
                               bos_id=vocab.bos_id,
                               eos_id=vocab.eos_id,
                               max_seq_len=model_config.max_seq_len,
                               beam_size=model_config.beam_size,  
                               length_penalty=model_config.length_penalty,
                               n_segments=model_config.n_segments,
                               annealing_topk=model_config.annealing_topk,
                               annealing=model_config.annealing,
                               diversity_coef=model_config.diversity_coef,
                               diversity_groups=model_config.diversity_groups)

train_dataset = FacebookDataset(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1)
test_dataset = FacebookDataset(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1)

model_trainer = Trainer(transformer,
                        train_dataset, 
                        test_dataset, 
                        batch_size=trainer_config.batch_size,
                        batch_split=trainer_config.batch_split, 
                        lr=trainer_config.lr, 
                        lr_warmup=trainer_config.lr_warmup, 
                        lm_weight=trainer_config.lm_weight,
                        risk_weight=trainer_config.risk_weight, 
                        n_jobs=trainer_config.n_jobs, 
                        clip_grad=trainer_config.clip_grad, 
                        device=device,
                        ignore_idxs=vocab.special_tokens_ids)

ckpt = trainer_config.default_checkpoint_path
state_dict = torch.load(ckpt, map_location=device)
model_trainer.load_state_dict(state_dict)
print('Weights loaded from {}'.format(ckpt))



Weights loaded from /data/kdgyun425/transformer/checkpoints/default_checkpoint


In [2]:
def sample_text_func(epoch):
    n_samples = 10
    samples_idxs = random.sample(range(len(test_dataset)), n_samples)
    samples = [test_dataset[idx] for idx in samples_idxs]
    for persona_info, dialog, target in samples:
        contexts = [torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [persona_info, dialog] if len(c) > 0]
        prediction = model_trainer.model.predict(contexts)[0]

        persona_info_str = vocab.ids2string(persona_info[1:-1])
        dialog_str = vocab.ids2string(dialog)
        dialog_str = dialog_str.replace(vocab.talker1_bos, '\n\t- ').replace(vocab.talker2_bos, '\n\t- ')
        dialog_str = dialog_str.replace(vocab.talker1_eos, '').replace(vocab.talker2_eos, '')
        target_str = vocab.ids2string(target[1:-1])
        prediction_str = vocab.ids2string(prediction)

        print('\n')
        print('Persona info:\n\t{}'.format(persona_info_str))
        print('Dialog:{}'.format(dialog_str))
        print('Target:\n\t{}'.format(target_str))
        print('Prediction:\n\t{}'.format(prediction_str))
        
sample_text_func(100)



Persona info:
	i 'm partly deaf . i love to drink fancy tea . i am a museum tour guide . i have a big library at home . 
Dialog:
	- soon i will be a nurse but now i 'm barmaid . crystal here . 
	- my name is rose and i grow them in my garden 
	- great ! do you have blue eyes and live with your best friend ? i do . 
Target:
	no i am deaf partly and i live in my home alone 
Prediction:
	no i do not have any 


Persona info:
	i enjoy buying new novels . i 'm a married woman . i enjoy educational films . i am having a baby and i have never given birth before . i just became an assistant last quarter . 
Dialog:
	- hello how are you today ? 
	- so tired . just laying on the couch reading my weekly book . you ? 
	- i 'm getting ready to go to a weekend concert to hear country music . 
	- that sounds fun . i wonder if i could do that while pregnant ? 
	- yes you can go to a concert . just be careful . do you like country music . 
	- i like some of it . it is popular here in pennsylvania . 
	