https://arxiv.org/abs/1711.05772 / https://arxiv.org/abs/1802.04877

https://github.com/natashamjaques/magenta/blob/affective-reward/magenta/models/affective_reward/latent_gan.py

In [1]:
import torch

print('cuda.is_available:', torch.cuda.is_available())
print(f'available: {torch.cuda.device_count()}; current: {torch.cuda.current_device()}')
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
print('pytorch', torch.__version__)

cuda.is_available: True
available: 1; current: 0
cuda:0
pytorch 0.4.0


## Labeling data

In [2]:
# pip install git+https://github.com/iconix/pytorch-text-vae.git
from pytorchtextvae import generate

In [3]:
DEVICE = torch.device('cpu') # CPU inference
n_samples = 2000
temp = 1.0

# workaround for un-pickling after module directory change https://stackoverflow.com/a/45264751
#import sys
#sys.path.append('../../pytorch-text-vae/pytorchtextvae')

vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state = generate.load_model('../../pytorch-text-vae/model/best/reviews_and_metadata_5yrs_state.pt', 'reviews_and_metadata_5yrs_stored_info.pkl', DEVICE, cache_path='../../pytorch-text-vae/model/best/tmp')

Fetching cached info at ../../pytorch-text-vae/model/best/tmp/reviews_and_metadata_5yrs_stored_info.pkl
Cache ../../pytorch-text-vae/model/best/tmp/reviews_and_metadata_5yrs_stored_info.pkl loaded (load time: 0.61s)
Found saved model ../../pytorch-text-vae/model/best/reviews_and_metadata_5yrs_state.pt
MAX_SAMPLE: False; TRUNCATED_SAMPLE: True
Trained for 360000 steps (load time: 18.88s)
Setting new random seed


In [4]:
#generate.generate(vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE, genres=['downtempo', 'dream pop', 'indietronica'], num_sample=10, temp=temp)

In [5]:
#gens, zs, conditions = generate.generate(vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE, num_sample=n_samples, temp=temp)

In [6]:
#list(zip(range(len(gens)), gens))

In [7]:
def to_embed(z, condition):
    if condition.dim() == 1:
        condition = condition.unsqueeze(0)
    squashed_condition = vae.decoder.c2h(condition)
    return torch.cat([z, squashed_condition], 1)

In [8]:
n_latent = 128
from pytorchtextvae.datasets import EOS_token

def generate(condition, gan=None, z=None, max_sample=False, truncated_sample=True, temp=temp):
    with torch.no_grad():
        if gan is None:
            z_prime = z
        else:
            gan.eval()
            z = torch.randn(1, n_latent).to(DEVICE)
            decode_embed = to_embed(z, condition).to(DEVICE)
            z_prime = gan.G(decode_embed)

        generated = vae.decoder.generate_with_embed(z_prime, 50, temp, DEVICE, max_sample=max_sample, trunc_sample=truncated_sample)
        generated_str = model.float_word_tensor_to_string(output_side, generated)

        EOS_str = f' {output_side.index_to_word(torch.LongTensor([EOS_token]))} '

        if generated_str.endswith(EOS_str):
            generated_str = generated_str[:-5]

        # flip it back
        return generated_str[::-1], z, z_prime

### Topical
Prefer certain topics to others

In [9]:
from pytorchtextvae import datasets

def tokenize(line):
    l = line.strip().lstrip().rstrip()
    l = datasets.normalize_string(l)
    return l.split(' ')

In [10]:
n_examples = 3

sents = [pair[0] for pair in pairs]
texts = [tokenize(sentence) for sentence in sents]
texts[:n_examples]

[['at',
  'just',
  'over',
  '4',
  'minutes',
  'in',
  'length',
  'it',
  'ebbs',
  'and',
  'flows',
  'in',
  'a',
  'manner',
  'that',
  'leaves',
  'me',
  'wanting',
  'more',
  'more',
  'more',
  'when',
  'it',
  'finally',
  'comes',
  'to',
  'its',
  'abrupt',
  'stop'],
 ['the',
  'track',
  'is',
  'a',
  'remix',
  'by',
  'maya',
  'jane',
  'coles',
  'who',
  'is',
  'based',
  'out',
  'of',
  'uk',
  'london'],
 ['breezy',
  'and',
  'pensive',
  'song_title',
  'is',
  'tinged',
  'with',
  'a',
  'bit',
  'of',
  'bittersweet',
  'nostalgia',
  'a',
  'fitting',
  'song',
  'to',
  'play',
  'as',
  'you',
  'polish',
  'off',
  'the',
  'last',
  'of',
  'a',
  'bottle',
  'of',
  'rose',
  'and',
  'reminisce',
  'over',
  'happy',
  'memories',
  'of',
  'you',
  'and',
  'last',
  'summers',
  'spanish',
  'fling',
  'at',
  '12',
  '41',
  'am',
  'perhaps']]

In [11]:
from nltk.corpus import stopwords

# remove stop words and words that appear only once
stoplist = [datasets.normalize_string(word) for word in stopwords.words('english')]
fillerlist = ['author', 'song_title', 'artist', 'sitename']

texts = [[word for word in text if word not in stoplist and word not in fillerlist] for text in texts]
texts[:n_examples]

[['4',
  'minutes',
  'length',
  'ebbs',
  'flows',
  'manner',
  'leaves',
  'wanting',
  'finally',
  'comes',
  'abrupt',
  'stop'],
 ['track', 'remix', 'maya', 'jane', 'coles', 'based', 'uk', 'london'],
 ['breezy',
  'pensive',
  'tinged',
  'bit',
  'bittersweet',
  'nostalgia',
  'fitting',
  'song',
  'play',
  'polish',
  'last',
  'bottle',
  'rose',
  'reminisce',
  'happy',
  'memories',
  'last',
  'summers',
  'spanish',
  'fling',
  '12',
  '41',
  'perhaps']]

In [12]:
from gensim.corpora.dictionary import Dictionary

dictionary = Dictionary(texts)

In [13]:
from gensim.models.ldamodel import LdaModel
import time

start = time.time()
n_topics = 4
passes = 20 # number of passes through documents
iterations = 400 # maximum number of iterations through the corpus when inferring the topic distribution of a corpus.
minimum_probability = 0

corpus = [dictionary.doc2bow(text) for text in texts]
# Train the model on the corpus.
lda = LdaModel(corpus, id2word=dictionary, num_topics=n_topics, iterations=iterations, passes=passes, minimum_probability=minimum_probability)
#lda = LdaModel(corpus, id2word=dictionary, num_topics=n_topics)
print(f'Runtime: {time.time() - start:.2f}s')
lda.print_topics(n_topics)

Runtime: 583.60s


[(0,
  '0.018*"pop" + 0.010*"music" + 0.008*"electronic" + 0.007*"indie" + 0.005*"rock" + 0.005*"remix" + 0.005*"hip" + 0.005*"hop" + 0.005*"dance" + 0.004*"duo"'),
 (1,
  '0.018*"track" + 0.011*"vocals" + 0.011*"like" + 0.010*"song" + 0.006*"sound" + 0.005*"vocal" + 0.005*"sounds" + 0.005*"one" + 0.004*"production" + 0.004*"beat"'),
 (2,
  '0.013*"music" + 0.010*"one" + 0.009*"like" + 0.008*"song" + 0.007*"time" + 0.006*"know" + 0.006*"im" + 0.005*"get" + 0.005*"would" + 0.005*"really"'),
 (3,
  '0.032*"new" + 0.020*"single" + 0.019*"album" + 0.016*"track" + 0.013*"release" + 0.012*"ep" + 0.012*"first" + 0.012*"year" + 0.011*"debut" + 0.010*"released"')]

In [14]:
from operator import itemgetter

for i in range(n_examples):
    print(max(lda[corpus[i]],key=itemgetter(1)), datasets.normalize_string(sents[i]))

(2, 0.37183803) at just over 4 minutes in length it ebbs and flows in a manner that leaves me wanting more more more when it finally comes to its abrupt stop
(3, 0.48564482) the track is a remix by maya jane coles who is based out of uk london
(1, 0.5647102) breezy and pensive song_title is tinged with a bit of bittersweet nostalgia a fitting song to play as you polish off the last of a bottle of rose and reminisce over happy memories of you and last summers spanish fling at 12 41 am perhaps


In [15]:
import pyLDAvis.gensim
pyLDAvis.enable_notebook()

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

pyLDAvis.gensim.prepare(lda, corpus, dictionary)

In [16]:
from collections import Counter
topic_counter = Counter([max(lda[corpus[i]], key=itemgetter(1))[0] for i in range(len(texts))])
topic_counter

In [17]:
out = sorted([f'{max(lda[corpus[i]], key=itemgetter(1))} {datasets.normalize_string(sents[i])}\n' for i in range(len(texts))], reverse=True)
with open('pairs_sentence_topics.txt', 'w') as f:
    f.writelines(out)

^using this file to select desirable/preferred topics for **weights** later on.

In [18]:
def get_example(i, pairs, input_side, output_side, random_state, device):
    pair = pairs[i]

    inp = model.word_tensor(input_side, pair[0]).to(device)
    target = model.word_tensor(output_side, pair[1]).to(device)
    condition = torch.tensor(pair[2], dtype=torch.float).unsqueeze(0).to(device) if len(pair) == 3 else None

    return inp, target, condition

In [19]:
import numpy as np
import time
from pytorchtextvae import model

labels = np.zeros((n_samples, n_topics), dtype=float)
embeds = []
start = time.time()

# debug vars
ts = []

for i in range(n_samples):
    pair_i = random_state.choice(len(pairs))
    ts.append(pairs[pair_i][0])
    input, target, condition = get_example(pair_i, pairs, input_side, output_side, random_state, DEVICE)
    with torch.no_grad():
        _, _, z, _ = vae(input, target, condition, DEVICE, temp)
        squashed_condition = vae.decoder.c2h(condition)
        decode_embed = torch.cat([z, squashed_condition], 1)
        embeds.append(decode_embed)
    
    labels[i] = [tup[1] for tup in lda[corpus[pair_i]]]

print(f'runtime: {time.time() - start:.2f}s')
list(zip(labels, ts))

runtime: 1082.96s


[(array([0.01256001, 0.07217589, 0.17199828, 0.74326587]),
  'theyve hit 3 5 million soundcloud plays been made one of hype machines ones to watch in 2015 and are also on this weeks radio 1 playlist'),
 (array([0.28381112, 0.02176651, 0.02220486, 0.67221749]),
  'by sound check december 9 2014 listen to the new release from artist aka george lewis song_title'),
 (array([0.01192618, 0.11822031, 0.0934166 , 0.77643687]),
  'so its nice to see the cologne based singer producer back in action with song_title the new single is self described by marius lauber as an ode to the night and is a superb introduction to his forthcoming debut album'),
 (array([0.78441381, 0.19241385, 0.01122425, 0.01194806]),
  'yelle is a french electropop trio consisting of yelle aka julie budet songwriter and vocals grandmarnier aka jean francois perrier producer and tephr aka tanguy destable producer'),
 (array([0.0210867 , 0.02102763, 0.39596498, 0.5619207 ]),
  'artist is embarking on a massive world tour whic

Trying to get a feel for realistic topic distribution in sentences of my desired topic...

In [20]:
sorted([(lda[corpus[i]], datasets.normalize_string(sents[i])) for i in range(n_samples) if max(lda[corpus[i]],key=itemgetter(1))[0] == 2], reverse=True)

[([(0, 0.46298915), (1, 0.024377426), (2, 0.48967716), (3, 0.022956217)],
  'music aside on almost every other front without you bucks contemporary trends in edm'),
 ([(0, 0.45214072), (1, 0.06529999), (2, 0.46577886), (3, 0.01678043)],
  'at times like these its comforting to know that there is somebody who truly seeks balance in the world let alone somebody who can help you forget about all of it'),
 ([(0, 0.4482146), (1, 0.015815554), (2, 0.521623), (3, 0.014346868)],
  'spanning nine tracks each with some tie to the forested scandinavian countryside the work demonstrates artist s methodic dedication to the craft and bridges the gap between earthly and otherworldly sounds'),
 ([(0, 0.44535336), (1, 0.021235334), (2, 0.5118354), (3, 0.021575866)],
  'they are definitely one of the most iconic indie rock bands of the late 2000s and we are stoked to see them tour again'),
 ([(0, 0.4085881), (1, 0.02761355), (2, 0.5406606), (3, 0.023137746)],
  'koz might do well to remember how good he

In [21]:
'''# take the topic distributions for top 25% of "good" sentences (sentences at the top of my favorite topic)
fav_topic = 3
n_sents_in_fav_topic = topic_counter[fav_topic]
topic_ideals = torch.mean(torch.tensor(sorted([[tup[1] for tup in lda[corpus[i]]] for i in range(len(texts)) if max(lda[corpus[i]],key=itemgetter(1))[0] == fav_topic], reverse=True), dtype=torch.float)[:int(n_sents_in_fav_topic/4)], dim=0)
topic_ideals'''

# +1 if a good topic, -1 if bad, (close to) 0 if neutral
topic_weights = torch.tensor([-1, 1, -1, 0], dtype=torch.float)
fav_topic = 1
n_sents_in_fav_topic = topic_counter[fav_topic]
# normalize this distribution to be closer to realistic topic distributions for "good" sentences
topic_weights = topic_weights * torch.mean(torch.tensor(sorted([[tup[1] for tup in lda[corpus[i]]] for i in range(len(texts)) if max(lda[corpus[i]],key=itemgetter(1))[0] == fav_topic], reverse=True), dtype=torch.float)[:int(n_sents_in_fav_topic/4)], dim=0)

topic_weights

tensor([-0.3218,  0.5169, -0.0845,  0.0000])

## Data

In [22]:
batch_size = 16
embed_size = embeds[0].size(1)

In [23]:
from fastai.dataset import *

class LatentDataset(Dataset):
    def __init__(self, embeds, labels): self.embeds,self.labels = embeds,labels
    def __getitem__(self, idx): return A(self.embeds[idx], self.labels[idx])
    def __len__(self): return len(self.embeds)
    
ds = LatentDataset(embeds, labels.astype(float))
dl = DataLoader(ds, batch_size)
md = ModelData('.', dl, None)

In [24]:
md.trn_ds[0]

[array([[ 0.2529 ,  1.08561,  0.42682,  1.4279 , -3.13116,  0.38173, -1.65676, -0.75316,  1.6511 ,  0.52037,
          1.79006, -0.95392,  0.9263 , -0.18469,  0.96248,  0.74508, -0.22671,  0.60438,  0.94637, -1.89324,
          0.7963 , -0.15779,  0.6037 ,  1.85373,  0.99637,  2.22095, -1.09685, -0.38759, -1.11662, -1.21176,
          0.54961,  1.54183, -1.11218, -0.41495,  1.0431 , -0.38825, -0.19873, -0.07921,  0.28285,  0.6974 ,
          1.00569,  0.15646, -0.07856,  1.94879,  0.10479,  0.6953 ,  4.60469,  1.25179, -1.37442,  0.298  ,
          0.68261, -0.08909,  1.37193, -0.20981, -0.67205, -1.97877,  0.2947 ,  0.69821,  0.07696,  0.49988,
         -1.21676,  2.34188,  2.49875,  1.38599,  3.80572, -2.02557,  0.87619,  1.97335, -0.19872,  1.38709,
         -0.80022, -0.28105,  1.47238,  0.58725, -3.29688, -1.59301,  0.5262 , -0.81198, -1.91472, -0.53752,
         -0.70082, -1.9064 ,  2.98032, -1.17089,  0.6498 , -1.417  ,  1.04999, -0.04363, -4.99423, -1.57813,
          0.52295, 

## Model

In [25]:
n_hidden = 1024
lr = 3e-4
fixed_genres = torch.FloatTensor(dataset.encode_genres(['neo soul', 'pop', 'r&b', 'urban contemporary'])).to(DEVICE)

In [26]:
import torch.optim as optim
import torch.nn as nn

class LCGAN_D(nn.Module):
    '''Discriminator'''
    def __init__(self, n_embed, n_hidden=n_hidden, n_output=n_topics):
        super(LCGAN_D, self).__init__()
        
        self.i2h = nn.Linear(n_embed, n_hidden)
        self.h2h = nn.Linear(n_hidden, n_hidden)
        self.h2o = nn.Linear(n_hidden, n_output)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, emb):
        x = emb
        x = self.relu(self.i2h(x))
        x = self.relu(self.h2h(x))
        x = self.relu(self.h2h(x))
        v = self.sigmoid(self.h2o(x))
        
        return v

class LCGAN_G(nn.Module):
    '''Generator'''
    def __init__(self, n_embed, n_hidden=n_hidden):
        super(LCGAN_G, self).__init__()
        self.n_embed = n_embed
        
        self.i2h = nn.Linear(n_embed, n_hidden)
        self.h2h = nn.Linear(n_hidden, n_hidden)
        # hidden-to-gating mechanism
        self.h2g = nn.Linear(n_hidden, 2*n_embed)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, emb):
        x = emb
        x = self.relu(self.i2h(x))
        x = self.relu(self.h2h(x))
        x = self.relu(self.h2h(x))
        x = self.h2g(x)
        
        # gating mechanism: allow network to remember/forget
        # what it wants to about the original emb(edding) and x
        emb_mid = x[:, self.n_embed:]
        gates = self.sigmoid(x[:, :self.n_embed])
        demb = gates * emb_mid # TODO: why naming?
        emb_prime = (1 - gates)*emb + demb
        
        return emb_prime

In [27]:
class LCGAN(nn.Module):
    def __init__(self, D, G, batch_size=batch_size):
        super(LCGAN, self).__init__()
        self.batch_size = batch_size
        
        self.D = D
        self.G = G

    def train(self):
        self.D.train()
        self.G.train()
        
    def eval(self):
        self.D.eval()
        self.G.eval()
        
    def forward(self, emb=None):       
        if emb is not None:
            # train discriminator
            #embed = to_embed(z, fixed_genres)
            v = self.D(emb)
            return v
        else: # train GAN
            # gaussian random noise
            emb_prior = torch.randn(self.batch_size, self.G.n_embed).to(DEVICE)
            
            emb_prime = self.G(emb_prior)
            v_prime = self.D(emb_prime)
            
            return v_prime

## Training

In [28]:
gan = LCGAN(LCGAN_D(embed_size).to(DEVICE), LCGAN_G(embed_size).to(DEVICE)).to(DEVICE)

In [29]:
import fastai

fastai.core.set_trainable(gan.D, True)
fastai.core.set_trainable(gan.G, True)

opt_d = optim.Adam(gan.D.parameters(), lr=lr)
opt_g = optim.Adam(gan.G.parameters(), lr=lr)

In [30]:
# test what the GAN is doing before any training
for i in range(10):
    print(generate(fixed_genres, gan)[0])

UNK UNK UNK one of of the and is signed of no UNK of the ive portland for two material
this of album in UNK and with an any and she an an her to her next next
it seems to be be waiting on the the the here and and it with a out a of indie uk dj radio
with the a UNK UNK UNK have saw just with part of the UNK music of jay their their live tour
UNK that be up in in the UNK the and in and the the UNK that is keep isnt on our thoughts
UNK UNK and and and the and the artist is is a the a the his s enjoy
we is is is UNK UNK with UNK UNK UNK and that featured of the that as the being film
UNK was a a lot of of of on the the UNK and and and their and vocals and their lo on any girls
came out of the the andre UNK and the and and track that draws taken out the last week
you is is is with a a dreamy but you away on the hunger the on dirty march


In [31]:
# adapted from: https://github.com/fastai/fastai/blob/master/courses/dl2/wgan.ipynb
def train(n_iter, alternate=False, first=False):
    gen_iters = 0
    for epoch in trange(n_iter):
        gan.train()
        data_iter = iter(md.trn_dl)
        i, n = 0, len(md.trn_dl)
        
        def train_G():
            ''' Train generator '''
            nonlocal gen_iters
            
            fastai.core.set_trainable(gan.D, False)
            fastai.core.set_trainable(gan.G, True)

            gan.G.zero_grad()

            #print(i, n)
            #loss = nn.MSELoss()
            v_prime = gan()
            log_loss_g = torch.log(v_prime)
            loss_g = (-log_loss_g * topic_weights).mean()
            #loss_g = loss(v_prime, topic_ideals.expand(batch_size, topic_ideals.size(0)))
            loss_g.backward()
            opt_g.step()
            gen_iters += 1
            
            return loss_g

        def train_D():
            ''' Train discriminator '''
            nonlocal i
            
            fastai.core.set_trainable(gan.D, True)
            fastai.core.set_trainable(gan.G, False)
            d_iters = 100 if (first and (gen_iters < 25) or (gen_iters % 500 == 0)) else 3
            j = 0

            while (j < d_iters) and (i < n):
                j += 1; i += 1
                batch = next(data_iter)
                #print(j, i, batch[0].size(), batch[1].size())
                emb_real = batch[0].to(DEVICE)
                v = gan(emb_real).to(DEVICE)

                gan.D.zero_grad()

                #loss_d = - (batch[1] * torch.log(v) + (1.0-batch[1]) * torch.log(1.0 - v)).mean()
                loss_d = - (batch[1].to(DEVICE) * torch.log(v) + (1.0-batch[1].to(DEVICE)) * torch.log(1.0 - v)).mean()
                loss_d.backward()
                opt_d.step()
                
                pbar.update()
                
            return loss_d
        
        with tqdm(total=n) as pbar:
            while i < n:
                if alternate:
                    # train discriminator
                    loss_d = train_D()
                    # then train generator a little bit
                    loss_g = train_G()
                else:
                    # train generator only
                    i += 1
                    loss_g = train_G()
                    pbar.update()
        
        if alternate:
            print(f'Loss_D {to_np(loss_d)}; Loss_G {to_np(loss_g)}; ')
        else:
            print(f'Loss_G {to_np(loss_g)}; ')

In [32]:
def train_and_generate(gan, n_epoch, genres=fixed_genres, alternate=False, n_sample=10):
    train(n_epoch, alternate)
    res = []
    for i in range(n_sample):
        res.append(generate(genres, gan))
    return res

In [33]:
[res[0] for res in train_and_generate(gan, 1, alternate=True)]

100%|██████████| 125/125 [00:01<00:00, 80.68it/s]
Loss_D 0.5566402673721313; Loss_G -0.0005247571971267462; 
100%|██████████| 1/1 [00:01<00:00,  1.55s/it]

['in their to the the track in and listeners a track with a a the and its its easy the vocals with the the and harmonies as as with the bit of of of rather harmonies',
 'an the the tracks of the UNK UNK UNK has show with UNK UNK with with with vocals and and and and and and and and and the and well the vocals beats for or banks',
 'on on a burst of the trio UNK UNK a a and addition of the the UNK and of and the the tracks has to love like a a a sound that you an on like a excellent vocals on this radio',
 'in a bit in in used to with the sound in in the sound and sounds and the the and it and a and as a song like it exactly like like two two',
 'the the UNK UNK the UNK a UNK UNK UNK on a a UNK with her single and may a while for love that an bit for the an as love the an artist s female smith',
 'the the UNK with the UNK UNK to with all with their latest UNK and out on the with a bit the a to to track to the to to the the to an record',
 'their new a their has out in your falling in wi

In [34]:
[res[0] for res in train_and_generate(gan, 1)]

100%|██████████| 125/125 [00:02<00:00, 53.30it/s]
Loss_G 2.4423203468322754; 
100%|██████████| 1/1 [00:02<00:00,  2.22s/it]

[' elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant bass love',
 ' elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant bass love',
 ' elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elephant elep

## Inference with saved models

In [33]:
def save():
    save_state_filename = 'ganG_state.pt'
    torch.save(gan.G.state_dict(), save_state_filename)
    print('Saved as %s' % (save_state_filename))
    
save()

Saved as ganG_state.pt


In [35]:
def gan_generate(vae, condition, n_latent, ganG, max_sample=False, trunc_sample=True):
    with torch.no_grad():
        ganG.eval()
        z = torch.randn(1, n_latent).to(DEVICE)
        decode_embed = to_embed(z, condition).to(DEVICE)
        z_prime = ganG(decode_embed)

        generated = vae.decoder.generate_with_embed(z_prime, 50, temp, DEVICE, max_sample=max_sample, trunc_sample=trunc_sample)
        generated_str = model.float_word_tensor_to_string(output_side, generated)

        EOS_str = f' {output_side.index_to_word(torch.LongTensor([EOS_token]))} '

        if generated_str.endswith(EOS_str):
            generated_str = generated_str[:-5]

        # flip it back
        return generated_str[::-1], z, z_prime

In [36]:
ganG = LCGAN_G(embed_size).to(DEVICE)
ganG.load_state_dict(torch.load('ganG_state.pt'))

gan_generate(vae, torch.FloatTensor(dataset.encode_genres(['hip hop','pop','pop rap','rap','trap music'])).to(DEVICE), n_latent, ganG)

('artist artist are reminiscent of many of our on this UNK the the may just get familiar with each other touch',
 tensor([[ 1.3266,  0.3218,  1.2880, -0.1209, -0.9666, -0.0200,  1.1089,
          -0.5409, -0.0544, -0.7155, -0.4585, -0.7051, -0.2313,  2.9348,
          -2.2418,  0.6079,  1.5914,  0.4511, -1.5066,  0.7279,  1.7946,
          -0.4472,  0.5314, -1.8664,  0.0953, -0.9538,  1.1704,  1.5429,
           0.0830, -0.4723, -1.2089, -1.2786, -0.1234, -0.6102,  0.5797,
           2.2356,  0.0159,  3.0624,  1.3961, -1.2540,  0.0048,  1.8790,
          -0.9674, -0.0236,  1.1817,  0.2974, -1.3599, -0.0940, -0.9623,
           1.3509,  0.3204, -0.2358, -0.8275,  1.4198, -0.8605,  0.4200,
           0.8859, -0.0607, -1.6140,  0.3870, -0.8037,  1.0355, -0.8868,
          -0.0426,  0.4908, -0.1655, -0.5113, -0.3807,  1.9757, -0.6120,
          -0.0254,  1.1762,  0.2535,  0.0041,  0.2726, -0.9798,  0.8222,
           1.4550, -0.2161, -0.0972,  0.2395,  0.0939,  0.1039,  0.7530,
           

# Extras

## Labeling data

### 'Banned' approach

label a sample as -1 (=="bad") if it contains a banned word; label as 1 otherwise (=="good")

In [7]:
#new_labels = np.array([(1, -1), (10, -1)])

banned = ['below']
labels = np.ones(n_samples, dtype=int)
gens_lose = list(set([i for b in banned for i in np.where([b in g.split() for g in gens])[0]]))
labels[gens_lose] = -1
zs_keep = zs

labels

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1, -1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1])

In [8]:
from collections import Counter
c1 = Counter([word for gen in gens for word in gen.split()])
[(b, c1[b]) for b in banned], c1.most_common(25)

([('below', 3)],
 [('the', 94),
  ('UNK', 82),
  ('of', 67),
  ('a', 66),
  ('and', 59),
  ('to', 55),
  ('is', 29),
  ('as', 26),
  ('artist', 26),
  ('with', 26),
  ('from', 24),
  ('it', 21),
  ('this', 21),
  ('track', 20),
  ('on', 20),
  ('up', 19),
  ('one', 18),
  ('be', 15),
  ('in', 14),
  ('has', 14),
  ('at', 11),
  ('i', 10),
  ('been', 10),
  ('trying', 10),
  ('that', 10)])

### 'Realism' approach

label a sample as 1 (=="good") if it came from the training data; label as -1 (=="bad") if it came from a random Gaussian `z`

In [10]:
from pytorchtextvae import model

input, target, condition = model.random_training_set(pairs, input_side, output_side, random_state, DEVICE)
model.long_word_tensor_to_string(input_side, input), dataset.decode_genres(condition)

('newcomer artist released his debut single last week and its already gaining major attention and a following that is demanding more after fill EOS ',
 ['vapor soul'])

In [11]:
temperature = 1.0

m, l, z, decoded = vae(input, target, condition, DEVICE, temperature)

z.size(), decoded.size()

(torch.Size([1, 128]), torch.Size([24, 333336]))

In [12]:
generate(condition, z=z, max_sample=True)[0]

'artist released his debut single and is just released and more than a ago and that that get your attention'

**TODO:** shouldn't generate with max sampling always return the same sample?

Even though the encoding is imperfect, we will still consider these `z`s as "realistic"

In [13]:
real_z = []
real_gens = []
for i in range(int(n_samples/2)):
    input, target, condition = model.random_training_set(pairs, input_side, output_side, random_state, DEVICE)
    with torch.no_grad():
        _, _, z, _ = vae(input, target, condition, DEVICE, temperature)
        real_z.append(z)
        real_gens.append(generate(condition, z=z, max_sample=True)[0])

In [14]:
from collections import Counter
c1 = Counter([word for gen in real_gens for word in gen.split()])
c1.most_common(25)

[('UNK', 106),
 ('and', 53),
 ('the', 50),
 ('a', 30),
 ('of', 30),
 ('on', 28),
 ('to', 24),
 ('in', 15),
 ('is', 15),
 ('i', 14),
 ('with', 14),
 ('w', 13),
 ('even', 11),
 ('that', 10),
 ('for', 10),
 ('who', 10),
 ('this', 9),
 ('always', 9),
 ('it', 9),
 ('tour', 8),
 ('into', 8),
 ('be', 8),
 ('but', 8),
 ('you', 7),
 ('new', 7)]

In [15]:
# up until now, `zs` held random zs - now concat with real zs
zs = torch.cat((torch.stack(real_z).squeeze(), torch.stack(zs[:int(n_samples/2)]).squeeze()))

In [16]:
labels = np.ones(n_samples, dtype=int)
labels[range(len(real_z), len(zs))] = -1
embeds = zs

labels

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

### Realism + Readability
in addition to the realism discriminator, add readability as a conditioning attribute

In [14]:
# https://github.com/shivam5992/textstat/issues/43
from textstat.textstat import textstat

# Score 	 Difficulty
# 90-100 	 Very Easy
# 80-89 	 Easy
# 70-79 	 Fairly Easy
# 60-69 	 Standard
# 50-59 	 Fairly Difficult
# 30-49 	 Difficult
# 0-29 	 Very Confusing

[textstat.flesch_reading_ease(sent) for sent in ["This is a sentence", "To be or not to be", ]]

[92.8, 116.15]

In [26]:
[(' '.join(gen.replace('UNK', '').split()), textstat.text_standard(' '.join(gen.replace('UNK', '').split()))) for gen in np.array(real_gens)[random_state.choice(len(real_gens), 10)]]

[('one of those who dont want to the one of their own who are their own in their own music',
  '4th and 5th grade'),
 ('the remix the production from the original and it sounds like a it with the bass and that sounds like it sounds like it from the chorus',
  '6th and 7th grade'),
 ('we been waiting for the last year while while its soon as waiting for little while we we as we as it',
  '12th and 13th grade'),
 ('one of those who dont want to the one of their own who are their own in their own music',
  '4th and 5th grade'),
 ('we trying to get into an artist is at the one of the and in the song that a part of the is in the trying to into it into a song to artist is just into part into that that or else',
  '14th and 15th grade'),
 ('if youre on the first version of you of know if you hear the version of this is that is to a song',
  '1th and 2th grade'),
 ('we been waiting for the last year while while its soon as waiting for little while we we as we as it',
  '12th and 13th grade'),


## Training

retrain discriminator with new samples... `z_prime`s that the discriminator is still not rejecting strongly enough

In [36]:
labels = np.ones(len(g2), dtype=int)
gens_lose = list(set([i for b in banned for i in np.where([b in res[0].split() for res in g2])[0]]))
#gens_keep = list(set(range(len(g2))) - gens_lose)
labels[gens_lose] = -1
zs_keep = np.array([res[2] for res in g2], dtype=object)

labels        

array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

In [32]:
ds = LatentDataset(zs_keep, labels.astype(float))
dl = DataLoader(ds, batch_size)
md = ModelData('.', dl, None)

In [33]:
[res[0] for res in train_and_generate(gan, 1, alternate=True)]

100%|██████████| 7/7 [00:00<00:00, 284.11it/s]
Loss_D 1.193210244178772; Loss_G 1.1896806955337524; 


['artist has teams up with a a one of the UNK and it would be taken from i don no feat',
 'this with a remix of UNK and trying to take of the UNK theres no feat',
 'the remix of UNK has been trying to take of the UNK theres no feat',
 'UNK is a listen to the UNK and coming side of what to look at feat',
 'a new track called UNK and serves as one of the trying would be ready for at times',
 'artist has up with a UNK and one of the track will be inspired as no feat',
 'artist is back in the likes of UNK and serves as a no times',
 'artist who reminds back to a more stuck on the UNK and theres many feat',
 'he continues to serves as one of the week what i look no feat',
 'with a layers of UNK and and UNK of the perfect trying i look at feat']