https://arxiv.org/abs/1711.05772 / https://arxiv.org/abs/1802.04877

https://github.com/natashamjaques/magenta/blob/affective-reward/magenta/models/affective_reward/latent_gan.py

In [1]:
import torch

print('cuda.is_available:', torch.cuda.is_available())
print(f'available: {torch.cuda.device_count()}; current: {torch.cuda.current_device()}')
DEVICE = torch.device(f'cuda:{torch.cuda.current_device()}' if torch.cuda.is_available() else 'cpu')
print(DEVICE)
print('pytorch', torch.__version__)

cuda.is_available: True
available: 1; current: 0
cuda:0
pytorch 0.4.0


## Labeling data

In [2]:
# pip install git+https://github.com/iconix/pytorch-text-vae.git
from pytorchtextvae import generate

In [3]:
DEVICE = torch.device('cpu') # CPU inference
n_samples = 2000
temp = 0.2

# workaround for un-pickling after module directory change https://stackoverflow.com/a/45264751
#import sys
#sys.path.append('../../pytorch-text-vae/pytorchtextvae')

vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state = generate.load_model('../../pytorch-text-vae/model/best/reviews_and_metadata_5yrs_state.pt', 'reviews_and_metadata_5yrs_stored_info.pkl', DEVICE, cache_path='../../pytorch-text-vae/model/best/tmp')

Fetching cached info at ../../pytorch-text-vae/model/best/tmp/reviews_and_metadata_5yrs_stored_info.pkl
Cache ../../pytorch-text-vae/model/best/tmp/reviews_and_metadata_5yrs_stored_info.pkl loaded (load time: 0.68s)
Found saved model ../../pytorch-text-vae/model/best/reviews_and_metadata_5yrs_state.pt
MAX_SAMPLE: False; TRUNCATED_SAMPLE: True
Trained for 360000 steps (load time: 18.85s)
Setting new random seed


In [4]:
#generate.generate(vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE, genres=['downtempo', 'dream pop', 'indietronica'], num_sample=10, temp=temp)

In [5]:
#gens, zs, conditions = generate.generate(vae, input_side, output_side, pairs, dataset, EMBED_SIZE, random_state, DEVICE, num_sample=n_samples, temp=temp)

In [6]:
#list(zip(range(len(gens)), gens))

In [7]:
def to_embed(z, condition):
    if condition.dim() == 1:
        condition = condition.unsqueeze(0)
    squashed_condition = vae.decoder.c2h(condition)
    return torch.cat([z, squashed_condition], 1)

In [8]:
n_latent = 128
from pytorchtextvae.datasets import EOS_token

def generate(condition, gan=None, z=None, max_sample=False, truncated_sample=True, temp=temp):
    with torch.no_grad():
        if gan is None:
            z_prime = z
        else:
            gan.eval()
            z = torch.randn(1, n_latent).to(DEVICE)
            decode_embed = to_embed(z, condition).to(DEVICE)
            z_prime = gan.G(decode_embed)

        generated = vae.decoder.generate_with_embed(z_prime, 50, temp, DEVICE, max_sample=max_sample, trunc_sample=truncated_sample)
        generated_str = model.float_word_tensor_to_string(output_side, generated)

        EOS_str = f' {output_side.index_to_word(torch.LongTensor([EOS_token]))} '

        if generated_str.endswith(EOS_str):
            generated_str = generated_str[:-5]

        # flip it back
        return generated_str[::-1], z, z_prime

### Topical
Prefer certain topics to others

In [9]:
from pytorchtextvae import datasets

def tokenize(line):
    l = line.strip().lstrip().rstrip()
    l = datasets.normalize_string(l)
    return l.split(' ')

In [10]:
n_examples = 3

sents = [pair[0] for pair in pairs]
texts = [tokenize(sentence) for sentence in sents]
texts[:n_examples]

[['after',
  'featuring',
  'on',
  'jakwobs',
  'fade',
  'late',
  'last',
  'year',
  'and',
  'also',
  'penning',
  'and',
  'co',
  'producing',
  'wretch',
  '32s',
  'no',
  '1',
  'single',
  'dont',
  'go',
  'artist',
  'has',
  'released',
  'her',
  'first',
  'solo',
  'single',
  'ahead',
  'of',
  'her',
  'ep',
  'due',
  'in',
  'spring'],
 ['that',
  'means',
  'doing',
  'everything',
  'ourselves',
  'from',
  'scratch',
  'generic',
  'bringing',
  'the',
  'hand',
  'of',
  'the',
  'artist',
  'back'],
 ['if',
  'you',
  'don',
  't',
  'like',
  'it',
  'at',
  'first',
  'wait',
  'until',
  'the',
  'synths',
  'come',
  'in',
  'during',
  'the',
  'chorus']]

In [11]:
from nltk.corpus import stopwords

# remove stop words and words that appear only once
stoplist = [datasets.normalize_string(word) for word in stopwords.words('english')]
fillerlist = ['author', 'song_title', 'artist', 'sitename']

texts = [[word for word in text if word not in stoplist and word not in fillerlist] for text in texts]
texts[:n_examples]

[['featuring',
  'jakwobs',
  'fade',
  'late',
  'last',
  'year',
  'also',
  'penning',
  'co',
  'producing',
  'wretch',
  '32s',
  '1',
  'single',
  'go',
  'released',
  'first',
  'solo',
  'single',
  'ahead',
  'ep',
  'due',
  'spring'],
 ['means', 'everything', 'scratch', 'generic', 'bringing', 'hand', 'back'],
 ['like', 'first', 'wait', 'synths', 'come', 'chorus']]

In [12]:
from gensim.corpora.dictionary import Dictionary

dictionary = Dictionary(texts)

In [13]:
from gensim.models.ldamodel import LdaModel
import time

start = time.time()
n_topics = 4
passes = 20 # number of passes through documents
iterations = 400 # maximum number of iterations through the corpus when inferring the topic distribution of a corpus.
minimum_probability = 0

corpus = [dictionary.doc2bow(text) for text in texts]
# Train the model on the corpus.
lda = LdaModel(corpus, id2word=dictionary, num_topics=n_topics, iterations=iterations, passes=passes, minimum_probability=minimum_probability)
#lda = LdaModel(corpus, id2word=dictionary, num_topics=n_topics)
print(f'Runtime: {time.time() - start:.2f}s')
lda.print_topics(n_topics)

Runtime: 534.08s


[(0,
  '0.036*"new" + 0.021*"single" + 0.020*"album" + 0.016*"track" + 0.014*"release" + 0.012*"ep" + 0.012*"year" + 0.011*"debut" + 0.011*"first" + 0.011*"released"'),
 (1,
  '0.014*"like" + 0.014*"one" + 0.011*"song" + 0.011*"music" + 0.008*"time" + 0.007*"get" + 0.006*"love" + 0.006*"something" + 0.006*"us" + 0.006*"make"'),
 (2,
  '0.015*"track" + 0.012*"vocals" + 0.012*"pop" + 0.007*"sound" + 0.006*"song" + 0.006*"production" + 0.005*"electronic" + 0.005*"vocal" + 0.005*"dance" + 0.005*"like"'),
 (3,
  '0.009*"tour" + 0.006*"music" + 0.005*"10" + 0.005*"festival" + 0.005*"live" + 0.005*"uk" + 0.004*"remix" + 0.004*"dates" + 0.004*"london" + 0.003*"club"')]

In [14]:
from operator import itemgetter

for i in range(n_examples):
    print(max(lda[corpus[i]],key=itemgetter(1)), datasets.normalize_string(sents[i]))

(0, 0.90399534) after featuring on jakwobs fade late last year and also penning and co producing wretch 32s no 1 single dont go artist has released her first solo single ahead of her ep due in spring
(2, 0.52396774) that means doing everything ourselves from scratch generic bringing the hand of the artist back
(1, 0.56518143) if you don t like it at first wait until the synths come in during the chorus


In [15]:
import pyLDAvis.gensim
pyLDAvis.enable_notebook()

import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

pyLDAvis.gensim.prepare(lda, corpus, dictionary)

In [16]:
from collections import Counter
Counter([max(lda[corpus[i]], key=itemgetter(1))[0] for i in range(len(texts))])

Counter({0: 29749, 2: 29463, 1: 36441, 3: 8814})

In [17]:
out = sorted([f'{max(lda[corpus[i]], key=itemgetter(1))} {datasets.normalize_string(sents[i])}\n' for i in range(len(texts))], reverse=True)
with open('pairs_sentence_topics.txt', 'w') as f:
    f.writelines(out)

^using this file to select desirable/preferred topics for **weights** below.

In [18]:
# 1 if a good topic, -1 if bad, 0 if neutral
topic_weights = torch.tensor([0, -1, 1, -1], dtype=torch.float)

In [19]:
def get_example(i, pairs, input_side, output_side, random_state, device):
    pair = pairs[i]

    inp = model.word_tensor(input_side, pair[0]).to(device)
    target = model.word_tensor(output_side, pair[1]).to(device)
    condition = torch.tensor(pair[2], dtype=torch.float).unsqueeze(0).to(device) if len(pair) == 3 else None

    return inp, target, condition

In [20]:
import numpy as np
import time
from pytorchtextvae import model

labels = np.zeros((n_samples, n_topics), dtype=float)
embeds = []
start = time.time()

# debug vars
ts = []

for i in range(n_samples):
    pair_i = random_state.choice(len(pairs))
    ts.append(pairs[pair_i][0])
    input, target, condition = get_example(pair_i, pairs, input_side, output_side, random_state, DEVICE)
    with torch.no_grad():
        _, _, z, _ = vae(input, target, condition, DEVICE, temp)
        squashed_condition = vae.decoder.c2h(condition)
        decode_embed = torch.cat([z, squashed_condition], 1)
        embeds.append(decode_embed)
    
    labels[i] = [tup[1] for tup in lda[corpus[pair_i]]]

print(f'runtime: {time.time() - start:.2f}s')
list(zip(labels, ts))

runtime: 1433.14s


[(array([0.27701294, 0.67650062, 0.02342723, 0.02305923]),
  'on hearing this latest offering from the man who begat the band thats certainly become more of a truism'),
 (array([0.01691901, 0.01731366, 0.94908559, 0.01668175]),
  'its fluid and evolving beginning with a slow jazzy strum and rising with a funky unrestrained groove that seems content to float about the atmosphere'),
 (array([0.48876944, 0.30449063, 0.17548878, 0.0312512 ]),
  'much in the vein of last years breakout single something about you'),
 (array([0.36397493, 0.07009388, 0.09863776, 0.46729341]),
  'the ep features collaborations with sneaky sound system ra ra riot s wes and lastly x ambassadors who also appeared on eminem s the marshall mathers lp 2'),
 (array([0.17830433, 0.24738024, 0.31226519, 0.26205021]),
  'arms and sleepers are now closer to acts like sun glitters slow magic or tycho than former ambient or post rock experiments'),
 (array([0.04207178, 0.50906318, 0.22618909, 0.2226759 ]),
  'but if he were

## Data

In [21]:
batch_size = 16
embed_size = embeds[0].size(1)

In [22]:
from fastai.dataset import *

class LatentDataset(Dataset):
    def __init__(self, embeds, labels): self.embeds,self.labels = embeds,labels
    def __getitem__(self, idx): return A(self.embeds[idx], self.labels[idx])
    def __len__(self): return len(self.embeds)
    
ds = LatentDataset(embeds, labels.astype(float))
dl = DataLoader(ds, batch_size)
md = ModelData('.', dl, None)

In [23]:
md.trn_ds[0]

[array([[ 3.93861, -0.026  ,  2.31492,  0.82133,  0.16714, -1.47847,  0.96384,  1.79145, -0.75436, -2.18008,
          2.54672,  1.55539, -0.87499,  3.0986 ,  2.02656,  1.90291,  0.40648,  1.95088, -1.61148, -0.22457,
         -0.76677,  0.92869, -0.03142, -0.99057,  0.66427,  1.36929, -2.46057, -0.27369,  2.15403, -0.96158,
         -1.25664, -0.93988,  0.3731 ,  1.90596, -0.66382,  2.47713, -0.78741, -1.63785,  1.21423,  2.629  ,
         -1.36975, -0.18394,  1.30741, -1.12852, -1.11652,  0.03081, -3.13996, -0.0284 ,  3.93696,  0.6328 ,
          1.36508,  4.19587, -1.17954,  0.21503, -1.02109, -0.00846, -2.02669,  0.83514,  2.70851, -2.3371 ,
          2.07963, -0.4968 , -1.3802 , -1.89908, -1.59388,  0.42006, -2.87025, -1.13497,  3.78593, -2.28085,
         -0.59909,  0.57125,  2.3702 , -2.49419, -0.35535, -0.82451, -0.03281,  1.48842,  2.81842, -0.23682,
         -1.80256, -1.56277,  1.65381, -0.80938,  2.38202,  1.01897,  1.15433, -0.00042,  1.21729, -0.08357,
         -0.8742 , 

## Model

In [24]:
n_hidden = 1024
lr = 3e-4
fixed_genres = torch.FloatTensor(dataset.encode_genres(['neo soul', 'pop', 'r&b', 'urban contemporary'])).to(DEVICE)

In [25]:
import torch.optim as optim
import torch.nn as nn

class LCGAN_D(nn.Module):
    '''Discriminator'''
    def __init__(self, n_embed, n_hidden=n_hidden, n_output=n_topics):
        super(LCGAN_D, self).__init__()
        
        self.i2h = nn.Linear(n_embed, n_hidden)
        self.h2h = nn.Linear(n_hidden, n_hidden)
        self.h2o = nn.Linear(n_hidden, n_output)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, emb):
        x = emb
        x = self.relu(self.i2h(x))
        x = self.relu(self.h2h(x))
        x = self.relu(self.h2h(x))
        v = self.sigmoid(self.h2o(x))
        
        return v

class LCGAN_G(nn.Module):
    '''Generator'''
    def __init__(self, n_embed, n_hidden=n_hidden):
        super(LCGAN_G, self).__init__()
        self.n_embed = n_embed
        
        self.i2h = nn.Linear(n_embed, n_hidden)
        self.h2h = nn.Linear(n_hidden, n_hidden)
        # hidden-to-gating mechanism
        self.h2g = nn.Linear(n_hidden, 2*n_embed)
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, emb):
        x = emb
        x = self.relu(self.i2h(x))
        x = self.relu(self.h2h(x))
        x = self.relu(self.h2h(x))
        x = self.h2g(x)
        
        # gating mechanism: allow network to remember/forget
        # what it wants to about the original emb(edding) and x
        emb_mid = x[:, self.n_embed:]
        gates = self.sigmoid(x[:, :self.n_embed])
        demb = gates * emb_mid # TODO: why naming?
        emb_prime = (1 - gates)*emb + demb
        
        return emb_prime

In [26]:
class LCGAN(nn.Module):
    def __init__(self, D, G, batch_size=batch_size):
        super(LCGAN, self).__init__()
        self.batch_size = batch_size
        
        self.D = D
        self.G = G

    def train(self):
        self.D.train()
        self.G.train()
        
    def eval(self):
        self.D.eval()
        self.G.eval()
        
    def forward(self, emb=None):       
        if emb is not None:
            # train discriminator
            #embed = to_embed(z, fixed_genres)
            v = self.D(emb)
            return v
        else: # train GAN
            # gaussian random noise
            emb_prior = torch.randn(self.batch_size, self.G.n_embed).to(DEVICE)
            
            emb_prime = self.G(emb_prior)
            v_prime = self.D(emb_prime)
            
            return v_prime

## Training

In [27]:
gan = LCGAN(LCGAN_D(embed_size).to(DEVICE), LCGAN_G(embed_size).to(DEVICE)).to(DEVICE)

In [28]:
import fastai

fastai.core.set_trainable(gan.D, True)
fastai.core.set_trainable(gan.G, True)

opt_d = optim.Adam(gan.D.parameters(), lr=lr)
opt_g = optim.Adam(gan.G.parameters(), lr=lr)

In [29]:
# test what the GAN is doing before any training
for i in range(10):
    print(generate(fixed_genres, gan)[0])

UNK UNK are you of the new more in a link below and also also one of the UNK UNK UNK july
UNK UNK and UNK artist has been back to a north american tour which is back at the kind of star stuff
in a part of danish trio artist returns with the video sing according to this one on my list
an official remix of UNK has teamed up with with such as the song featured for late last mark series
his his UNK UNK UNK and UNK at UNK UNK UNK UNK to keep your eye on my list
he went on the collaborations trio and and in his on on what youd expect from dj set
the UNK has come with a part of my here and listen to the r b side to music today
one thing we listened to combined with a tracks that was well in the vocals and production reminiscent of these instrumentation
with a pretty people in the life and the tried to stand out is feels like the kind of live effect
its the sort the big that are back in UNK in the better known as as quite path


In [30]:
# adapted from: https://github.com/fastai/fastai/blob/master/courses/dl2/wgan.ipynb
def train(n_iter, alternate=False, first=False):
    gen_iters = 0
    for epoch in trange(n_iter):
        gan.train()
        data_iter = iter(md.trn_dl)
        i, n = 0, len(md.trn_dl)
        
        def train_G():
            ''' Train generator '''
            nonlocal gen_iters
            
            fastai.core.set_trainable(gan.D, False)
            fastai.core.set_trainable(gan.G, True)

            gan.G.zero_grad()

            #print(i, n)
            v_prime = gan()
            log_loss_g = torch.log(v_prime)
            loss_g = (-log_loss_g * topic_weights).mean()
            loss_g.backward()
            opt_g.step()
            gen_iters += 1
            
            return loss_g

        def train_D():
            ''' Train discriminator '''
            nonlocal i
            
            fastai.core.set_trainable(gan.D, True)
            fastai.core.set_trainable(gan.G, False)
            d_iters = 100 if (first and (gen_iters < 25) or (gen_iters % 500 == 0)) else 3
            j = 0

            while (j < d_iters) and (i < n):
                j += 1; i += 1
                batch = next(data_iter)
                #print(j, i, batch[0].size(), batch[1].size())
                emb_real = batch[0].to(DEVICE)
                v = gan(emb_real).to(DEVICE)

                gan.D.zero_grad()

                #loss_d = - (batch[1] * torch.log(v) + (1.0-batch[1]) * torch.log(1.0 - v)).mean()
                loss_d = - (batch[1].to(DEVICE) * torch.log(v) + (1.0-batch[1].to(DEVICE)) * torch.log(1.0 - v)).mean()
                loss_d.backward()
                opt_d.step()
                
                pbar.update()
                
            return loss_d
        
        with tqdm(total=n) as pbar:
            while i < n:
                if alternate:
                    # train discriminator
                    loss_d = train_D()
                    # then train generator a little bit
                    loss_g = train_G()
                else:
                    # train generator only
                    i += 1
                    loss_g = train_G()
                    pbar.update()
        
        if alternate:
            print(f'Loss_D {to_np(loss_d)}; Loss_G {to_np(loss_g)}; ')
        else:
            print(f'Loss_G {to_np(loss_g)}; ')

In [31]:
def train_and_generate(gan, n_epoch, genres=fixed_genres, alternate=False, n_sample=10):
    train(n_epoch, alternate)
    res = []
    for i in range(n_sample):
        res.append(generate(genres, gan))
    return res

In [32]:
[res[0] for res in train_and_generate(gan, 1, alternate=True)]

100%|██████████| 125/125 [00:01<00:00, 74.74it/s]
Loss_D 0.5376699566841125; Loss_G -0.3489548861980438; 
100%|██████████| 1/1 [00:01<00:00,  1.23s/it]

['vocals are are yet another full of of with UNK UNK r b and UNK and also the perfect example to it creating an uplifting space',
 'with a a UNK UNK UNK is taking over three but in the tracks that i get lost in and and has a beauty to to it in your work',
 'there are a bit for indie soul infused step in this combined with the neo guitars and soulful vocals and sounds and UNK features the the UNK vocals UNK this mix',
 'about the game right and and in in and and and and and and and and and and and and and and and we havent heard in a sort of disclosure and found on the small studio in san late',
 'we are name for a UNK years of atlantic and and and and it has to look at future beats and soulful and that that are its just feel good vibes',
 'kill j had a a different a collaboration and its clear that its the two of both sound and both and and the sound and this final mark',
 'the fact that it begins with a UNK vocal melodies and deep b beat that seems to it by the deep in artist s atmosp

In [33]:
[res[0] for res in train_and_generate(gan, 1)]

100%|██████████| 125/125 [00:02<00:00, 50.46it/s]
Loss_G -22.094064712524414; 
100%|██████████| 1/1 [00:02<00:00,  2.48s/it]

[' work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work label work',
 ' work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work label work',
 ' work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work label work',
 ' work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work work

## Inference with saved models

In [34]:
def save():
    save_state_filename = 'ganG_state.pt'
    torch.save(gan.G.state_dict(), save_state_filename)
    print('Saved as %s' % (save_state_filename))
    
save()

Saved as ganG_state.pt


In [55]:
def gan_generate(vae, condition, n_latent, ganG, max_sample=False, trunc_sample=True):
    with torch.no_grad():
        ganG.eval()
        z = torch.randn(1, n_latent).to(DEVICE)
        decode_embed = to_embed(z, condition).to(DEVICE)
        z_prime = ganG(decode_embed)

        generated = vae.decoder.generate_with_embed(z_prime, 50, temp, DEVICE, max_sample=max_sample, trunc_sample=trunc_sample)
        generated_str = model.float_word_tensor_to_string(output_side, generated)

        EOS_str = f' {output_side.index_to_word(torch.LongTensor([EOS_token]))} '

        if generated_str.endswith(EOS_str):
            generated_str = generated_str[:-5]

        # flip it back
        return generated_str[::-1], z, z_prime

In [56]:
ganG = LCGAN_G(embed_size).to(DEVICE)
ganG.load_state_dict(torch.load('ganG_state.pt'))

gan_generate(vae, torch.FloatTensor(dataset.encode_genres(['hip hop','pop','pop rap','rap','trap music'])).to(DEVICE), n_latent, ganG)

('a form of UNK and UNK is no wonder throughout a track just interested welcome addition to their music career',
 tensor([[ 0.3426, -0.4434,  0.6537, -1.0167,  0.2749,  0.3956, -1.0323,
           1.9262, -0.2886,  0.0875, -1.1526,  0.5510, -0.6338, -0.5600,
          -1.1005, -0.0089, -0.0115,  0.3233, -1.4833, -0.4620,  0.2851,
          -0.1228, -2.3361, -1.1106,  0.3530,  0.6186, -0.2250, -1.9346,
           0.4420, -1.3540, -0.9850,  1.5157,  1.8120,  0.8877, -1.1195,
           0.1646,  0.2872, -1.1713, -0.0716, -0.2902, -0.3561,  0.2219,
           0.8907, -1.1704, -0.0048, -0.6173,  1.3518,  0.4813, -0.1438,
          -1.4539,  0.2248, -0.2255,  1.8956,  0.4242,  0.5299, -0.6540,
           0.7414, -0.7336, -0.3256,  1.0842,  0.2970, -1.3575, -1.0421,
          -1.7367, -0.6553,  0.8228,  1.5621, -0.5879, -0.6136, -0.1467,
          -0.0047,  1.4218,  1.6615,  0.0985, -0.0368, -0.6737,  0.0706,
          -1.5198, -0.6521, -1.7595,  0.6300,  2.0502,  1.0153, -0.0525,
           

# Extras

## Labeling data

### 'Banned' approach

label a sample as -1 (=="bad") if it contains a banned word; label as 1 otherwise (=="good")

In [7]:
#new_labels = np.array([(1, -1), (10, -1)])

banned = ['below']
labels = np.ones(n_samples, dtype=int)
gens_lose = list(set([i for b in banned for i in np.where([b in g.split() for g in gens])[0]]))
labels[gens_lose] = -1
zs_keep = zs

labels

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1, -1,  1,  1,  1, -1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1,  1,  1,  1])

In [8]:
from collections import Counter
c1 = Counter([word for gen in gens for word in gen.split()])
[(b, c1[b]) for b in banned], c1.most_common(25)

([('below', 3)],
 [('the', 94),
  ('UNK', 82),
  ('of', 67),
  ('a', 66),
  ('and', 59),
  ('to', 55),
  ('is', 29),
  ('as', 26),
  ('artist', 26),
  ('with', 26),
  ('from', 24),
  ('it', 21),
  ('this', 21),
  ('track', 20),
  ('on', 20),
  ('up', 19),
  ('one', 18),
  ('be', 15),
  ('in', 14),
  ('has', 14),
  ('at', 11),
  ('i', 10),
  ('been', 10),
  ('trying', 10),
  ('that', 10)])

### 'Realism' approach

label a sample as 1 (=="good") if it came from the training data; label as -1 (=="bad") if it came from a random Gaussian `z`

In [10]:
from pytorchtextvae import model

input, target, condition = model.random_training_set(pairs, input_side, output_side, random_state, DEVICE)
model.long_word_tensor_to_string(input_side, input), dataset.decode_genres(condition)

('newcomer artist released his debut single last week and its already gaining major attention and a following that is demanding more after fill EOS ',
 ['vapor soul'])

In [11]:
temperature = 1.0

m, l, z, decoded = vae(input, target, condition, DEVICE, temperature)

z.size(), decoded.size()

(torch.Size([1, 128]), torch.Size([24, 333336]))

In [12]:
generate(condition, z=z, max_sample=True)[0]

'artist released his debut single and is just released and more than a ago and that that get your attention'

**TODO:** shouldn't generate with max sampling always return the same sample?

Even though the encoding is imperfect, we will still consider these `z`s as "realistic"

In [13]:
real_z = []
real_gens = []
for i in range(int(n_samples/2)):
    input, target, condition = model.random_training_set(pairs, input_side, output_side, random_state, DEVICE)
    with torch.no_grad():
        _, _, z, _ = vae(input, target, condition, DEVICE, temperature)
        real_z.append(z)
        real_gens.append(generate(condition, z=z, max_sample=True)[0])

In [14]:
from collections import Counter
c1 = Counter([word for gen in real_gens for word in gen.split()])
c1.most_common(25)

[('UNK', 106),
 ('and', 53),
 ('the', 50),
 ('a', 30),
 ('of', 30),
 ('on', 28),
 ('to', 24),
 ('in', 15),
 ('is', 15),
 ('i', 14),
 ('with', 14),
 ('w', 13),
 ('even', 11),
 ('that', 10),
 ('for', 10),
 ('who', 10),
 ('this', 9),
 ('always', 9),
 ('it', 9),
 ('tour', 8),
 ('into', 8),
 ('be', 8),
 ('but', 8),
 ('you', 7),
 ('new', 7)]

In [15]:
# up until now, `zs` held random zs - now concat with real zs
zs = torch.cat((torch.stack(real_z).squeeze(), torch.stack(zs[:int(n_samples/2)]).squeeze()))

In [16]:
labels = np.ones(n_samples, dtype=int)
labels[range(len(real_z), len(zs))] = -1
embeds = zs

labels

array([ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,
        1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

### Realism + Readability
in addition to the realism discriminator, add readability as a conditioning attribute

In [14]:
# https://github.com/shivam5992/textstat/issues/43
from textstat.textstat import textstat

# Score 	 Difficulty
# 90-100 	 Very Easy
# 80-89 	 Easy
# 70-79 	 Fairly Easy
# 60-69 	 Standard
# 50-59 	 Fairly Difficult
# 30-49 	 Difficult
# 0-29 	 Very Confusing

[textstat.flesch_reading_ease(sent) for sent in ["This is a sentence", "To be or not to be", ]]

[92.8, 116.15]

In [26]:
[(' '.join(gen.replace('UNK', '').split()), textstat.text_standard(' '.join(gen.replace('UNK', '').split()))) for gen in np.array(real_gens)[random_state.choice(len(real_gens), 10)]]

[('one of those who dont want to the one of their own who are their own in their own music',
  '4th and 5th grade'),
 ('the remix the production from the original and it sounds like a it with the bass and that sounds like it sounds like it from the chorus',
  '6th and 7th grade'),
 ('we been waiting for the last year while while its soon as waiting for little while we we as we as it',
  '12th and 13th grade'),
 ('one of those who dont want to the one of their own who are their own in their own music',
  '4th and 5th grade'),
 ('we trying to get into an artist is at the one of the and in the song that a part of the is in the trying to into it into a song to artist is just into part into that that or else',
  '14th and 15th grade'),
 ('if youre on the first version of you of know if you hear the version of this is that is to a song',
  '1th and 2th grade'),
 ('we been waiting for the last year while while its soon as waiting for little while we we as we as it',
  '12th and 13th grade'),


## Training

retrain discriminator with new samples... `z_prime`s that the discriminator is still not rejecting strongly enough

In [36]:
labels = np.ones(len(g2), dtype=int)
gens_lose = list(set([i for b in banned for i in np.where([b in res[0].split() for res in g2])[0]]))
#gens_keep = list(set(range(len(g2))) - gens_lose)
labels[gens_lose] = -1
zs_keep = np.array([res[2] for res in g2], dtype=object)

labels        

array([-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
       -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,  1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1])

In [32]:
ds = LatentDataset(zs_keep, labels.astype(float))
dl = DataLoader(ds, batch_size)
md = ModelData('.', dl, None)

In [33]:
[res[0] for res in train_and_generate(gan, 1, alternate=True)]

100%|██████████| 7/7 [00:00<00:00, 284.11it/s]
Loss_D 1.193210244178772; Loss_G 1.1896806955337524; 


['artist has teams up with a a one of the UNK and it would be taken from i don no feat',
 'this with a remix of UNK and trying to take of the UNK theres no feat',
 'the remix of UNK has been trying to take of the UNK theres no feat',
 'UNK is a listen to the UNK and coming side of what to look at feat',
 'a new track called UNK and serves as one of the trying would be ready for at times',
 'artist has up with a UNK and one of the track will be inspired as no feat',
 'artist is back in the likes of UNK and serves as a no times',
 'artist who reminds back to a more stuck on the UNK and theres many feat',
 'he continues to serves as one of the week what i look no feat',
 'with a layers of UNK and and UNK of the perfect trying i look at feat']