In [1]:
import numpy as np
import os
from embeddings import WordEmbeddingManager, create_embedding_dataloader
import embeddings
import utils
from gan import Generator, Discriminator, train as train_gan
import torch
from sklearn.model_selection import train_test_split

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/johnhenryrudden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
TRAIN_PATH = 'data/raw_train.txt'
tokenized_sentences = utils.process_training_data(TRAIN_PATH)
# Only use the first 10,000 sentences for now
tokenized_sentences = tokenized_sentences

In [3]:
longest_sentence = max(tokenized_sentences, key=len)
print(f'Longest sentence has {len(longest_sentence)} tokens')
SEQ_LENGTH = len(longest_sentence)

Longest sentence has 20 tokens


In [4]:
WORD2VEC_MODEL_PATH = 'data/word2vec.model'
word2vec_manager = WordEmbeddingManager(WORD2VEC_MODEL_PATH)

Model loaded successfully from data/word2vec.model


# Double check one hot encoding and decoding is working as expected

In [5]:
# dataloader = create_embedding_dataloader(tokenized_sentences, word2vec_manager, seq_length=SEQ_LENGTH, batch_size=4, encoding_method="one_hot", verbose=True)
# batch = next(iter(dataloader))
# print(f'first batch sentence: {batch[0]}')
# encoded_first_sentence = batch[0]
# decoded_first_sentence = [word2vec_manager.decode_one_hot(encoded_token) for encoded_token in encoded_first_sentence]
# decoded_first_sentence


In [6]:
# sentence_to_encode = ['That', 'you', 'have', "ta'en", 'a', 'tardy', '<UNK>', 'here', '.', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>', '<PAD>']
# encoded_sentence = [word2vec_manager.one_hot_encode(word) for word in sentence_to_encode]
# decoded = [word2vec_manager.decode_one_hot(one_hot) for one_hot in encoded_sentence]
# encoded_sentence,decoded

# Start model training

In [7]:
# Initialize models
seq_length = SEQ_LENGTH 

# Generator first

gen_input_dim = embeddings.EMBEDDING_SIZE
gen_hidden_dim = 300

# add 1 to output dim to account for padding token
gen_output_dim = len(word2vec_manager._model.wv.key_to_index) + 1

generator = Generator(input_size=gen_input_dim, hidden_size=gen_hidden_dim, output_size=gen_output_dim, seq_length=seq_length)


# Discriminator
# Discriminator input is the same as the generator output (the generated next token probability distribution)
discrim_input_dim = gen_output_dim
discrim_hidden_dim = 100

discriminator = Discriminator(input_dim=discrim_input_dim, hidden_dim=discrim_hidden_dim, seq_length=seq_length)

In [8]:
discrim_params = list(discriminator.parameters())
gen_params = list(generator.parameters())
num_params_gen = sum([np.prod(p.size()) for p in gen_params])
num_params_discrim = sum([np.prod(p.size()) for p in discrim_params])
print(f'Generator has {num_params_gen} parameters')
print(f'Discriminator has {num_params_discrim} parameters')

Generator has 2485755 parameters
Discriminator has 2782901 parameters


In [9]:
# Train models

# Hyperparameters
learning_rate = 0.0001
batch_size = 4
temperature = 1.0
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

num_epochs = 1
TRAIN_ON_FRAC = 0.1
train_size = int(TRAIN_ON_FRAC * len(tokenized_sentences))
train_sents, _ = train_test_split(tokenized_sentences, train_size=TRAIN_ON_FRAC, random_state=42)



In [10]:
def generate_sentences(generator, temp):
    gens = []
    for i in range(10):
        noise = torch.randn(1, seq_length, gen_input_dim)
        generated_data = generator(noise, temperature, hard=False)
        argmaxs = torch.argmax(generated_data[0], dim=1)
        generated_sentence = [word2vec_manager.index_to_word(index) for index in argmaxs]
        gens.append(" ".join(generated_sentence).replace("<PAD>", ""))
    return gens

In [11]:
results = []
for lr in [0.001, 0.0005, 0.0001]:
    for batch_size in [4, 8, 16]:
        for temp in [0.5, 1.0, 1.5]:
            # TODO: need to test trained models on validation set
            # TODO: so this model needs percision recall and f1 score
            generator = Generator(input_size=gen_input_dim, hidden_size=gen_hidden_dim, output_size=gen_output_dim, seq_length=seq_length)
            discriminator = Discriminator(input_dim=discrim_input_dim, hidden_dim=discrim_hidden_dim, seq_length=seq_length)
            print(f'lr: {lr}, batch_size: {batch_size}, temp: {temp}')
            g_loss, d_loss =train_gan(
                generator=generator,
                discriminator=discriminator,
                tokenized_sentences=train_sents,
                word2vec_manager=word2vec_manager,
                seq_length=SEQ_LENGTH,
                generator_input_features=gen_input_dim,  # Updated parameter name
                num_epochs=num_epochs,
                batch_size=batch_size,
                noise_sample_method="normal",
                gumbel_hard=True,
                learning_rate=lr,
                temperature=temp,
                encoding_method="one_hot",
                device=device,
                debug=False,
            )
            gens = generate_sentences(generator, temp)
            results.append((lr, batch_size, temp, gens, g_loss, d_loss))


            

lr: 0.001, batch_size: 4, temp: 0.5


Epoch 1/1:   0%|          | 0/731 [00:00<?, ?it/s]

Epoch 1/1 | Generator Loss: 8.4974 | Discriminator Loss: 0.0008: 100%|██████████| 731/731 [01:40<00:00,  7.24it/s]


lr: 0.001, batch_size: 4, temp: 1.0


Epoch 1/1 | Generator Loss: 8.1949 | Discriminator Loss: 0.0008: 100%|██████████| 731/731 [01:43<00:00,  7.06it/s]


lr: 0.001, batch_size: 4, temp: 1.5


Epoch 1/1 | Generator Loss: 5.7720 | Discriminator Loss: 0.0030: 100%|██████████| 731/731 [01:44<00:00,  7.02it/s]


lr: 0.001, batch_size: 8, temp: 0.5


Epoch 1/1 | Generator Loss: 4.3503 | Discriminator Loss: 0.7030: 100%|██████████| 366/366 [02:09<00:00,  2.83it/s]


lr: 0.001, batch_size: 8, temp: 1.0


Epoch 1/1 | Generator Loss: 1.9736 | Discriminator Loss: 0.2262: 100%|██████████| 366/366 [01:40<00:00,  3.65it/s]


lr: 0.001, batch_size: 8, temp: 1.5


Epoch 1/1 | Generator Loss: 6.0809 | Discriminator Loss: 0.0031: 100%|██████████| 366/366 [01:42<00:00,  3.58it/s]


lr: 0.001, batch_size: 16, temp: 0.5


Epoch 1/1 | Generator Loss: 0.7343 | Discriminator Loss: 1.3689: 100%|██████████| 183/183 [07:20<00:00,  2.41s/it]   


lr: 0.001, batch_size: 16, temp: 1.0


Epoch 1/1 | Generator Loss: 1.5237 | Discriminator Loss: 0.6204: 100%|██████████| 183/183 [01:08<00:00,  2.67it/s]


lr: 0.001, batch_size: 16, temp: 1.5


Epoch 1/1 | Generator Loss: 1.1824 | Discriminator Loss: 0.5345: 100%|██████████| 183/183 [01:08<00:00,  2.67it/s]


lr: 0.0005, batch_size: 4, temp: 0.5


Epoch 1/1 | Generator Loss: 4.2601 | Discriminator Loss: 0.0939: 100%|██████████| 731/731 [01:45<00:00,  6.95it/s]


lr: 0.0005, batch_size: 4, temp: 1.0


Epoch 1/1 | Generator Loss: 5.3449 | Discriminator Loss: 0.3717: 100%|██████████| 731/731 [01:47<00:00,  6.82it/s]


lr: 0.0005, batch_size: 4, temp: 1.5


Epoch 1/1 | Generator Loss: 1.2972 | Discriminator Loss: 1.6144: 100%|██████████| 731/731 [01:44<00:00,  6.97it/s]


lr: 0.0005, batch_size: 8, temp: 0.5


Epoch 1/1 | Generator Loss: 2.9739 | Discriminator Loss: 0.0768: 100%|██████████| 366/366 [01:36<00:00,  3.80it/s]


lr: 0.0005, batch_size: 8, temp: 1.0


Epoch 1/1 | Generator Loss: 0.6228 | Discriminator Loss: 1.3979: 100%|██████████| 366/366 [01:37<00:00,  3.75it/s]


lr: 0.0005, batch_size: 8, temp: 1.5


Epoch 1/1 | Generator Loss: 2.1383 | Discriminator Loss: 0.2904: 100%|██████████| 366/366 [01:44<00:00,  3.50it/s]


lr: 0.0005, batch_size: 16, temp: 0.5


Epoch 1/1 | Generator Loss: 0.7921 | Discriminator Loss: 1.3552: 100%|██████████| 183/183 [01:12<00:00,  2.53it/s]


lr: 0.0005, batch_size: 16, temp: 1.0


Epoch 1/1 | Generator Loss: 0.7260 | Discriminator Loss: 1.3720: 100%|██████████| 183/183 [01:08<00:00,  2.67it/s]


lr: 0.0005, batch_size: 16, temp: 1.5


Epoch 1/1 | Generator Loss: 0.7579 | Discriminator Loss: 1.3522: 100%|██████████| 183/183 [01:07<00:00,  2.70it/s]


lr: 0.0001, batch_size: 4, temp: 0.5


Epoch 1/1 | Generator Loss: 0.7977 | Discriminator Loss: 1.4066: 100%|██████████| 731/731 [01:44<00:00,  7.02it/s]


lr: 0.0001, batch_size: 4, temp: 1.0


Epoch 1/1 | Generator Loss: 1.4879 | Discriminator Loss: 1.0192: 100%|██████████| 731/731 [01:52<00:00,  6.48it/s]


lr: 0.0001, batch_size: 4, temp: 1.5


Epoch 1/1 | Generator Loss: 0.7224 | Discriminator Loss: 1.3949: 100%|██████████| 731/731 [01:57<00:00,  6.21it/s]


lr: 0.0001, batch_size: 8, temp: 0.5


Epoch 1/1 | Generator Loss: 0.7290 | Discriminator Loss: 0.9488: 100%|██████████| 366/366 [01:37<00:00,  3.75it/s]


lr: 0.0001, batch_size: 8, temp: 1.0


Epoch 1/1 | Generator Loss: 0.7040 | Discriminator Loss: 1.3744:   6%|▋         | 23/366 [00:06<01:37,  3.51it/s]


KeyboardInterrupt: 

In [None]:
results

[(0.001,
  4,
  0.5,
  ['Most heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard',
   'nor down down down down down down down down down down down down down down down down down down down',
   'Paulina rashly heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard',
   'affecting bring down down down down down down down down down down down down down down down down down down',
   'so down down down down down down down down down down down down down down down down down down down',
   'pretty spoil heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard heard',
   'followed fawn down down down down down down down down down down down down down down down down down down',
   'and welcome down down down down down down down down down down down down down down down down down down',
   'scorn gasping down down down down down down down down d