In [2]:
# Import pacakges
import torch
from torch import nn
import pandas as pd
from collections import Counter
import argparse
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# from https://www.kdnuggets.com/2020/07/pytorch-lstm-text-generation-tutorial.html

class Model(nn.Module):
    def __init__(self, dataset):
        super(Model, self).__init__()
        self.lstm_size = 128
        self.embedding_dim = 128
        self.num_layers = 3

        n_vocab = len(dataset.uniq_words)
        self.embedding = nn.Embedding(
            num_embeddings=n_vocab,
            embedding_dim=self.embedding_dim,
        )
        self.lstm = nn.LSTM(
            input_size=self.lstm_size,
            hidden_size=self.lstm_size,
            num_layers=self.num_layers,
            dropout=0.2, # adjust hyperparam?
        )
        self.fc = nn.Linear(self.lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        return logits, state

    def init_state(self, sequence_length):
        return (torch.zeros(self.num_layers, sequence_length, self.lstm_size),
                torch.zeros(self.num_layers, sequence_length, self.lstm_size))

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(
        self,
        args,
        data
    ):
        self.args = args
        self.words = self.load_words(data)
        self.uniq_words = self.get_uniq_words()

        self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
        self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}

        self.words_indexes = [self.word_to_index[w] for w in self.words]

    def load_words(self, data):
        train_df = pd.read_csv(data)
        print(train_df.head())
        text = train_df['Lyric'][1:1000].str.cat(sep=' ')
        return text.split(' ')

    def get_uniq_words(self):
        word_counts = Counter(self.words)
        return sorted(word_counts, key=word_counts.get, reverse=True)

    def __len__(self):
        return len(self.words_indexes) - self.args.sequence_length

    def __getitem__(self, index):
        return (
            torch.tensor(self.words_indexes[index:index+self.args.sequence_length]),
            torch.tensor(self.words_indexes[index+1:index+self.args.sequence_length+1]),
        )

In [5]:
parser = argparse.ArgumentParser()
parser.add_argument('--max-epochs', type=int, default=10)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--sequence-length', type=int, default=4)
args,u = parser.parse_known_args()

In [6]:
data = 'data/country_data_cleaned_spaces.csv'
dataset = Dataset(args,data)
model = Model(dataset)

                                               Lyric
0  o death where is thy sting \n o grave where is...
1  they used to call me lightning \n i was always...
2  you were in college working parttime waiting t...
3  he was born in the summer of his 27th year \n ...
4  a shimmy shimmy go go motherfucking pop bitch ...


In [7]:
def train(dataset, model, args):
    model.train()

    dataloader = DataLoader(dataset, batch_size=args.batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(args.max_epochs):
        state_h, state_c = model.init_state(args.sequence_length)

        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()

            y_pred, (state_h, state_c) = model(x, (state_h, state_c))
            loss = criterion(y_pred.transpose(1, 2), y)

            state_h = state_h.detach()
            state_c = state_c.detach()

            loss.backward()
            optimizer.step()

            print({ 'epoch': epoch, 'batch': batch, 'loss': loss.item() })

def predict(dataset, model, text, next_words=100):
    model.eval()

    words = text.split(' ')
    state_h, state_c = model.init_state(len(words))

    for i in range(0, next_words):
        x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]])
        y_pred, (state_h, state_c) = model(x, (state_h, state_c))

        last_word_logits = y_pred[0][-1]
        p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().numpy()
        word_index = np.random.choice(len(last_word_logits), p=p)
        words.append(dataset.index_to_word[word_index])

    return words

In [8]:
train(dataset,model,args)

{'epoch': 0, 'batch': 0, 'loss': 9.393013000488281}
{'epoch': 0, 'batch': 1, 'loss': 9.375574111938477}
{'epoch': 0, 'batch': 2, 'loss': 9.366560935974121}
{'epoch': 0, 'batch': 3, 'loss': 9.349126815795898}
{'epoch': 0, 'batch': 4, 'loss': 9.341114044189453}
{'epoch': 0, 'batch': 5, 'loss': 9.323019027709961}
{'epoch': 0, 'batch': 6, 'loss': 9.303452491760254}
{'epoch': 0, 'batch': 7, 'loss': 9.272867202758789}
{'epoch': 0, 'batch': 8, 'loss': 9.197030067443848}
{'epoch': 0, 'batch': 9, 'loss': 9.049039840698242}
{'epoch': 0, 'batch': 10, 'loss': 8.871219635009766}
{'epoch': 0, 'batch': 11, 'loss': 8.761839866638184}
{'epoch': 0, 'batch': 12, 'loss': 8.575411796569824}
{'epoch': 0, 'batch': 13, 'loss': 8.887012481689453}
{'epoch': 0, 'batch': 14, 'loss': 8.371642112731934}
{'epoch': 0, 'batch': 15, 'loss': 8.189277648925781}
{'epoch': 0, 'batch': 16, 'loss': 7.999395370483398}
{'epoch': 0, 'batch': 17, 'loss': 7.840645790100098}
{'epoch': 0, 'batch': 18, 'loss': 7.637129783630371}
{'e

In [18]:
samples = pd.DataFrame([' '.join(predict(dataset,model,"")[1:]) for i in range(25)])
samples.to_csv("example.csv",encoding="utf-8")