In [1]:
import numpy as np 
import pandas as pd 
import torch
from data_processing import generate_vocab, process_data, create_dataloaders
from model import get_pretrained_emb, EncoderRNN, DecoderRNN, DecoderAttnRNN, EncoderDecoder, EncoderDecoderAttn, EncoderCNN, EncoderCNN2, Decoder_RNN_from_CNN, CNN_RNN_EncoderDecoder 
from train_eval import evaluate, train_and_eval, summarize_results, plot_single_learning_curve, load_experiment_log
import pickle as pkl 
from datetime import datetime
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SRC_LANG = 'zh'
TARG_LANG = 'en'

SRC_MAX_SENTENCE_LEN = 10
TARG_MAX_SENTENCE_LEN = 10
SRC_VOCAB_SIZE = 30000 
TARG_VOCAB_SIZE = 30000 

BATCH_SIZE = 64 

# takes a long time to process, save to pickle for reimport in future 
#vocab = generate_vocab(SRC_LANG, TARG_LANG, SRC_VOCAB_SIZE, TARG_VOCAB_SIZE)
#vocab_filename = "{}-{}-vocab.p".format(SRC_LANG, TARG_LANG)
#pkl.dump(vocab, open(vocab_filename, "wb"))

# reload from pickle 
vocab_filename = "{}-{}-vocab.p".format(SRC_LANG, TARG_LANG)
vocab = pkl.load(open(vocab_filename, "rb"))
data = process_data(SRC_LANG, TARG_LANG, SRC_MAX_SENTENCE_LEN, TARG_MAX_SENTENCE_LEN, vocab, filter_long=False)
data_minibatch = process_data(SRC_LANG, TARG_LANG, SRC_MAX_SENTENCE_LEN, TARG_MAX_SENTENCE_LEN, vocab, sample_limit=BATCH_SIZE, filter_long=False) 
data_minitrain = process_data(SRC_LANG, TARG_LANG, SRC_MAX_SENTENCE_LEN, TARG_MAX_SENTENCE_LEN, vocab, sample_limit=1000, filter_long=False)

# create dataloaders 
loaders_full = create_dataloaders(data, SRC_MAX_SENTENCE_LEN, TARG_MAX_SENTENCE_LEN, BATCH_SIZE)
loaders_minibatch = create_dataloaders(data_minibatch, SRC_MAX_SENTENCE_LEN, TARG_MAX_SENTENCE_LEN, BATCH_SIZE)
loaders_minitrain = create_dataloaders(data_minitrain, SRC_MAX_SENTENCE_LEN, TARG_MAX_SENTENCE_LEN, BATCH_SIZE)

# model architecture params 
NETWORK_TYPE = 'cnn'
RNN_CELL_TYPE = 'gru'
NUM_LAYERS = 1 
ENC_HIDDEN_DIM = 512
DEC_HIDDEN_DIM = ENC_HIDDEN_DIM 
TEACHER_FORCING_RATIO = 1
CLIP_GRAD_MAX_NORM = 1
ENC_DROPOUT = 0 #0.2 
DEC_DROPOUT = 0 #0.2 
ATTENTION_TYPE = 'additive'

# training params  
NUM_EPOCHS = 10 #5
LR = 0.0003 # 0.0005
OPTIMIZER = 'Adam'
LAZY_TRAIN = False

# name the model and experiment 
EXPERIMENT_NAME = 'zh_final'
if NETWORK_TYPE == 'rnn': 
    MODEL_NAME = '{}-rnn-{}-attn'.format(SRC_LANG, ATTENTION_TYPE)
elif NETWORK_TYPE == 'cnn': 
    MODEL_NAME = '{}-cnn'.format(SRC_LANG)

# store as dict to save to results later 
params = {'experiment_name': EXPERIMENT_NAME,'model_name': MODEL_NAME, 'src_lang': SRC_LANG, 'targ_lang': TARG_LANG, 
          'rnn_cell_type': RNN_CELL_TYPE, 'src_max_sentence_len': SRC_MAX_SENTENCE_LEN, 
          'targ_max_sentence_len': TARG_MAX_SENTENCE_LEN, 'src_vocab_size': SRC_VOCAB_SIZE, 
          'targ_vocab_size': TARG_VOCAB_SIZE, 'num_layers': NUM_LAYERS, 'enc_hidden_dim': ENC_HIDDEN_DIM, 
          'dec_hidden_dim': DEC_HIDDEN_DIM, 'teacher_forcing_ratio': TEACHER_FORCING_RATIO, 
          'clip_grad_max_norm': CLIP_GRAD_MAX_NORM, 'enc_dropout': ENC_DROPOUT, 'dec_dropout': DEC_DROPOUT, 
          'attention_type': ATTENTION_TYPE, 'batch_size': BATCH_SIZE, 'num_epochs': NUM_EPOCHS, 
          'learning_rate': LR, 'optimizer': OPTIMIZER, 'lazy_train': LAZY_TRAIN} 


# instantiate model 

encoder = EncoderCNN(pretrained_word2vec=get_pretrained_emb(vocab[SRC_LANG]['word2vec'], vocab[SRC_LANG]['token2id']), 
                      src_max_sentence_len=10, dropout=0, enc_hidden_dim=params['enc_hidden_dim'])


decoder =  Decoder_RNN_from_CNN(dec_hidden_dim=params['dec_hidden_dim'], enc_hidden_dim=params['enc_hidden_dim'], num_layers=NUM_LAYERS,
                     targ_vocab_size=TARG_VOCAB_SIZE, targ_max_sentence_len=TARG_MAX_SENTENCE_LEN, batch_size=BATCH_SIZE, 
                     pretrained_word2vec=get_pretrained_emb(vocab[TARG_LANG]['word2vec'], vocab[TARG_LANG]['token2id']))
model = CNN_RNN_EncoderDecoder(encoder, decoder, vocab[TARG_LANG]['token2id']).to(device)

In [2]:
MODEL_NAME_TO_RELOAD = 'zh-cnn'
checkpoint = torch.load('model_checkpoints/{}.pth.tar'.format(MODEL_NAME_TO_RELOAD), map_location=device)
model.load_state_dict(checkpoint)

In [None]:
experiment_results = load_experiment_log(experiment_name=EXPERIMENT_NAME)

summarize_results(experiment_results)[['model_name', 'best_val_loss', 'best_val_bleu', 'runtime', 
                                          'total_params', 'trainable_params', 'dt_created']].head(1)

In [3]:
# check performance on validation set 
val_loss, val_bleu, val_hyp_idxs, val_ref_idxs, val_source_idxs, val_hyp_tokens, val_ref_tokens, val_source_tokens,\
val_attn = evaluate(model=model, loader=loaders_full['dev'], 
                    src_id2token=vocab[SRC_LANG]['id2token'], targ_id2token=vocab[TARG_LANG]['id2token'])
print("Validation BLEU: {:.2f} | Validation Loss: {:.2f}".format(val_bleu, val_loss))



Validation BLEU: 5.73 | Validation Loss: 5.80


In [4]:
# print predictions on val data 
for source, ref, hyp in zip(val_source_tokens, val_ref_tokens, val_hyp_tokens): 
    print("SOURCE: {}".format(' '.join(source)))
    print("REFERENCE: {}".format(' '.join(ref)))
    print("HYPOTHESIS: {}".format(' '.join(hyp)))
    print()

SOURCE: 我 11 岁 那年 记得 得有 一天 早晨 醒来 听见
REFERENCE: when i was 11 , i remember waking up
HYPOTHESIS: <SOS> i i want to , the i &apos;t ,

SOURCE: 我 的 父亲 在 用 他 的 灰色 小 收音
REFERENCE: my father was listening to bbc news on his
HYPOTHESIS: <SOS> i i is to to it , . i

SOURCE: 他 面带 <UNK> <UNK> 笑容 这 很少 少见 因为 大部
REFERENCE: there was a big smile on his face which
HYPOTHESIS: <SOS> and &apos;s a <UNK> of , and , and

SOURCE: 塔利 塔利班 走 了 父亲 大声 叫 着 <EOS> <PAD>
REFERENCE: &quot; the taliban are gone ! &quot; my father
HYPOTHESIS: <SOS> and , , is is going . . .

SOURCE: 我 不知 知道 那 意味 意味着 什么 但是 我 能
REFERENCE: i didn &apos;t know what it meant , but
HYPOTHESIS: <SOS> i i &apos;t know to he was like .

SOURCE: 你 现在 可以 去 个 真正 的 学校 念书 了
REFERENCE: &quot; you can go to a real school now
HYPOTHESIS: <SOS> you you can , that the the , ,

SOURCE: 我 永远 不会 忘记 那个 早晨 <EOS> <PAD> <PAD> <PAD>
REFERENCE: a morning that i will never forget . <EOS>
HYPOTHESIS: <SOS> i i i i was to to to <EOS>

SOURCE: 一个 真正 的 学校 <EOS> <PAD


SOURCE: 它们 帮 我们 记录 保存 回忆 和 我们 的 过去
REFERENCE: they &apos;re our <UNK> and our histories , the
HYPOTHESIS: <SOS> they &apos;re us to , , , , ,

SOURCE: 这 就是 这个 项目 目的 全部 为了 恢复 人性 中
REFERENCE: that &apos;s all this project was about , about
HYPOTHESIS: <SOS> this &apos;s what he of , the the ,

SOURCE: 当 这样 一张 照片 重回 它 的 主人 人身 身边
REFERENCE: when a photo like this can be returned to
HYPOTHESIS: <SOS> and when &apos;re of , , , , ,

SOURCE: 这个 项目 也 为 我们 这些 修 图 师 带来
REFERENCE: the project &apos;s also made a big difference in
HYPOTHESIS: <SOS> and this is that , we , , ,

SOURCE: 对 一些 修 图 师 这段 经历 为 他们 建立
REFERENCE: for some of them , it &apos;s given them
HYPOTHESIS: <SOS> and , , , , , &apos;s not to

SOURCE: 我 想 读 给 大家 一封 封电子邮件 电子 电子邮件 邮件
REFERENCE: i would like to conclude by reading an email
HYPOTHESIS: <SOS> i i to to show it the to and

SOURCE: 当 我 在 修复 那些 照片 的 时候 我 不禁
REFERENCE: &quot; as i worked , i couldn &apos;t help
HYPOTHESIS: <SOS> when i , i to , i to get

SOURCE: 其中 特别 是 一张 照

REFERENCE: he was trembling when our boat approached , frightened
HYPOTHESIS: <SOS> when we , to , were , , ,

SOURCE: 他 很 害怕 会 掉 到 水里 <EOS> <PAD> <PAD>
REFERENCE: he was petrified he would be knocked in the
HYPOTHESIS: <SOS> he &apos;s , , <EOS> <EOS> . . <EOS>

SOURCE: <UNK> 下 浸 着 些 树 的 枝干 经常 常会
REFERENCE: the skeletal tree limbs submerged in lake volta often
HYPOTHESIS: <SOS> the <UNK> is to , , the , in

SOURCE: 很多 人 都 淹死 了 <EOS> <PAD> <PAD> <PAD> <PAD>
REFERENCE: many of them drown . <EOS> <PAD> <PAD> <PAD>
HYPOTHESIS: <SOS> and , are are bi . <EOS> . .

SOURCE: 从 他 记事 开始 就 被迫 迫在 <UNK> 上工 工作
REFERENCE: for as long as he can recall , he
HYPOTHESIS: <SOS> and , , he , , , , ,

SOURCE: 他 非常 害怕 主人 不敢 逃跑 由于 他 从小 就
REFERENCE: terrified of his master , he will not run
HYPOTHESIS: <SOS> he &apos;s &apos;t , , , he , to

SOURCE: 我 在 早上 五点 时 看到 这些 男孩 男孩子 孩子
REFERENCE: i met these boys at five in the morning
HYPOTHESIS: <SOS> i was in in in in , in ,

SOURCE: 在 这样 寒冷 <UNK> 的 晚上 <EOS> <PAD> <

In [5]:
# check performance on test set 
test_loss, test_bleu, test_hyp_idxs, test_ref_idxs, test_source_idxs, test_hyp_tokens, test_ref_tokens, test_source_tokens,\
test_attn = evaluate(model=model, loader=loaders_full['test'], 
                     src_id2token=vocab[SRC_LANG]['id2token'], targ_id2token=vocab[TARG_LANG]['id2token'])
print("Test BLEU: {:.2f} | Test Loss: {:.2f}".format(test_bleu, test_loss))



Test BLEU: 6.45 | Test Loss: 5.60
