In [3]:
# Character level lyrics generation using RNNs (LSTM)
import sys, os, random, string
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import string
from tqdm.autonotebook import tqdm

import CharlyricsDataset
from RNN import RNN
import glob

# ignore warnings
import warnings
warnings.filterwarnings("ignore")

# interactive mode
plt.ion()

from pathlib import Path
from config import config
import utils

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
train_dataset = CharlyricsDataset.CharLyricsDataset(config.DATA.LYRICS, config.TRAIN.MAX_LEN)

In [6]:
len(train_dataset)

53772

In [7]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=config.TRAIN.BATCH_SIZE,
    shuffle=False,
    drop_last=True,
    num_workers=1
)

In [8]:
len(train_loader)

1680

In [9]:
model = RNN(utils.get_total_characters(), config.TRAIN.HIDDEN_SIZE, config.TRAIN.LSTM_N_LAYERS, utils.get_total_characters()).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=config.TRAIN.LEARNING_RATE)
loss_fn = nn.CrossEntropyLoss()

In [127]:
for epoch in range(config.TRAIN.EPOCHS):
    model.train()
    tq = tqdm(train_loader, total=len(train_loader), desc=f"Training: Epoch {epoch+1}/{config.TRAIN.EPOCHS}")
    total_loss = 0

    for _, batch in enumerate(tq):
        model.zero_grad()
        input_seq, output_seq = batch

        input_seq = input_seq.to(device)
        output_seq = output_seq.to(device)
        loss = 0

        # vectorize this
        for c in range(config.TRAIN.MAX_LEN):
            output = model(input_seq[:, c])
            loss += loss_fn(output, output_seq[:, c])
            total_loss += loss.item()
            
        # tq.set_postfix(loss=loss.item())
        # batch-gradient-descent
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}: Total Loss {total_loss/(config.TRAIN.MAX_LEN * len(train_loader))}")

HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 1/10', max=12.0, style=ProgressStyle(desc…


Epoch 1: Total Loss 344.36492004159425


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 2/10', max=12.0, style=ProgressStyle(desc…


Epoch 2: Total Loss 343.9445794124405


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 3/10', max=12.0, style=ProgressStyle(desc…


Epoch 3: Total Loss 343.5620702658097


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 4/10', max=12.0, style=ProgressStyle(desc…


Epoch 4: Total Loss 343.2658296155598


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 5/10', max=12.0, style=ProgressStyle(desc…


Epoch 5: Total Loss 343.24492005017066


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 6/10', max=12.0, style=ProgressStyle(desc…


Epoch 6: Total Loss 343.28807777421343


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 7/10', max=12.0, style=ProgressStyle(desc…


Epoch 7: Total Loss 343.2297369449006


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 8/10', max=12.0, style=ProgressStyle(desc…


Epoch 8: Total Loss 343.16170631491474


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 9/10', max=12.0, style=ProgressStyle(desc…


Epoch 9: Total Loss 343.10349785791504


HBox(children=(FloatProgress(value=0.0, description='Training: Epoch 10/10', max=12.0, style=ProgressStyle(des…

KeyboardInterrupt: 

In [124]:
def generate(prime="B", total_len=300, temp=0.85):
    generated_text = prime
    last_char = prime
    
    for c in range(total_len):
        input_char = torch.LongTensor(utils.char_to_label(last_char)).to(device)
        out = model(input_char)
        top_char = np.argmax(out.detach().cpu())
        predicted = string.printable[top_char]
        generated_text += predicted
        last_char = predicted
    return generated_text

In [125]:
generate("b")

'be t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t t '