In [46]:
# Character level lyrics generation using RNNs (LSTM)
import sys, os, random, string
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import string
from tqdm.autonotebook import tqdm

import CharlyricsDataset
from RNN import RNN
import glob

# ignore warnings
import warnings
warnings.filterwarnings("ignore")

# interactive mode
plt.ion()

from pathlib import Path
from config import config
import utils

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [48]:
train_dataset = CharlyricsDataset.CharLyricsDataset(config.DATA.LYRICS, config.TRAIN.MAX_LEN)
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.TRAIN.BATCH_SIZE,
    shuffle=False,
    drop_last=True,
    num_workers=1
)

In [49]:
model = RNN(utils.get_total_characters(), config.TRAIN.HIDDEN_SIZE, config.TRAIN.LSTM_N_LAYERS, utils.get_total_characters()).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=config.TRAIN.LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

In [56]:
for epoch in range(50):
    model.train()
    # tq = tqdm(train_loader, total=len(train_loader), desc=f"Training: Epoch {epoch+1}/{config.TRAIN.EPOCHS}")
    total_loss = 0

    for _, batch in enumerate(train_loader):
        model.zero_grad()
        input_seq, output_seq = batch

        input_seq = input_seq.to(device)
        output_seq = output_seq.to(device)
        loss = 0

        # vectorize this
        for c in range(config.TRAIN.MAX_LEN):
            output = model(input_seq[:, c])
            loss += loss_fn(output, output_seq[:, c])
            total_loss += loss.item()
            
        # tq.set_postfix(loss=loss.item())
        # batch-gradient-descent
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: Total Loss {total_loss/(config.TRAIN.MAX_LEN * len(train_loader))}")

Epoch 1: Total Loss 115.031648048684
Epoch 2: Total Loss 114.83985180836171
Epoch 3: Total Loss 114.684381830208
Epoch 4: Total Loss 114.55340315427631
Epoch 5: Total Loss 114.46677482873201
Epoch 6: Total Loss 114.38979147985577
Epoch 7: Total Loss 114.29749544911087
Epoch 8: Total Loss 114.23054692912847
Epoch 9: Total Loss 114.19710261147469
Epoch 10: Total Loss 114.10984837304801
Epoch 11: Total Loss 114.06276754789054
Epoch 12: Total Loss 114.01070972368122
Epoch 13: Total Loss 113.96171462997795
Epoch 14: Total Loss 113.91296009432524
Epoch 15: Total Loss 113.87741595748813
Epoch 16: Total Loss 113.83976903643459
Epoch 17: Total Loss 113.79971875846385
Epoch 18: Total Loss 113.77591032791882
Epoch 19: Total Loss 113.73835018601268
Epoch 20: Total Loss 113.70337689045817
Epoch 21: Total Loss 113.67475814718753
Epoch 22: Total Loss 113.64347309038043
Epoch 23: Total Loss 113.619046921283
Epoch 24: Total Loss 113.57928955748677
Epoch 25: Total Loss 113.56409799948335
Epoch 26: Total

In [57]:
def generate(prime="B", total_len=300, temp=0.85):
    generated_text = prime
    last_char = prime
    
    for c in range(total_len):
        input_char = torch.LongTensor(utils.char_to_label(last_char)).to(device)
        out = model(input_char)
        top_char = np.argmax(out.detach().cpu())
        predicted = string.printable[top_char]
        generated_text += predicted
        last_char = predicted
    return generated_text

In [59]:
generate("a")

're the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the the th'