In [1]:
# import libraries
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import os
import random
import pickle

Loading Pretrained LSTM

In [2]:
# network architecture
class LSTMNameGenerator(nn.Module):
    def __init__(self, vocab_size, hidden_size, embedding_size=32, dropout=0.2, char2int=None, int2char=None):
        super(LSTMNameGenerator, self).__init__()

        # useful model properties
        self.vocab_size = vocab_size        
        self.char2int = char2int
        self.int2char = int2char
        self.hidden_size = hidden_size

        # embedding layer is useful to map input into vector representation
        self.embedding = nn.Embedding(vocab_size, embedding_size)

        # LSTM layer preserved by PyTorch library
        # this layer handles LSTM Cell loops
        self.lstm = nn.LSTM(embedding_size, hidden_size, batch_first=True)

        # Linear layer for output
        self.output = nn.Linear(hidden_size, vocab_size)

        # Dropout layer
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, h=None, c=None):
        # its optional to init hidden state by ourselves
        # bcs PyTorch will handle it if we don't provide it

        # map input to vector
        x = self.embedding(x)

        # compute current hidden state
        if h != None and c != None:
            o, (h, c) = self.lstm(x, (h, c))
        else:
            o, (h, c) = self.lstm(x, None)

        # apply dropout
        o = self.dropout(o)

        # compute output
        o = self.output(o)

        # here we return hidden state too cz we want to use it in inference mode
        return o, h, c

    def init_hidden(self, device):
        return torch.randn(1, 1, self.hidden_size, device=device)

In [3]:
HID_SIZE = 128
EMB_SIZE = 64

In [4]:
checkpoint = torch.load('../models/name_genLSTM.pt')
vocab = pickle.load(open('../models/vocab.pkl', 'rb'))
int2char = pickle.load(open('../models/vocab_int2char.pkl', 'rb'))

gen = LSTMNameGenerator(len(vocab), HID_SIZE, EMB_SIZE, char2int=vocab, int2char=int2char)
gen.load_state_dict(checkpoint)

<All keys matched successfully>

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [6]:
gen.to(device)

LSTMNameGenerator(
  (embedding): Embedding(56, 64)
  (lstm): LSTM(64, 128, batch_first=True)
  (output): Linear(in_features=128, out_features=56, bias=True)
  (dropout): Dropout(p=0.2, inplace=False)
)

In [7]:
def generate_sample(model, phrase=None, max_length=6, temperature=1.0, top_k=None):
    # x_enc = [[vocab[ch] for ch in phrase]]
    x_enc = [[0]]
    # x_pad = pad_features(x_enc, max_length)
    x_torch = torch.tensor(x_enc, dtype=torch.int64, device=device)

    # create list for output
    char_out = phrase.split()

    # move to device
    x_torch = x_torch

    # init empty hidden state
    h = c = model.init_hidden(device)

    # running through seed phrase to generate hidden_state
    # here we leave the last character cz we will feed it in
    # the generating phase as the first sequence
    for i in range(len(phrase)-1):
        out, h, c = model(x_torch[:, i:i+1], h, c)

    # start generating
    for _ in range(max_length - len(phrase)):
        out, h, c = model(x_torch[:, -1:], h, c)
        # p = F.softmax(out / temperature, dim=-1).data        
        p = F.gumbel_softmax(out, 1).data        

        # pick top K token by top_k (if defined)
        if top_k is None:
            top_char = np.arange(len(vocab))
        else:
            p, top_char = p.topk(top_k).detach()
            top_char = top_char.cpu().numpy().squeeze()        

        # select next token and push it to input sequence
        p = p.cpu().numpy().squeeze()
        char_id = np.random.choice(top_char, p=p/p.sum())
        char = torch.tensor([[char_id]], dtype=torch.int64, device=device)
        x_torch = torch.cat([x_torch, char], dim=-1)
    
        # push to char_out too
        char_out.append(int2char[char_id] if char_id > 0 else ' ')

    return ''.join(char_out)

In [8]:
generate_sample(gen, phrase='', max_length=16)


'                '

Loading Data

In [9]:
# read dataset
PATH = '../data/names.txt'

with open(PATH, 'r', encoding='utf-8') as r:
    names = r.read().split('\n')

# shuffle dataset
index = list(range(len(names)))
random.shuffle(index)
names = [names[i] for i in index]

In [10]:
# build vocab
chars = tuple(set(''.join(names)))
int2char = dict(enumerate(chars, 1))
int2char[0] = '<PAD>'
char2int = {v: k for k, v in int2char.items()}

# encode words
names_enc = [[char2int[ch] for ch in name] for name in names]
names_enc[:5]

[[37, 3, 27, 55, 3, 51],
 [33, 26, 50, 40, 32, 1, 51, 3, 26],
 [16, 50, 40, 32, 26, 26],
 [21, 52, 40, 3, 3, 55, 3],
 [2, 32, 55, 27, 51, 32, 1, 40]]

In [11]:
# pad features
seq_length = max([len(x) for x in names_enc])

def pad_features(names, seq_length):
    features = np.zeros((len(names), seq_length), dtype=int)    

    for i, row in enumerate(names):
        # if seq_length < len(row) then row will be trimmed (expected)        
        features[i, :len(row)] = np.array(row)[:seq_length]

    return features

features = pad_features(names_enc, seq_length)

assert len(features) == len(names_enc)
assert len(features[0]) == seq_length

features[:5]

array([[37,  3, 27, 55,  3, 51,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [33, 26, 50, 40, 32,  1, 51,  3, 26,  0,  0,  0,  0,  0,  0],
       [16, 50, 40, 32, 26, 26,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       [21, 52, 40,  3,  3, 55,  3,  0,  0,  0,  0,  0,  0,  0,  0],
       [ 2, 32, 55, 27, 51, 32,  1, 40,  0,  0,  0,  0,  0,  0,  0]])

In [12]:
# train test split
train_size = .75     # we will use 80% of data as train set
# val_size = .5       # we will use 50% of test set as validation set

split_id = int(len(features) * train_size)
train_x, test_x = features[:split_id], features[split_id:]

# test_id = int(len(remain_x) * val_size)
# val_x, test_x = remain_x[:test_id], remain_x[test_id:]

print('Feature Shapes:')
print('===============')
print('Train set: {}'.format(train_x.shape))
# print('Validation set: {}'.format(val_x.shape))
print('Test set: {}'.format(test_x.shape))

Feature Shapes:
Train set: (5958, 15)
Test set: (1986, 15)


In [13]:
# generate batches
batch_size = 128

trainset = TensorDataset(torch.from_numpy(train_x))
# validset = TensorDataset(torch.from_numpy(val_x))
testset = TensorDataset(torch.from_numpy(test_x))

trainloader = DataLoader(trainset, shuffle=True, batch_size=batch_size)
# validloader = DataLoader(validset, shuffle=True, batch_size=batch_size)
testloader = DataLoader(testset, shuffle=True, batch_size=batch_size)

In [14]:
diter = iter(trainloader)
x = diter.next()[0]

print('Sample batch size: ', x.size())   # batch_size, seq_length
print('Sample batch input: \n', x)

Sample batch size:  torch.Size([128, 15])
Sample batch input: 
 tensor([[39, 40, 50,  ...,  0,  0,  0],
        [21, 52, 26,  ...,  0,  0,  0],
        [35, 26, 50,  ...,  0,  0,  0],
        ...,
        [ 2, 52, 40,  ...,  0,  0,  0],
        [ 4, 51, 50,  ...,  0,  0,  0],
        [18,  6,  6,  ...,  0,  0,  0]], dtype=torch.int32)


Discriminator Architecture

In [15]:
class LSTMNameDiscriminator(nn.Module):
    def __init__(self, vocab_size, out_size, hid_size=64, emb_size=32):
        super(LSTMNameDiscriminator, self).__init__()

        # embedding layer
        self.embedding = nn.Embedding(vocab_size, emb_size)

        # LSTM layer
        self.lstm = nn.LSTM(emb_size, hid_size, batch_first=True)

        # linear layer
        self.output = nn.Linear(hid_size, out_size)

        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # map input to vector
        x = self.embedding(x)

        # compute current hidden state
        o, _ = self.lstm(x)

        # get last sequence output
        o = o[:, -1, :]

        # feed output to linear layer
        logit = self.output(o)

        out = self.sigmoid(logit)

        return out

In [16]:
disc = LSTMNameDiscriminator(len(vocab), 1, HID_SIZE, EMB_SIZE)
disc.to(device)

LSTMNameDiscriminator(
  (embedding): Embedding(56, 64)
  (lstm): LSTM(64, 128, batch_first=True)
  (output): Linear(in_features=128, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

GAN Training

In [17]:
torch.cat([torch.tensor([[1]]), torch.tensor([[1]]), torch.tensor([[1]])], dim=0)

tensor([[1],
        [1],
        [1]])

In [18]:
def generate_fake(model, batch_size=batch_size, max_length=12, tau=1.0, noise=None):
    # X_enc = torch.zeros((batch_size, 1), dtype=torch.int64, device=device)
    X_enc = torch.tensor([], device=device, dtype=torch.int64)

    generated = []

    # insert noise to hidden and cell state
    if noise == None:
        h = c = model.init_hidden(device)
    else:
        h = c = noise

    for _ in range(batch_size):
        X_interim = torch.tensor([], device=device, dtype=torch.int64)
        char_out = []
        # start generating
        for _ in range(max_length):
            X = torch.zeros((1, 1), dtype=torch.int64, device=device)
            out, h, c = model(X, h, c)

            # apply gumbel softmax here
            p = F.gumbel_softmax(out, tau=tau).data

            top_char = np.arange(len(vocab))
            p = p.cpu().numpy().squeeze()
            char_id = np.random.choice(top_char, p=p/p.sum())
            char = torch.tensor([[char_id]], dtype=torch.int64, device=device)
            X_interim = torch.cat([X_interim, char], dim=-1)

            char_out.append(int2char[char_id] if char_id > 0 else ' ')
        
        generated.append(char_out)
        X_enc = torch.cat([X_enc, X_interim], dim=0)

    return X_enc, generated

In [19]:
X, fake = generate_fake(gen, batch_size=10, tau=5)
X, fake

(tensor([[42,  0,  0,  1, 13,  0, 27,  0,  0,  0, 46,  0],
         [ 0,  0, 52, 37,  0, 10, 55,  0,  0,  0, 12, 40],
         [ 0, 28,  0,  0, 23,  0,  0,  0, 13,  0, 46,  0],
         [18, 36,  0,  0, 23,  0, 25,  0,  0,  0,  0, 37],
         [ 0, 42, 54,  0,  4,  0,  4, 47, 32,  0, 18,  0],
         [ 0,  0, 38,  0,  0,  0,  0,  0, 42, 30,  0,  0],
         [ 0,  0, 54,  0,  0, 28,  0,  0, 24, 28,  0,  0],
         [ 0, 46,  3,  0,  0, 30,  0,  0, 28, 42, 52,  0],
         [20,  1,  0, 42,  0, 28, 13,  1, 38,  0, 28,  0],
         [ 0, 46,  0,  0, 23,  0,  0, 52,  0,  0, 38,  3]], device='cuda:0'),
 [['U', ' ', ' ', 'd', "'", ' ', 't', ' ', ' ', ' ', 'b', ' '],
  [' ', ' ', 'h', 'A', ' ', 'u', 'o', ' ', ' ', ' ', 'Q', 'a'],
  [' ', 'L', ' ', ' ', 'I', ' ', ' ', ' ', "'", ' ', 'b', ' '],
  ['E', 's', ' ', ' ', 'I', ' ', 'P', ' ', ' ', ' ', ' ', 'A'],
  [' ', 'U', 'x', ' ', 'M', ' ', 'M', 'F', 'l', ' ', 'E', ' '],
  [' ', ' ', 'k', ' ', ' ', ' ', ' ', ' ', 'U', 'W', ' ', ' '],
  [' ',

In [20]:
lr = 0.001
opt_gen = Adam(gen.parameters(), lr=lr)
opt_disc = Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
num_epoch = 100

In [21]:
fixed_noise = gen.init_hidden(device)

In [22]:
for epoch in tqdm(range(num_epoch), desc='Epochs'):

    gen.train()
    disc.train()

    for batch_idx, real in enumerate(tqdm(trainloader, desc='Batch')):
        
        real = real[0].to(device)

        # Train Discriminator
        fake, _ = generate_fake(gen, batch_size=batch_size, tau=5)
        disc_real = disc(real).view(-1)
        # lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        lossD_real = -1 * torch.mean(torch.log(disc_real))

        disc_fake = disc(fake).view(-1)
        # lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD_fake = -1 * torch.mean(torch.log(1 - disc_fake))

        # lossD = (lossD_real + lossD_fake) / 2
        lossD = (lossD_real + lossD_fake)

        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()


        # Train Generator
        output = disc(fake).view(-1)
        lossG = -1 * torch.mean(torch.log(output/(1 - output)))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f'Epoch [{epoch}/{num_epoch} Batch {batch_idx}/{len(trainloader)}] | Loss D: {lossD:.4f}, loss G: {lossG:.4f}'
            )

            with torch.no_grad():    
                _, fake_gen = generate_fake(gen, batch_size=1, tau=5, noise=fixed_noise)
                print(f'Generated: {"".join(fake_gen[0])}')

Epochs:   0%|          | 0/100 [00:00<?, ?it/s]

Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [0/100 Batch 0/47] | Loss D: 1.3758, loss G: 0.0706
Generated: 'S  a  A  U 


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [1/100 Batch 0/47] | Loss D: 1.3765, loss G: 0.0693
Generated: CZ      r  v


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [2/100 Batch 0/47] | Loss D: 1.3787, loss G: 0.0649
Generated: In      YP k


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [3/100 Batch 0/47] | Loss D: 1.3810, loss G: 0.0599
Generated: W    v    cH


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [4/100 Batch 0/47] | Loss D: 1.3803, loss G: 0.0616
Generated:  A 'u PVY F 


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [5/100 Batch 0/47] | Loss D: 1.3751, loss G: 0.0723
Generated: Z  a   Zc h 


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [6/100 Batch 0/47] | Loss D: 1.3758, loss G: 0.0705
Generated: c      F   U


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [7/100 Batch 0/47] | Loss D: 1.3789, loss G: 0.0642
Generated: Y k g Q     


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [8/100 Batch 0/47] | Loss D: 1.3749, loss G: 0.0725
Generated: Ig      U   


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [9/100 Batch 0/47] | Loss D: 1.3780, loss G: 0.0663
Generated:  M 'LZ      


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [10/100 Batch 0/47] | Loss D: 1.3761, loss G: 0.0699
Generated:  L  g N     


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [11/100 Batch 0/47] | Loss D: 1.3775, loss G: 0.0674
Generated:   VUM     Z 


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [12/100 Batch 0/47] | Loss D: 1.3775, loss G: 0.0673
Generated:  e Mn'   w M


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [13/100 Batch 0/47] | Loss D: 1.3780, loss G: 0.0661
Generated:  Ua L   gb R


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [14/100 Batch 0/47] | Loss D: 1.3762, loss G: 0.0702
Generated:  bz  t  M  R


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [15/100 Batch 0/47] | Loss D: 1.3797, loss G: 0.0627
Generated:  U L Z   G  


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [16/100 Batch 0/47] | Loss D: 1.3775, loss G: 0.0672
Generated: o  b   u RPY


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [17/100 Batch 0/47] | Loss D: 1.3761, loss G: 0.0701
Generated:       n    c


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [18/100 Batch 0/47] | Loss D: 1.3779, loss G: 0.0661
Generated: B d  b  ' n 


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [19/100 Batch 0/47] | Loss D: 1.3773, loss G: 0.0675
Generated:  Phg     Fbv


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [20/100 Batch 0/47] | Loss D: 1.3800, loss G: 0.0622
Generated:  M PZ -A    


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [21/100 Batch 0/47] | Loss D: 1.3759, loss G: 0.0708
Generated: I nE  P Hb  


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [22/100 Batch 0/47] | Loss D: 1.3789, loss G: 0.0646
Generated:    b  o     


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [23/100 Batch 0/47] | Loss D: 1.3780, loss G: 0.0663
Generated: W  W       u


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [24/100 Batch 0/47] | Loss D: 1.3759, loss G: 0.0704
Generated: HB HN       


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [25/100 Batch 0/47] | Loss D: 1.3782, loss G: 0.0660
Generated: o   P lE    


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [26/100 Batch 0/47] | Loss D: 1.3773, loss G: 0.0676
Generated:    Z YebEP z


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [27/100 Batch 0/47] | Loss D: 1.3792, loss G: 0.0638
Generated: UWE R cPd   


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [28/100 Batch 0/47] | Loss D: 1.3805, loss G: 0.0612
Generated:    gu mP    


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [29/100 Batch 0/47] | Loss D: 1.3768, loss G: 0.0687
Generated:  TR  V  bP M


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [30/100 Batch 0/47] | Loss D: 1.3761, loss G: 0.0701
Generated: e UPw       


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [31/100 Batch 0/47] | Loss D: 1.3798, loss G: 0.0623
Generated: tY U    '  L


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [32/100 Batch 0/47] | Loss D: 1.3766, loss G: 0.0691
Generated: C Pbs  ldU  


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [33/100 Batch 0/47] | Loss D: 1.3782, loss G: 0.0659
Generated: kNl  lP   cL


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [34/100 Batch 0/47] | Loss D: 1.3761, loss G: 0.0706
Generated: - Vu  N cl P


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [35/100 Batch 0/47] | Loss D: 1.3785, loss G: 0.0654
Generated: X t  z      


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [36/100 Batch 0/47] | Loss D: 1.3791, loss G: 0.0641
Generated:  bU tL P W  


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [37/100 Batch 0/47] | Loss D: 1.3811, loss G: 0.0601
Generated: M  PCh  V   


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [38/100 Batch 0/47] | Loss D: 1.3808, loss G: 0.0605
Generated:     UMhVg   


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [39/100 Batch 0/47] | Loss D: 1.3766, loss G: 0.0691
Generated: Pc   bh b I 


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [40/100 Batch 0/47] | Loss D: 1.3796, loss G: 0.0629
Generated: t     w W   


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [41/100 Batch 0/47] | Loss D: 1.3788, loss G: 0.0646
Generated: H lM'h     b


Batch:   0%|          | 0/47 [00:00<?, ?it/s]

Epoch [42/100 Batch 0/47] | Loss D: 1.3794, loss G: 0.0634
Generated: gtn    W P  


KeyboardInterrupt: 