In [1]:
import numpy as np
import os
from embeddings import WordEmbeddingManager, create_embedding_dataloader
import embeddings
import utils
from gan import Generator, Discriminator, train as train_gan
import torch

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/johnhenryrudden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
TRAIN_PATH = 'data/raw_train.txt'
tokenized_sentences = utils.process_training_data(TRAIN_PATH)
# Only use the first 10,000 sentences for now
tokenized_sentences = tokenized_sentences[:500]

In [3]:
longest_sentence = max(tokenized_sentences, key=len)
print(f'Longest sentence has {len(longest_sentence)} tokens')
SEQ_LENGTH = len(longest_sentence)

Longest sentence has 17 tokens


In [4]:
WORD2VEC_MODEL_PATH = 'data/word2vec.model'
word2vec_manager = WordEmbeddingManager(WORD2VEC_MODEL_PATH)

Model loaded successfully from data/word2vec.model


In [5]:
dataloader = create_embedding_dataloader(tokenized_sentences, word2vec_manager, seq_length=SEQ_LENGTH, batch_size=4, encoding_method="one_hot", verbose=True)

# Double check one hot encoding and decoding is working as expected

In [6]:
batch = next(iter(dataloader))
batch

Encoding sentence: ['Such', 'a', 'nature', ',', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'], with encoding method: one_hot
Encoding sentence: ['First', 'Citizen', ':', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'], with encoding method: one_hot
Encoding sentence: ['I', 'speak', 'from', '<UNK>', '.', 'Nay', ',', 'more', ',', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'], with encoding method: one_hot
Encoding sentence: ['Nay', ',', 'let', 'them', 'follow', ':', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>'], with encoding method: one_hot


tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 1.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 1., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 

In [7]:
encoded_first_sentence = batch[0]
decoded_first_sentence = [word2vec_manager.decode_one_hot(encoded_token) for encoded_token in encoded_first_sentence]
decoded_first_sentence


['Such',
 'a',
 'nature',
 ',',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>']

In [8]:
sentence_to_encode = ['That', 'you', 'have', "ta'en", 'a', 'tardy', '<UNK>', 'here', '.', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
encoded_sentence = [word2vec_manager.one_hot_encode(word) for word in sentence_to_encode]
encoded_sentence

[tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([0., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.]),
 tensor([1., 0., 0.,  ..., 0., 0., 0.])]

In [9]:
decoded = [word2vec_manager.decode_one_hot(one_hot) for one_hot in encoded_sentence]
decoded

['That',
 'you',
 'have',
 "ta'en",
 'a',
 'tardy',
 '<UNK>',
 'here',
 '.',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>',
 '<PAD>']

# Start model training

In [10]:
# Initialize models
seq_length = SEQ_LENGTH 

# Generator first

gen_input_dim = embeddings.EMBEDDING_SIZE
gen_hidden_dim = 300

# add 1 to output dim to account for padding token
gen_output_dim = len(word2vec_manager._model.wv.key_to_index) + 1

generator = Generator(input_size=gen_input_dim, hidden_size=gen_hidden_dim, output_size=gen_output_dim, seq_length=seq_length)


# Discriminator
# Discriminator input is the same as the generator output (the generated next token probability distribution)
discrim_input_dim = gen_output_dim
discrim_hidden_dim = 100

discriminator = Discriminator(input_dim=discrim_input_dim, hidden_dim=discrim_hidden_dim, seq_length=seq_length)

In [11]:
discrim_params = list(discriminator.parameters())
gen_params = list(generator.parameters())
num_params_gen = sum([np.prod(p.size()) for p in gen_params])
num_params_discrim = sum([np.prod(p.size()) for p in discrim_params])
print(f'Generator has {num_params_gen} parameters')
print(f'Discriminator has {num_params_discrim} parameters')

Generator has 2485755 parameters
Discriminator has 2782901 parameters


In [12]:
# Train models

# Hyperparameters
learning_rate = 0.001
num_epochs = 2
batch_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
temperature = 1.0

avg_g_loss, avg_d_loss = train_gan(
    generator=generator,
    discriminator=discriminator,
    tokenized_sentences=tokenized_sentences,
    word2vec_manager=word2vec_manager,
    seq_length=SEQ_LENGTH,
    generator_input_features=gen_input_dim,  # Updated parameter name
    num_epochs=num_epochs,
    batch_size=batch_size,
    learning_rate=learning_rate,
    temperature=temperature,
    encoding_method="one_hot",
    debug=False,
)



Epoch 1/2: 100%|██████████| 125/125 [00:13<00:00,  9.56it/s]
Epoch 2/2: 100%|██████████| 125/125 [00:13<00:00,  9.57it/s]


In [13]:
for _ in range(5):
    noise = torch.randn(1, seq_length, gen_input_dim)
    generated_data = generator(noise, temperature, should_sample=True)
    generated_sentence = [word2vec_manager.index_to_word(word_index) for word_index in generated_data]
    print(generated_sentence)

['fort', 'characters', 'HENRY', 'cuckolds', 'parliament', 'turn', 'Mars', 'appellant', 'elbow', 'prophecies', 'beggars', 'Conspirators', 'smock', 'Caius', 'shift', 'utterance', 'Pardon']
['Seeking', 'spoke', 'enrich', "'Whoop", 'Westminster', 'fool', 'meet', 'suns', 'helping', 'transport', 'passion', "'shall", 'harder', 'mask', "'love", 'vouches', "Thou'rt"]
['thy', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['ancient', 'CORIOLANUS', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
['Stabb', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
