In [13]:
import torch
import torch.nn as nn
from torch import optim

import sys

In [5]:
# local imports 
sys.path.append("..")

from transformer_model import generate_text, train_model
from data import ArticlesIter, articles_filepath

## Setup

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Load Vocab

In [7]:
SAVE_PATH = "../transformer_model/vocab_save/" + "embeddings_vocab.pt"

emb_vocab_obj = torch.load(SAVE_PATH)

print("Vocab size:", len(emb_vocab_obj))

Vocab size: 400004


### Load Model

In [8]:
SAVE_PATH = "../transformer_model/model_save/" + "glove_emb_model.pt"

emb_model = torch.load(SAVE_PATH)

print("Loaded hyperparameters:")
print("---------------------------")
print("Embedding Dimension:", emb_model.embedding_dim)
print("Hidden Dim:         ", emb_model.hidden_dim)
print("Num Layers:         ", emb_model.num_layers)
print("Num Heads:          ", emb_model.num_heads)
print("Max Length:         ", emb_model.max_len)
print("Dropout Per.:       ", emb_model.dropout_p)

Loaded hyperparameters:
---------------------------
Embedding Dimension: 300
Hidden Dim:          300
Num Layers:          2
Num Heads:           4
Max Length:          100
Dropout Per.:        0.2


## Generate Some Text 
Given our mappings we just read in we can now have the transformer make some text.

In [9]:
start_seq = ["<sos>", "the", "quick", "brown", "fox"]
number_of_words = 10

for word in start_seq:
    if word != "<sos>":
        print(word, end=" ")
        
for word in generate_text(
    model=emb_model,
    vocab=emb_vocab_obj,
    start_seq=start_seq,
    max_length=number_of_words):
    print(word, end=" ")

the quick brown fox or one until to officials officials of about iraq countries 

## Train the Model

In [14]:
learning_rate = 3e-4

optimizer = optim.Adam(emb_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [12]:
article_iter = ArticlesIter(batch_size=1000)

for batch in article_iter:
    train_model(
        model=emb_model,
        optimizer=optimizer,
        criterion=criterion,
        data=batch,
        device=device
    )


     Unnamed: 0                                        ID  \
0             0  f49ee725a0360aa6881ed1f7999cc531885dd06a   
1             1  808fe317a53fbd3130c9b7563341a7eea6d15e94   
2             2  98fd67bd343e58bc4e275bbb5a4ea454ec827c0d   
3             3  e12b5bd7056287049d9ec98e41dbb287bd19a981   
4             4  b83e8bcfcd51419849160e789b6658b21a9aedcd   
..          ...                                       ...   
995         995  00773877bc7719de3f57057eee2cc600a3d60a19   
996         996  8c991735dbe909123fcb6eee51343f550d3a55af   
997         997  7e68a105ff1d93ba0ebb9651f9ec16242600b363   
998         998  f3310e85dc96de5b25a201b14a98600da0561e6b   
999         999  7b490a0d7dd91dc489d0c103413dfa788bd04332   

                                               Content  \
0    New York police are concerned drones could bec...   
1    By . Ryan Lipman . Perhaps Australian porn sta...   
2    This was, Sergio Garcia conceded, much like be...   
3    An Ebola outbreak that began i