In [1]:
import numpy as np
import os
from embeddings import SentenceEmbeddingDataset, WordEmbeddingManager, create_embedding_dataloader
import embeddings
import utils
from gan import Generator, Discriminator
import torch

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     /Users/johnhenryrudden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
TRAIN_PATH = 'data/raw_train.txt'
tokenized_sentences = utils.process_training_data(TRAIN_PATH)

In [3]:
longest_sentence = max(tokenized_sentences, key=len)
print(f'Longest sentence has {len(longest_sentence)} tokens')
SEQ_LENGTH = len(longest_sentence)

Longest sentence has 20 tokens


In [4]:
WORD2VEC_MODEL_PATH = 'data/word2vec.model'
word2vec_manager = WordEmbeddingManager(WORD2VEC_MODEL_PATH)

Model loaded successfully from data/word2vec.model


In [5]:
dataloader = create_embedding_dataloader(tokenized_sentences, word2vec_manager, seq_length=SEQ_LENGTH, batch_size=4)

In [6]:
batch = next(iter(dataloader))

In [7]:
batch[0]

tensor([[ 7.1316e-03,  1.9339e-02, -2.7511e-02,  6.9083e-03, -1.6032e-02,
         -4.1205e-02,  6.9458e-02,  1.0333e-01, -8.4358e-02,  1.8224e-03,
         -3.7226e-02, -3.8608e-02, -2.3026e-02,  5.3395e-02, -7.1094e-02,
          5.7522e-03,  2.3071e-02, -3.9609e-02, -9.7750e-02, -2.4183e-03,
          1.7771e-02,  2.1389e-02,  4.9973e-02, -1.1299e-02,  3.1689e-02,
         -3.8005e-02, -2.7232e-02, -2.7236e-02, -7.9745e-02,  7.4413e-03,
          9.5555e-04,  2.6702e-03,  6.4233e-03,  2.9464e-02, -2.4807e-02,
          4.9826e-02,  7.7195e-02,  2.9932e-02, -3.2439e-02,  3.7582e-02,
          5.7510e-02, -3.2887e-02, -1.8172e-02, -2.4403e-02,  1.1959e-01,
         -3.2395e-03, -1.1700e-02, -6.5287e-02,  7.9087e-02,  6.0487e-02],
        [-3.0621e-01,  6.4465e-01, -2.8543e-01,  5.6573e-01, -4.2993e-01,
         -1.0852e+00,  1.2997e+00,  1.5104e+00, -5.8759e-01, -4.2063e-01,
         -2.8488e-01, -2.5938e-01, -9.5989e-01,  3.7014e-02, -1.0439e+00,
         -2.6326e-01, -4.5191e-01,  3

In [12]:
# Initialize models
input_dim = embeddings.EMBEDDING_SIZE
hidden_dim = 300
output_dim = len(word2vec_manager._model.wv.key_to_index)
seq_length = SEQ_LENGTH 

generator = Generator(input_size=input_dim, hidden_size=hidden_dim, output_size=output_dim, seq_length=seq_length)
discriminator = Discriminator(input_dim=input_dim, hidden_dim=hidden_dim, seq_length=seq_length)

In [13]:
discrim_params = list(discriminator.parameters())
gen_params = list(generator.parameters())
num_params_gen = sum([np.prod(p.size()) for p in gen_params])
num_params_discrim = sum([np.prod(p.size()) for p in discrim_params])
print(f'Generator has {num_params_gen} parameters')
print(f'Discriminator has {num_params_discrim} parameters')

Generator has 2485454 parameters
Discriminator has 422701 parameters


In [14]:
# Train models
import torch.optim as optim

# Hyperparameters
learning_rate = 0.001
num_epochs = 2
batch_size = 4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
temperature = 0.5

# Init optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Init loss functions
loss = torch.nn.BCELoss()

# data loader
dataloader = create_embedding_dataloader(tokenized_sentences, word2vec_manager, seq_length=SEQ_LENGTH, batch_size=batch_size)
dataloader = iter(dataloader)

for _ in range(num_epochs):
    for real_data in dataloader:
        batch_size = real_data.size(0)

        # gen fake data
        noise = torch.randn(batch_size, seq_length, input_dim)
        generated_data = generator(noise, temperature)
        fake_labels = torch.zeros(batch_size, 1)

        # train discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(batch_size, 1)
        real_loss = loss(discriminator(real_data), real_labels)
        # need to detach generated data from the graph to avoid training the generator
        print(generated_data.shape)
        fake_loss = loss(discriminator(generated_data.detach()), fake_labels)
        total_loss = real_loss + fake_loss
        total_loss.backward()
        optimizer_D.step()

        # train generator
        optimizer_G.zero_grad()
        fake_pred = discriminator(fake_data)
        gen_loss = loss(fake_pred, real_labels)
        gen_loss.backward()
        optimizer_G.step()
        
    print(f'Generator loss: {gen_loss}')
    print(f'Discriminator loss: {total_loss}')


    


torch.Size([20, 4, 6854])


RuntimeError: input.size(-1) must be equal to input_size. Expected 50, got 6854