In [31]:
import torch
from torch import nn
import pandas as pd
from collections import Counter
import argparse
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np

In [3]:
# from https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2, # adjust hyperparam?
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [20]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        args,
        data
    ):
        self.args = args
        self.words = self.load_words(data)
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self, data):
        train_df = pd.read_csv(data)
        text = train_df['Lyric'].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1]),
        )

In [34]:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
args,u = parser.parse_known_args()

In [35]:
data = 'ArianaGrande.csv'
dataset = Dataset(args,data)
model = Model(dataset)

In [29]:
def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [36]:
train(dataset,model,args)

{'epoch': 0, 'batch': 0, 'loss': 8.41429328918457}
{'epoch': 0, 'batch': 1, 'loss': 8.400004386901855}
{'epoch': 0, 'batch': 2, 'loss': 8.409750938415527}
{'epoch': 0, 'batch': 3, 'loss': 8.392723083496094}
{'epoch': 0, 'batch': 4, 'loss': 8.380059242248535}
{'epoch': 0, 'batch': 5, 'loss': 8.365432739257812}
{'epoch': 0, 'batch': 6, 'loss': 8.38187313079834}
{'epoch': 0, 'batch': 7, 'loss': 8.370115280151367}
{'epoch': 0, 'batch': 8, 'loss': 8.322590827941895}
{'epoch': 0, 'batch': 9, 'loss': 8.29200267791748}
{'epoch': 0, 'batch': 10, 'loss': 8.266642570495605}
{'epoch': 0, 'batch': 11, 'loss': 8.177728652954102}
{'epoch': 0, 'batch': 12, 'loss': 8.062816619873047}
{'epoch': 0, 'batch': 13, 'loss': 7.9323906898498535}
{'epoch': 0, 'batch': 14, 'loss': 7.781162738800049}
{'epoch': 0, 'batch': 15, 'loss': 7.612934112548828}
{'epoch': 0, 'batch': 16, 'loss': 7.5601935386657715}
{'epoch': 0, 'batch': 17, 'loss': 7.267599582672119}
{'epoch': 0, 'batch': 18, 'loss': 7.019326210021973}
{'ep

['this',
 'is',
 'sample',
 'frightens',
 'blow',
 'to',
 'wanna',
 'make',
 'up',
 "thinkin'",
 'shit',
 'and',
 'you',
 'want',
 "i'm",
 'not',
 'looooooolove',
 'you',
 'the',
 'time',
 'and',
 '',
 'sweetener',
 'problem',
 'ariana',
 'ladies',
 'on',
 'the',
 'little',
 'conversation',
 "don't",
 'catch',
 'this',
 'sad',
 'alright',
 "you're",
 'been',
 'that',
 'and',
 'come',
 'kiss',
 'come',
 "somethin'",
 'in',
 'my',
 'lot',
 'i',
 'need',
 'me',
 'you',
 'say',
 'focus',
 'cause',
 "that's",
 'all',
 "you're",
 'up',
 'but',
 'i',
 'well',
 'when',
 'i',
 'love',
 'you',
 'i',
 "can't",
 'tell',
 'a',
 'more',
 'and',
 'we',
 'been',
 "cravin'",
 'and',
 'all',
 'is',
 'broke',
 'the',
 'minaj',
 "i've",
 'always',
 "beamin'",
 'you',
 "ain't",
 'worried',
 "'em",
 'away',
 'to',
 'been',
 "breathin'",
 'is',
 'my',
 'heart',
 'is',
 'ot',
 'forever',
 'too',
 'more',
 'not',
 'picture',
 'it',
 'like',
 'what']

In [40]:
predict(dataset,model,"and i am")

['and',
 'i',
 'am',
 'loosen',
 'lust',
 'left',
 "it's",
 'touch',
 'that',
 'no',
 'touch',
 'you',
 'all',
 'but',
 'you',
 'he',
 'do',
 'up',
 'mine',
 'baby',
 'a',
 'only',
 'one',
 'i',
 'should',
 'slip',
 'know',
 'it',
 'got',
 'the',
 'clouds',
 "don't",
 'do',
 'everything',
 'all',
 'is',
 'maybe',
 'you',
 'was',
 'here',
 'so',
 'from',
 'my',
 'phone',
 'about',
 'no',
 'makes',
 'you',
 'just',
 'ya',
 'boy',
 'i',
 'just',
 'felt',
 'say',
 "i'm",
 'not',
 'happen',
 'nobody',
 'from',
 'me',
 'i',
 'feel',
 'saying',
 'love',
 "i'm",
 'tell',
 'you',
 'next',
 'and',
 'you',
 'did',
 'it',
 'oh',
 'yeah',
 'you',
 'love',
 'me',
 'rolling',
 'and',
 "you're",
 'honeymoon',
 'of',
 'him',
 '',
 'pre',
 'grande',
 'and',
 'next',
 'us',
 'i',
 'will',
 'know',
 'i',
 'like',
 'it',
 'if',
 'i',
 'really',
 'make',
 'your',
 'hands',
 'let']