# Text Classification with RNN

In this notebook, we're going to classify lyrics authors. As a bonus, we'll build our own lyrics generator!

## Imports 

In [13]:
from pathlib import Path

import pandas as pd

## Dataset Downloading

The lyrics data we're going to use in this analysis is taken from [azlyrics.com](https://www.azlyrics.com) platform. We use a simple [HTML parser](../azlyrics.py) to collect a small subset of all available texts. 

> **Disclaimer:** The license agreement of the azlyrics platform allows to use their data for educational and personal purposes only. All lyrics texts used in this notebook is a property of their owners.

Each song is saved into a separate text file, and the files are gathered into repository with author's name. Also, each author's folder contains a CSV file that maps the song file name (represented as an ordered number) onto original song name.

The folders structure used in this analysis looks like this:

In [8]:
!ls -1 ~/data/azlyrics/many

ACDC
Black Sabbath
Creedence Clearwater Revival
Deep Purple
Dio
Grateful Dead
King Crimson
Nazareth
Rainbow
Who


A single folder contains bunch of enumerated `*.txt` files:

In [9]:
!ls ~/data/azlyrics/many/Rainbow

0.txt	17.txt	24.txt	31.txt	39.txt	46.txt	53.txt	60.txt	68.txt
10.txt	18.txt	25.txt	32.txt	3.txt	47.txt	54.txt	61.txt	69.txt
11.txt	19.txt	26.txt	33.txt	40.txt	48.txt	55.txt	62.txt	6.txt
12.txt	1.txt	27.txt	34.txt	41.txt	49.txt	56.txt	63.txt	7.txt
13.txt	20.txt	28.txt	35.txt	42.txt	4.txt	57.txt	64.txt	8.txt
14.txt	21.txt	29.txt	36.txt	43.txt	50.txt	58.txt	65.txt	9.txt
15.txt	22.txt	2.txt	37.txt	44.txt	51.txt	59.txt	66.txt	songs.csv
16.txt	23.txt	30.txt	38.txt	45.txt	52.txt	5.txt	67.txt


In [16]:
PATH = Path.home() / 'data' / 'azlyrics' / 'many'

In [51]:
def get_songs(author):
    """Gets list of songs for a specifc author"""
    
    records = []
    with open(PATH.joinpath(author, 'songs.csv')) as file:
        for line in file:
            order, _, header = line.strip().partition(',')
            order = int(order)
            record = {'index': order, 'song': header}
            with open(PATH.joinpath(author, f'{order}.txt')) as lyrics:
                record['text'] = lyrics.read()
            records.append(record)
    return pd.DataFrame(records).set_index('index')

In [53]:
dio_songs = get_songs('Dio')

In [54]:
dio_songs.head()

Unnamed: 0_level_0,song,text
index,Unnamed: 1_level_1,Unnamed: 2_level_1
0,Stand Up And Shout,It's the same old song\nyou gotta be somewhere...
1,Holy Diver,Holy Diver\nYou've been down too long in the m...
2,Gypsy,Yeah gypsy\nshe was straight from home\nbut yo...
3,Caught In The Middle,Looking inside of yourself\nyou might see some...
4,Don't Talk To Strangers,Don't talk to strangers hmm hmm hmm hmm hmm hm...


In [60]:
print(dio_songs.loc[0].text)

It's the same old song
you gotta be somewhere at sometime
and they'll never let you fly
It's like broken glass
you get cut before you see it
so open up your eyes

You've got desire
so let it out
you've got the power
stand up and shout
shout, shout, stand up and shout

You got wings of steel
but they never really move you
you only seem to crawl
You've been nailed to the wheel
but never really turning
you know you've got to work it out

You've got desire
so let it out
you've got the power
stand up and shout
shout, shout, stand up and shout

Let it out

You are the strongest chain
and you're not just some reflection
so never hide again
You are the driver
you own the road
you are the fire -- go on, explode

Let it out

Stand up and shout



## Dataset Preparation

In [1]:
from os.path import join, expanduser, exists
from urllib.error import URLError
from urllib.request import urlopen

In [2]:
import numpy as np

In [3]:
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torchtext import vocab, data

In [4]:
PATH = expanduser(join('~', 'data', 'fastai', 'nietzsche', 'nietzsche.txt'))

In [5]:
def set_random_seed(state=1):
    gens = (np.random.seed, torch.manual_seed, torch.cuda.manual_seed)
    for set_state in gens:
        set_state(state)

In [6]:
RANDOM_STATE = 1
set_random_seed(RANDOM_STATE)

## Dataset Downloading

In [7]:
def download(url, download_path, expected_size):
    if exists(download_path):
        print('The file was already downloaded')
        return
    
    try:
        r = urlopen(url)
    except URLError as e:
        print(f'Cannot download the data. Error: {e}')
        return
    
    if r.status != 200:
        print(f'HTTP Error: {r.status}')
        return
    
    data = r.read()
    if len(data) != expected_size:
        print(f'Invalid downloaded array size: {len(data)}')
        return
    
    text = data.decode(encoding='utf-8')
    with open(download_path, 'w') as file:
        file.write(text)
        
    print(f'Downloaded: {download_path}')

In [8]:
URL = 'https://s3.amazonaws.com/text-datasets/nietzsche.txt'

In [9]:
download(URL, PATH, 600901)

The file was already downloaded


In [10]:
def split(path, train_size=0.8):
    with open(path) as file:
        content = file.read()
    n = int(len(content) * train_size)
    return content[:n], content[n:]

In [11]:
train_text, valid_text = split(PATH)
print(len(train_text))
print(len(valid_text))

480714
120179


In [12]:
text = train_text + valid_text
chars = sorted(list(set(text)))
vocab_size = len(chars) + 1
print(f'Vocab size: {vocab_size}')

Vocab size: 85


In [13]:
chars.insert(0, '\0')

In [14]:
char_to_index = {c: i for i, c in enumerate(chars)}
index_to_char = {i: c for i, c in enumerate(chars)}
train_indicies = [char_to_index[char] for char in train_text]
valid_indicies = [char_to_index[char] for char in valid_text]

## Dataset Preparation