https://arxiv.org/abs/1711.05772 / https://arxiv.org/abs/1802.04877

https://github.com/natashamjaques/magenta/blob/affective-reward/magenta/models/affective_reward/latent_gan.py

In [1]:
import torch

print('cuda.is_available:', torch.cuda.is_available())
print(f'available: {torch.cuda.device_count()}; current: {torch.cuda.current_device()}')
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
print('pytorch', torch.__version__)

cuda.is_available: True
available: 1; current: 0
cuda:0
pytorch 0.4.0


## Labeling data

In [2]:
# pip install git+https://github.com/iconix/pytorch-text-vae.git
from pytorchtextvae import generate

In [3]:
DEVICE = torch.device('cpu') # CPU inference
n_samples = 200
temp = 0.2

# workaround for un-pickling after module directory change https://stackoverflow.com/a/45264751
#import sys
#sys.path.append('../../pytorch-text-vae/pytorchtextvae')

vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state = generate.load_model('../../pytorch-text-vae/model/best/reviews_and_metadata_5yrs_state.pt', 'reviews_and_metadata_5yrs_stored_info.pkl', DEVICE, cache_path='../../pytorch-text-vae/model/best/tmp')

Fetching cached info at ../../pytorch-text-vae/model/best/tmp/reviews_and_metadata_5yrs_stored_info.pkl
Cache ../../pytorch-text-vae/model/best/tmp/reviews_and_metadata_5yrs_stored_info.pkl loaded (load time: 0.61s)
Found saved model ../../pytorch-text-vae/model/best/reviews_and_metadata_5yrs_state.pt
MAX_SAMPLE: False; TRUNCATED_SAMPLE: True
Trained for 360000 steps (load time: 18.87s)
Setting new random seed


In [4]:
#generate.generate(vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE, genres=['downtempo', 'dream pop', 'indietronica'], num_sample=10, temp=temp)

In [5]:
#gens, zs, conditions = generate.generate(vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE, num_sample=n_samples, temp=temp)

In [6]:
#list(zip(range(len(gens)), gens))

In [7]:
def to_embed(z, condition):
    if condition.dim() == 1:
        condition = condition.unsqueeze(0)
    squashed_condition = vae.decoder.c2h(condition)
    return torch.cat([z, squashed_condition], 1)

In [8]:
n_latent = 128
from pytorchtextvae.datasets import EOS_token

def generate(condition, gan=None, z=None, max_sample=False, truncated_sample=True, temp=temp):
    with torch.no_grad():
        if gan is None:
            z_prime = z
        else:
            gan.eval()
            z = torch.randn(1, n_latent).to(DEVICE)
            decode_embed = to_embed(z, condition).to(DEVICE)
            z_prime = gan.G(decode_embed)

        generated = vae.decoder.generate_with_embed(z_prime, 50, temp, DEVICE, max_sample=max_sample, trunc_sample=truncated_sample)
        generated_str = model.float_word_tensor_to_string(output_side, generated)

        EOS_str = f' {output_side.index_to_word(torch.LongTensor([EOS_token]))} '

        if generated_str.endswith(EOS_str):
            generated_str = generated_str[:-5]

        # flip it back
        return generated_str[::-1], z, z_prime

### Topical
Prefer certain topics to others

In [9]:
from pytorchtextvae import datasets

def tokenize(line):
    l = line.strip().lstrip().rstrip()
    l = datasets.normalize_string(l)
    return l.split(' ')

In [10]:
n_examples = 3

sents = [pair[0] for pair in pairs]
texts = [tokenize(sentence) for sentence in sents]
texts[:n_examples]

[['ive',
  'been',
  'listening',
  'to',
  'so',
  'much',
  'lately',
  'with',
  'all',
  'of',
  'this',
  'time',
  'in',
  'the',
  'van'],
 ['song_title',
  'is',
  'the',
  'first',
  'track',
  'axel',
  'boman',
  'and',
  'john',
  'talabot',
  'shared',
  'from',
  'their',
  'artist',
  'next',
  'lp'],
 ['on',
  'our',
  'favorite',
  'new',
  'track',
  'song_title',
  'simon',
  'green',
  'aka',
  'bonobo',
  'does',
  'exactly',
  'what',
  'he',
  's',
  'good',
  'at',
  'creating',
  'complex',
  'melodies',
  'a',
  'la',
  'cirrus',
  'and',
  'adding',
  'chimes',
  'and',
  'bells',
  'so',
  'effortlessly',
  'that',
  'even',
  'pantha',
  'du',
  'prince',
  'could',
  'die',
  'for']]

In [11]:
from nltk.corpus import stopwords

# remove stop words and words that appear only once
stoplist = [datasets.normalize_string(word) for word in stopwords.words('english')]
fillerlist = ['author', 'song_title', 'artist', 'sitename']

texts = [[word for word in text if word not in stoplist and word not in fillerlist] for text in texts]
texts[:n_examples]

[['ive', 'listening', 'much', 'lately', 'time', 'van'],
 ['first',
  'track',
  'axel',
  'boman',
  'john',
  'talabot',
  'shared',
  'next',
  'lp'],
 ['favorite',
  'new',
  'track',
  'simon',
  'green',
  'aka',
  'bonobo',
  'exactly',
  'good',
  'creating',
  'complex',
  'melodies',
  'la',
  'cirrus',
  'adding',
  'chimes',
  'bells',
  'effortlessly',
  'even',
  'pantha',
  'du',
  'prince',
  'could',
  'die']]

In [12]:
from gensim.corpora.dictionary import Dictionary

dictionary = Dictionary(texts)

In [13]:
from gensim.models.ldamodel import LdaModel
import time

start = time.time()
n_topics = 4
passes = 20 # number of passes through documents
iterations = 400 # maximum number of iterations through the corpus when inferring the topic distribution of a corpus.
minimum_probability = 0

corpus = [dictionary.doc2bow(text) for text in texts]
# Train the model on the corpus.
lda = LdaModel(corpus, id2word=dictionary, num_topics=n_topics, iterations=iterations, passes=passes, minimum_probability=minimum_probability)
#lda = LdaModel(corpus, id2word=dictionary, num_topics=n_topics)
print(f'Runtime: {time.time() - start:.2f}s')
lda.print_topics(n_topics)

Runtime: 533.40s


[(0,
  '0.025*"album" + 0.017*"release" + 0.016*"ep" + 0.015*"debut" + 0.014*"new" + 0.013*"released" + 0.010*"single" + 0.009*"via" + 0.008*"first" + 0.007*"track"'),
 (1,
  '0.025*"new" + 0.015*"remix" + 0.014*"single" + 0.013*"track" + 0.010*"one" + 0.009*"producer" + 0.008*"year" + 0.008*"duo" + 0.007*"based" + 0.007*"last"'),
 (2,
  '0.015*"track" + 0.013*"vocals" + 0.010*"pop" + 0.007*"sound" + 0.006*"song" + 0.006*"like" + 0.005*"vocal" + 0.005*"production" + 0.005*"synth" + 0.005*"beat"'),
 (3,
  '0.015*"like" + 0.013*"song" + 0.011*"one" + 0.011*"music" + 0.009*"get" + 0.009*"time" + 0.006*"really" + 0.006*"love" + 0.006*"something" + 0.006*"much"')]

In [14]:
from operator import itemgetter

for i in range(n_examples):
    print(max(lda[corpus[i]],key=itemgetter(1)), datasets.normalize_string(sents[i]))

(3, 0.7441754) ive been listening to so much lately with all of this time in the van
(0, 0.9229266) song_title is the first track axel boman and john talabot shared from their artist next lp
(2, 0.42214796) on our favorite new track song_title simon green aka bonobo does exactly what he s good at creating complex melodies a la cirrus and adding chimes and bells so effortlessly that even pantha du prince could die for


In [15]:
import pyLDAvis.gensim
pyLDAvis.enable_notebook()

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

pyLDAvis.gensim.prepare(lda, corpus, dictionary)

In [16]:
from collections import Counter
Counter([max(lda[corpus[i]], key=itemgetter(1))[0] for i in range(len(texts))])

Counter({3: 34284, 0: 20337, 2: 26852, 1: 22994})

In [17]:
out = sorted([f'{max(lda[corpus[i]], key=itemgetter(1))} {datasets.normalize_string(sents[i])}\n' for i in range(len(texts))], reverse=True)
with open('pairs_sentence_topics.txt', 'w') as f:
    f.writelines(out)

^using this file to select desirable/preferred topics for **weights** below.

In [18]:
# 1 if a good topic, -1 if bad, 0 if neutral
topic_weights = torch.tensor([-1, 0, 1, 0], dtype=torch.float)

In [19]:
def get_example(i, pairs, input_side, output_side, random_state, device):
    pair = pairs[i]

    inp = model.word_tensor(input_side, pair[0]).to(device)
    target = model.word_tensor(output_side, pair[1]).to(device)
    condition = torch.tensor(pair[2], dtype=torch.float).unsqueeze(0).to(device) if len(pair) == 3 else None

    return inp, target, condition

In [20]:
import numpy as np
import time
from pytorchtextvae import model

labels = np.zeros((n_samples, n_topics), dtype=float)
embeds = []
start = time.time()

# debug vars
ts = []

for i in range(n_samples):
    pair_i = random_state.choice(len(pairs))
    ts.append(pairs[pair_i][0])
    input, target, condition = get_example(pair_i, pairs, input_side, output_side, random_state, DEVICE)
    with torch.no_grad():
        _, _, z, _ = vae(input, target, condition, DEVICE, temp)
        squashed_condition = vae.decoder.c2h(condition)
        decode_embed = torch.cat([z, squashed_condition], 1)
        embeds.append(decode_embed)
    
    labels[i] = [tup[1] for tup in lda[corpus[pair_i]]]

print(f'runtime: {time.time() - start:.2f}s')
list(zip(labels, ts))

runtime: 143.83s


[(array([0.2141604 , 0.02269608, 0.02091037, 0.7422331 ]),
  'the vacation however turned into a writing session and after a couple of weeks they had come up with material for a whole album'),
 (array([0.22517413, 0.0143771 , 0.28643605, 0.4740127 ]),
  'this is in advance of guesting at their debut london show together with loads of other hd faves if you like the rhythmic chiming flow of goddards work then enjoy these two'),
 (array([0.06529655, 0.08316156, 0.26611835, 0.58542353]),
  'aluna sings if you wanna train me like an animal better keep your eye on my every move theres no need to be so damn'),
 (array([0.03611599, 0.23360415, 0.03688584, 0.69339401]),
  'conversely where is the tipping point when suddenly the track is no longer his but'),
 (array([0.02311261, 0.16285406, 0.78821659, 0.02581673]),
  'the vocal sample is chopped and mysterious a yearning female voice lost in unfamiliar sounds'),
 (array([0.25903478, 0.04536976, 0.04359853, 0.65199691]),
  'that kind of action w

## Data

In [21]:
batch_size = 16
embed_size = embeds[0].size(1)

In [22]:
from fastai.dataset import *

class LatentDataset(Dataset):
    def __init__(self, embeds, labels): self.embeds,self.labels = embeds,labels
    def __getitem__(self, idx): return A(self.embeds[idx], self.labels[idx])
    def __len__(self): return len(self.embeds)
    
ds = LatentDataset(embeds, labels.astype(float))
dl = DataLoader(ds, batch_size)
md = ModelData('.', dl, None)

In [23]:
md.trn_ds[0]

[array([[-1.88831, -3.77544,  2.12043, -1.49156,  1.21908, -0.69846,  1.31653,  0.85273, -0.50136, -2.75683,
         -0.68626,  1.46093, -0.27001,  0.59422, -0.24358,  0.71455,  2.37467,  0.38596,  0.42892,  1.29327,
         -0.22243,  0.47079, -0.64109, -1.86893, -1.40863,  1.66425, -0.0186 ,  0.10167,  0.67046,  0.29173,
          3.53054, -1.98305, -0.35301,  3.9853 ,  0.63015,  1.03038, -0.06038, -0.43655,  1.57195,  0.28702,
         -1.00542, -1.09672, -2.6014 ,  0.20153, -0.09539, -0.49441,  0.85762,  0.39356,  1.42792, -0.87912,
          0.60129, -1.10921,  0.26257,  2.55725,  1.715  ,  1.53366, -0.25937, -1.6421 ,  1.37921,  0.68685,
          0.38404, -0.67452,  1.57176,  2.73031,  0.68616,  0.62398, -2.13882,  2.72947, -0.96153,  0.40305,
          0.59293,  3.23975, -0.37111,  1.51339,  0.9162 ,  2.90308, -0.18822, -1.00512, -1.28772,  2.21644,
          0.90739,  1.4183 , -0.07918,  1.46065, -2.43391, -1.44411,  0.54959,  1.65002, -0.36701,  3.50553,
          0.99829, 

## Model

In [24]:
n_hidden = 1024
lr = 3e-4
fixed_genres = torch.FloatTensor(dataset.encode_genres(['neo soul', 'pop', 'r&b', 'urban contemporary'])).to(DEVICE)

In [25]:
import torch.optim as optim
import torch.nn as nn

class LCGAN_D(nn.Module):
    '''Discriminator'''
    def __init__(self, n_embed, n_hidden=n_hidden, n_output=n_topics):
        super(LCGAN_D, self).__init__()
        
        self.i2h = nn.Linear(n_embed, n_hidden)
        self.h2h = nn.Linear(n_hidden, n_hidden)
        self.h2o = nn.Linear(n_hidden, n_output)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, emb):
        x = emb
        x = self.relu(self.i2h(x))
        x = self.relu(self.h2h(x))
        x = self.relu(self.h2h(x))
        v = self.sigmoid(self.h2o(x))
        
        return v

class LCGAN_G(nn.Module):
    '''Generator'''
    def __init__(self, n_embed, n_hidden=n_hidden):
        super(LCGAN_G, self).__init__()
        self.n_embed = n_embed
        
        self.i2h = nn.Linear(n_embed, n_hidden)
        self.h2h = nn.Linear(n_hidden, n_hidden)
        # hidden-to-gating mechanism
        self.h2g = nn.Linear(n_hidden, 2*n_embed)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, emb):
        x = emb
        x = self.relu(self.i2h(x))
        x = self.relu(self.h2h(x))
        x = self.relu(self.h2h(x))
        x = self.h2g(x)
        
        # gating mechanism: allow network to remember/forget
        # what it wants to about the original emb(edding) and x
        emb_mid = x[:, self.n_embed:]
        gates = self.sigmoid(x[:, :self.n_embed])
        demb = gates * emb_mid # TODO: why naming?
        emb_prime = (1 - gates)*emb + demb
        
        return emb_prime

In [26]:
class LCGAN(nn.Module):
    def __init__(self, D, G, batch_size=batch_size):
        super(LCGAN, self).__init__()
        self.batch_size = batch_size
        
        self.D = D
        self.G = G

    def train(self):
        self.D.train()
        self.G.train()
        
    def eval(self):
        self.D.eval()
        self.G.eval()
        
    def forward(self, emb=None):       
        if emb is not None:
            # train discriminator
            #embed = to_embed(z, fixed_genres)
            v = self.D(emb)
            return v
        else: # train GAN
            # gaussian random noise
            emb_prior = torch.randn(self.batch_size, self.G.n_embed).to(DEVICE)
            
            emb_prime = self.G(emb_prior)
            v_prime = self.D(emb_prime)
            
            return v_prime

## Training

In [27]:
gan = LCGAN(LCGAN_D(embed_size).to(DEVICE), LCGAN_G(embed_size).to(DEVICE)).to(DEVICE)

In [28]:
import fastai

fastai.core.set_trainable(gan.D, True)
fastai.core.set_trainable(gan.G, True)

opt_d = optim.Adam(gan.D.parameters(), lr=lr)
opt_g = optim.Adam(gan.G.parameters(), lr=lr)

In [29]:
# test what the GAN is doing before any training
for i in range(10):
    print(generate(fixed_genres, gan)[0])

UNK danish trio artist is back into the world months combined with a track that was featured on nearly two chance
UNK UNK and UNK for the likes UNK UNK that is up on his sound and as well as light effect
according to the UNK UNK UNK UNK at the UNK UNK UNK UNK for the courtesy of UNK light
artist decided to come back with UNK UNK and its about to come stuck in your major which songs magic
it makes for a a UNK UNK UNK continues to take on her vocals and a first time to pop minutes
the duo are all over and and better known and the i trying to get your play above
UNK UNK and some of the UNK and with a UNK which is going on july
this is going to get obsessed with his sound and grew up on the remix and you think of artist s space
UNK is rather is a UNK UNK and i tried to work in a UNK while staying true to my atmosphere
apart from the likes of UNK and UNK and sing returned into the record that we fell in full space


In [30]:
# adapted from: https://github.com/fastai/fastai/blob/master/courses/dl2/wgan.ipynb
def train(n_iter, alternate=False, first=False):
    gen_iters = 0
    for epoch in trange(n_iter):
        gan.train()
        data_iter = iter(md.trn_dl)
        i, n = 0, len(md.trn_dl)
        
        def train_G():
            ''' Train generator '''
            nonlocal gen_iters
            
            fastai.core.set_trainable(gan.D, False)
            fastai.core.set_trainable(gan.G, True)

            gan.G.zero_grad()

            #print(i, n)
            v_prime = gan()
            log_loss_g = torch.log(v_prime)
            loss_g = (-log_loss_g * topic_weights).mean()
            loss_g.backward()
            opt_g.step()
            gen_iters += 1
            
            return loss_g

        def train_D():
            ''' Train discriminator '''
            nonlocal i
            
            fastai.core.set_trainable(gan.D, True)
            fastai.core.set_trainable(gan.G, False)
            d_iters = 100 if (first and (gen_iters < 25) or (gen_iters % 500 == 0)) else 3
            j = 0

            while (j < d_iters) and (i < n):
                j += 1; i += 1
                batch = next(data_iter)
                #print(j, i, batch[0].size(), batch[1].size())
                emb_real = batch[0].to(DEVICE)
                v = gan(emb_real).to(DEVICE)

                gan.D.zero_grad()

                #loss_d = - (batch[1] * torch.log(v) + (1.0-batch[1]) * torch.log(1.0 - v)).mean()
                loss_d = - (batch[1].to(DEVICE) * torch.log(v) + (1.0-batch[1].to(DEVICE)) * torch.log(1.0 - v)).mean()
                loss_d.backward()
                opt_d.step()
                
                pbar.update()
                
            return loss_d
        
        with tqdm(total=n) as pbar:
            while i < n:
                if alternate:
                    # train discriminator
                    loss_d = train_D()
                    # then train generator a little bit
                    loss_g = train_G()
                else:
                    # train generator only
                    i += 1
                    loss_g = train_G()
                    pbar.update()
        
        if alternate:
            print(f'Loss_D {to_np(loss_d)}; Loss_G {to_np(loss_g)}; ')
        else:
            print(f'Loss_G {to_np(loss_g)}; ')

In [31]:
def train_and_generate(gan, n_epoch, genres=fixed_genres, alternate=False, n_sample=10):
    train(n_epoch, alternate)
    res = []
    for i in range(n_sample):
        res.append(generate(genres, gan))
    return res

In [32]:
[res[0] for res in train_and_generate(gan, 1, alternate=True)]

100%|██████████| 13/13 [00:00<00:00, 81.34it/s]
Loss_D 0.5684986710548401; Loss_G -0.02011490799486637; 
100%|██████████| 1/1 [00:00<00:00,  6.17it/s]

['one of an electro pop ago artist is back in the UNK UNK continues to keep an final atmosphere',
 'is some of the slow and and his production but perfect addition to look at the release of UNK again',
 'UNK UNK come along of the self described in a UNK UNK UNK on portland or stand enjoy',
 'UNK which is a lot of artist and falling into something we were talking about her vocal away',
 'such a new tune duo artist she returns with the image courtesy of its called age',
 'from the r spin artist who all addition from the ultra before you might have to play play',
 'artist is a bit UNK UNK in the UNK and UNK on his long if they were even more minutes',
 'UNK a version of the happens to be released on the year and if youre looking for this aptly titled ride',
 'canadian synth UNK UNK is back in some of the artist and impressed us with some with its soulful atmosphere',
 'some about the UNK UNK UNK UNK and get for a song is time to share it earlier today']

In [34]:
[res[0] for res in train_and_generate(gan, 1)]

100%|██████████| 13/13 [00:00<00:00, 49.67it/s]
Loss_G -0.2309029996395111; 
100%|██████████| 1/1 [00:00<00:00,  3.77it/s]

['artist started working with the very well known for tracks together with the very much two well we really well we see you really well see more tracks very amazing',
 'with their brand new electro r b between artist working with the electro r b electro r b working with two tracks really well if you much more tracks we get this little better',
 ' artist artist artist artist artist artist artist always well with the brand new new artist artist always really working between artist artist working between r b voice working with voice really well as well really much really well really get you get really get really get get their new side',
 'artist artist artist artist with a very looking for the the b side with the very little bit more tracks we dont really much much if we will get much more tracks like the tracks',
 'while artist very well well with artist very well with with r b side with tracks very well see how much we really really much we really really much we really get down the past

## Inference with saved models

In [33]:
def save():
    save_state_filename = 'ganG_state.pt'
    torch.save(gan.G.state_dict(), save_state_filename)
    print('Saved as %s' % (save_state_filename))
    
save()

Saved as ganG_state.pt


In [35]:
def gan_generate(vae, condition, n_latent, ganG, max_sample=False, trunc_sample=True):
    with torch.no_grad():
        ganG.eval()
        z = torch.randn(1, n_latent).to(DEVICE)
        decode_embed = to_embed(z, condition).to(DEVICE)
        z_prime = ganG(decode_embed)

        generated = vae.decoder.generate_with_embed(z_prime, 50, temp, DEVICE, max_sample=max_sample, trunc_sample=trunc_sample)
        generated_str = model.float_word_tensor_to_string(output_side, generated)

        EOS_str = f' {output_side.index_to_word(torch.LongTensor([EOS_token]))} '

        if generated_str.endswith(EOS_str):
            generated_str = generated_str[:-5]

        # flip it back
        return generated_str[::-1], z, z_prime

In [36]:
ganG = LCGAN_G(embed_size).to(DEVICE)
ganG.load_state_dict(torch.load('ganG_state.pt'))

gan_generate(vae, torch.FloatTensor(dataset.encode_genres(['hip hop','pop','pop rap','rap','trap music'])).to(DEVICE), n_latent, ganG)

('artist artist are reminiscent of many of our on this UNK the the may just get familiar with each other touch',
 tensor([[ 1.3266,  0.3218,  1.2880, -0.1209, -0.9666, -0.0200,  1.1089,
          -0.5409, -0.0544, -0.7155, -0.4585, -0.7051, -0.2313,  2.9348,
          -2.2418,  0.6079,  1.5914,  0.4511, -1.5066,  0.7279,  1.7946,
          -0.4472,  0.5314, -1.8664,  0.0953, -0.9538,  1.1704,  1.5429,
           0.0830, -0.4723, -1.2089, -1.2786, -0.1234, -0.6102,  0.5797,
           2.2356,  0.0159,  3.0624,  1.3961, -1.2540,  0.0048,  1.8790,
          -0.9674, -0.0236,  1.1817,  0.2974, -1.3599, -0.0940, -0.9623,
           1.3509,  0.3204, -0.2358, -0.8275,  1.4198, -0.8605,  0.4200,
           0.8859, -0.0607, -1.6140,  0.3870, -0.8037,  1.0355, -0.8868,
          -0.0426,  0.4908, -0.1655, -0.5113, -0.3807,  1.9757, -0.6120,
          -0.0254,  1.1762,  0.2535,  0.0041,  0.2726, -0.9798,  0.8222,
           1.4550, -0.2161, -0.0972,  0.2395,  0.0939,  0.1039,  0.7530,
           

# Extras

## Labeling data

### 'Banned' approach

label a sample as -1 (=="bad") if it contains a banned word; label as 1 otherwise (=="good")

In [7]:
#new_labels = np.array([(1, -1), (10, -1)])

banned = ['below']
labels = np.ones(n_samples, dtype=int)
gens_lose = list(set([i for b in banned for i in np.where([b in g.split() for g in gens])[0]]))
labels[gens_lose] = -1
zs_keep = zs

labels

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1, -1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1])

In [8]:
from collections import Counter
c1 = Counter([word for gen in gens for word in gen.split()])
[(b, c1[b]) for b in banned], c1.most_common(25)

([('below', 3)],
 [('the', 94),
  ('UNK', 82),
  ('of', 67),
  ('a', 66),
  ('and', 59),
  ('to', 55),
  ('is', 29),
  ('as', 26),
  ('artist', 26),
  ('with', 26),
  ('from', 24),
  ('it', 21),
  ('this', 21),
  ('track', 20),
  ('on', 20),
  ('up', 19),
  ('one', 18),
  ('be', 15),
  ('in', 14),
  ('has', 14),
  ('at', 11),
  ('i', 10),
  ('been', 10),
  ('trying', 10),
  ('that', 10)])

### 'Realism' approach

label a sample as 1 (=="good") if it came from the training data; label as -1 (=="bad") if it came from a random Gaussian `z`

In [10]:
from pytorchtextvae import model

input, target, condition = model.random_training_set(pairs, input_side, output_side, random_state, DEVICE)
model.long_word_tensor_to_string(input_side, input), dataset.decode_genres(condition)

('newcomer artist released his debut single last week and its already gaining major attention and a following that is demanding more after fill EOS ',
 ['vapor soul'])

In [11]:
temperature = 1.0

m, l, z, decoded = vae(input, target, condition, DEVICE, temperature)

z.size(), decoded.size()

(torch.Size([1, 128]), torch.Size([24, 333336]))

In [12]:
generate(condition, z=z, max_sample=True)[0]

'artist released his debut single and is just released and more than a ago and that that get your attention'

**TODO:** shouldn't generate with max sampling always return the same sample?

Even though the encoding is imperfect, we will still consider these `z`s as "realistic"

In [13]:
real_z = []
real_gens = []
for i in range(int(n_samples/2)):
    input, target, condition = model.random_training_set(pairs, input_side, output_side, random_state, DEVICE)
    with torch.no_grad():
        _, _, z, _ = vae(input, target, condition, DEVICE, temperature)
        real_z.append(z)
        real_gens.append(generate(condition, z=z, max_sample=True)[0])

In [14]:
from collections import Counter
c1 = Counter([word for gen in real_gens for word in gen.split()])
c1.most_common(25)

[('UNK', 106),
 ('and', 53),
 ('the', 50),
 ('a', 30),
 ('of', 30),
 ('on', 28),
 ('to', 24),
 ('in', 15),
 ('is', 15),
 ('i', 14),
 ('with', 14),
 ('w', 13),
 ('even', 11),
 ('that', 10),
 ('for', 10),
 ('who', 10),
 ('this', 9),
 ('always', 9),
 ('it', 9),
 ('tour', 8),
 ('into', 8),
 ('be', 8),
 ('but', 8),
 ('you', 7),
 ('new', 7)]

In [15]:
# up until now, `zs` held random zs - now concat with real zs
zs = torch.cat((torch.stack(real_z).squeeze(), torch.stack(zs[:int(n_samples/2)]).squeeze()))

In [16]:
labels = np.ones(n_samples, dtype=int)
labels[range(len(real_z), len(zs))] = -1
embeds = zs

labels

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

### Realism + Readability
in addition to the realism discriminator, add readability as a conditioning attribute

In [14]:
# https://github.com/shivam5992/textstat/issues/43
from textstat.textstat import textstat

# Score 	 Difficulty
# 90-100 	 Very Easy
# 80-89 	 Easy
# 70-79 	 Fairly Easy
# 60-69 	 Standard
# 50-59 	 Fairly Difficult
# 30-49 	 Difficult
# 0-29 	 Very Confusing

[textstat.flesch_reading_ease(sent) for sent in ["This is a sentence", "To be or not to be", ]]

[92.8, 116.15]

In [26]:
[(' '.join(gen.replace('UNK', '').split()), textstat.text_standard(' '.join(gen.replace('UNK', '').split()))) for gen in np.array(real_gens)[random_state.choice(len(real_gens), 10)]]

[('one of those who dont want to the one of their own who are their own in their own music',
  '4th and 5th grade'),
 ('the remix the production from the original and it sounds like a it with the bass and that sounds like it sounds like it from the chorus',
  '6th and 7th grade'),
 ('we been waiting for the last year while while its soon as waiting for little while we we as we as it',
  '12th and 13th grade'),
 ('one of those who dont want to the one of their own who are their own in their own music',
  '4th and 5th grade'),
 ('we trying to get into an artist is at the one of the and in the song that a part of the is in the trying to into it into a song to artist is just into part into that that or else',
  '14th and 15th grade'),
 ('if youre on the first version of you of know if you hear the version of this is that is to a song',
  '1th and 2th grade'),
 ('we been waiting for the last year while while its soon as waiting for little while we we as we as it',
  '12th and 13th grade'),


## Training

retrain discriminator with new samples... `z_prime`s that the discriminator is still not rejecting strongly enough

In [36]:
labels = np.ones(len(g2), dtype=int)
gens_lose = list(set([i for b in banned for i in np.where([b in res[0].split() for res in g2])[0]]))
#gens_keep = list(set(range(len(g2))) - gens_lose)
labels[gens_lose] = -1
zs_keep = np.array([res[2] for res in g2], dtype=object)

labels        

array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

In [32]:
ds = LatentDataset(zs_keep, labels.astype(float))
dl = DataLoader(ds, batch_size)
md = ModelData('.', dl, None)

In [33]:
[res[0] for res in train_and_generate(gan, 1, alternate=True)]

100%|██████████| 7/7 [00:00<00:00, 284.11it/s]
Loss_D 1.193210244178772; Loss_G 1.1896806955337524; 


['artist has teams up with a a one of the UNK and it would be taken from i don no feat',
 'this with a remix of UNK and trying to take of the UNK theres no feat',
 'the remix of UNK has been trying to take of the UNK theres no feat',
 'UNK is a listen to the UNK and coming side of what to look at feat',
 'a new track called UNK and serves as one of the trying would be ready for at times',
 'artist has up with a UNK and one of the track will be inspired as no feat',
 'artist is back in the likes of UNK and serves as a no times',
 'artist who reminds back to a more stuck on the UNK and theres many feat',
 'he continues to serves as one of the week what i look no feat',
 'with a layers of UNK and and UNK of the perfect trying i look at feat']