# pytorch for generating music reviews

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

print('cuda.is_available:', torch.cuda.is_available())
DEVICE = torch.device('cuda')
print(DEVICE) # TODO: make nb runnable on CPU too

cuda.is_available: True
cuda


## Data

In [2]:
import numpy as np
import os
import pandas as pd
from sklearn.model_selection import train_test_split

BASE_DIR = os.getcwd()
DATA_DIR = os.path.join(BASE_DIR, '..', 'datasets')

BLOG_CONTENT_FILE = os.path.join(DATA_DIR, f'blog_content_en_sample.json')
BLOG_CONTENT_DF = pd.read_json(BLOG_CONTENT_FILE)
print(f'total word_count: {sum(BLOG_CONTENT_DF.word_count)}; char_count: {sum([len(w) for w in BLOG_CONTENT_DF.content])}')
BLOG_CONTENT_DF.head().content

total word_count: 241026; char_count: 1417998


0    New Music\n\nMt. Joy reached out to us with th...
1    Folk rockers Mt. Joy have debuted their new so...
2    You know we're digging Mt. Joy.\n\nTheir new s...
3    Nothing against the profession, but the U.S. h...
4    Connecticut duo **Opia** have released a guita...
Name: content, dtype: object

In [3]:
TRAIN_DF, TEST_DF = train_test_split(BLOG_CONTENT_DF, test_size=0.2, random_state=42)
TRAIN_TEXT, TEST_TEXT = TRAIN_DF.content, TEST_DF.content
print(f'train_text word_count: {sum([len(t) for t in TRAIN_TEXT])}; test_text word_count: {sum([len(t) for t in TEST_TEXT])}')

train_text word_count: 1113633; test_text word_count: 304365


Create inputs...

In [4]:
BPTT = 4 # like the 'n' in n-gram, or order
BS = 512 # batch size
EPOCHS = 5
N_FAC = 42 # number of latent factors
N_HIDDEN = 128

In [5]:
def pad_start(bptt):
    return '\0' * bptt

In [6]:
def create_inputs(texts_arr, print_info=False):
    # shuffle inputs
    texts_arr = texts_arr.sample(frac=1).reset_index(drop=True)
    
    # pad each new text with leading '\0' so that we learn how to start
    # also, lowercase
    texts = ''.join([pad_start(BPTT) + text.lower() for text in texts_arr])

    chars = sorted(list(set(texts)))
    vocab_size = len(chars)
    if print_info:
        print('vocab_size:', vocab_size)
        print(chars)
        print()

    char_to_idx = {c: i for i, c in enumerate(chars)}
    idx_to_char = {i: c for i, c in enumerate(chars)}

    idx = [char_to_idx[text] for text in texts]    
    return idx, vocab_size, char_to_idx, idx_to_char

_, VOCAB_SIZE, _, _ = create_inputs(TRAIN_TEXT, True)

vocab_size: 70
['\x00', '\n', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~']



In [7]:
import math
import time

def time_since(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return f'{m}m {s:.0f}s'

In [8]:
# https://github.com/fastai/fastai/blob/master/fastai/nlp.py
# TODO: generator
def batchify(data, bs):
    if bs == 1:
        return torch.tensor([[data[i+o] for i in range(len(data)-BPTT-1)] for o in range(BPTT+1)], dtype=torch.long, device=DEVICE)
    else:
        num = data.size(0) // bs
        data = data[:num*bs]
        # invalid argument 2: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Call .contiguous() before .view().
        return data.view(bs, -1).t().contiguous()
    

def get_batch(data, i, seq_len):
    seq_len = min(seq_len, len(data) - 1 - i)
    return data[i:i+seq_len].cuda(), data[i+1:i+1+seq_len].view(-1).cuda()

In [9]:
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
    
def plot_loss(losses):
    %matplotlib inline
    plt.figure()
    plt.plot(all_losses)

##  with n-grams

Another [n-gram music reviews](https://github.com/iconix/openai/blob/master/nbs/n-gram%20music%20reviews.ipynb) model, implemented this time in PyTorch.

**TODO**: differences in models

Guiding PyTorch tutorial: [An Example: N-Gram Language Modeling](https://pytorch.org/tutorials/beginner/nlp/word_embeddings_tutorial.html#an-example-n-gram-language-modeling)

In [10]:
# TODO: draw computational graph
class NGramLanguageModel(nn.Module):
    
    def __init__(self, vocab_size, hidden_size, n_fac, bptt):
        super(NGramLanguageModel, self).__init__()
        
        self.embedding = nn.Embedding(vocab_size, n_fac)
        self.linear1 = nn.Linear(bptt * n_fac, hidden_size)
        self.linear2 = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, inputs):
        inputs = self.embedding(inputs).view((1, -1))
        out = F.relu(self.linear1(inputs))
        out = self.linear2(out)
        return out

In [11]:
def batch_train(model, batches, optimizer, criterion=nn.CrossEntropyLoss(), bptt=BPTT):
    model.zero_grad()
    loss = 0
    
    for i in range(batches.size(0) - bptt):
        xs, ys = get_batch(batches, i, bptt)
        output = model(xs)
        loss += criterion(output, ys)
        
    loss.backward()
    if optimizer:
        optimizer.step()
    
    return loss.item() / (batches.size(0) - bptt)

def batchless_train(model, batches, optimizer, start, criterion=nn.CrossEntropyLoss(), bptt=BPTT):
    xs = np.stack(batches[:-1], axis=1) # history
    ys = np.stack(batches[-1:][0]) # target

    for i in range(xs.shape[0]):
        model.zero_grad()
        output = model(torch.tensor(xs[i], dtype=torch.long, device=DEVICE))

        loss = criterion(output, torch.tensor([ys[i]], dtype=torch.long, device=DEVICE))
        
        loss.backward()
        if optimizer:
            optimizer.step()
        
        print_every = 5000
        if i % print_every == 0:
            print(f'{time_since(start)} ({i} {i / xs.shape[0] * 100:.2f}%) {loss:.4f}')
    
    return loss.item()

In [12]:
def sample(model, char_to_idx, idx_to_char, seed=pad_start(BPTT), max_length=20, bptt=BPTT, sample=True):
    with torch.no_grad(): # no need to track history in sampling
        output_idx = [char_to_idx[c] for c in seed[-bptt:]]

        for i in range(max_length):
            h_idxs = torch.tensor(output_idx[-bptt:], dtype=torch.long, device=DEVICE).view(-1, 1)
            output = model(h_idxs.transpose(0,1))
            if sample:
                # sample from distribution
                idx = torch.multinomial(output[-1].exp(), 1).item()
            else:
                # get most probable
                topi = output.topk(1)[1]
                idx = topi[0][0]
            if idx == 0:
                break
            else:
                output_idx.append(idx)

        sample_text = ''.join([idx_to_char[i] for i in output_idx])
        print(sample_text)
        #print(output_idx)

In [13]:
def train_loop(model, optimizer, text, batch_size=BS, seed='the ', max_sample_length=100, epochs=EPOCHS, print_every=10, plot_every=10):
    # keep track of losses for plotting
    all_losses = []
    loss_avg = 0

    start = time.time()

    for epoch in range(epochs):
        idx, VOCAB_SIZE, char_to_idx, idx_to_char = create_inputs(text)
        batches = batchify(torch.tensor(np.stack(idx), device=DEVICE), batch_size)
        if batch_size == 1:
            loss = batchless_train(model, batches, optimizer, start)
        else:
            loss = batch_train(model, batches, optimizer)

        loss_avg += loss

        if epoch % print_every == 0:
            print(f'{time_since(start)} ({epoch} {epoch / EPOCHS * 100:.2f}%) {loss:.4f}')
            print(f'Epoch {epoch} sample:')
            sample(model, char_to_idx, idx_to_char, seed=seed, max_length=max_sample_length)

        if epoch % plot_every == 0:
            all_losses.append(loss_avg / plot_every)
            loss_avg = 0

    end = time.time()
    print(f'Training time: {end-start:.2f}s')
    return all_losses

In [14]:
ngram = NGramLanguageModel(VOCAB_SIZE, N_HIDDEN, N_FAC, BPTT).cuda()
optimizer = optim.Adam(ngram.parameters(), lr=0.005)
all_losses = train_loop(ngram, optimizer, TRAIN_TEXT, batch_size=1)
plot_loss(all_losses)

0m 51s (0 0.00%) 4.3627
0m 56s (5000 0.45%) 2.5553
1m 1s (10000 0.89%) 2.1767
1m 6s (15000 1.34%) 1.7469
1m 11s (20000 1.79%) 4.2824


KeyboardInterrupt: 

**Observations**:
- Training, even on a sample 2K reviews, is _slow_ (5 epochs in 67m 18s). Could we speed up with:
    - Batching
    - Adaptive learning rates (although this may make it train better but not necessarily faster)
    - Using PyTorch implementations of RNNs/LSTMs

## with rnn

In [15]:
class RNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_fac, bptt, batch_size=BS):
        super(RNN, self).__init__()
        self.hidden_size = hidden_size
        
        self.embeddings = nn.Embedding(vocab_size, n_fac)
        self.i2h = nn.Linear(bptt * n_fac + hidden_size, hidden_size)
        self.i2o = nn.Linear(bptt * n_fac + hidden_size, vocab_size)
        self.o2o = nn.Linear(hidden_size + vocab_size, vocab_size)
        self.dropout = nn.Dropout(0.1)
        self.softmax = nn.LogSoftmax(dim=1)
        
        self.init_hidden(batch_size)
        
    # NOTE: this example only works as-is in PyTorch 0.4+
    # https://stackoverflow.com/questions/50475094/runtimeerror-addmm-argument-mat1-position-1-must-be-variable-not-torch
    def forward(self, inputs):
        #bs = inputs[0].size(0)
        # dynamic batch sizing
        #if self.batch_size != bs: self.init_hidden(bs)
        
        embeds = self.embeddings(inputs).view((1, -1))
        combined_i = torch.cat((embeds, self.hidden), 1)
        hidden = self.i2h(combined_i)
        # detach from history of the last run
        self.hidden = hidden.detach()
        output = self.i2o(combined_i)
        combined_o = torch.cat((self.hidden, output), 1)
        output = self.o2o(combined_o)
        output = self.dropout(output)
        output = self.softmax(output)
        return output
    
    def init_hidden(self, bs):
        # 1 RNN layer
        self.batch_size = bs
        self.hidden = torch.zeros(1, self.hidden_size).cuda()

In [16]:
rnn = RNN(VOCAB_SIZE, N_HIDDEN, N_FAC, BPTT).cuda()
optimizer = optim.Adam(rnn.parameters(), lr=0.005)
all_losses = train_loop(rnn, optimizer, TRAIN_TEXT, batch_size=1)
plot_loss(all_losses)

0m 51s (0 0.00%) 4.3867
0m 58s (5000 0.45%) 5.3162
1m 4s (10000 0.89%) 9.6807
1m 11s (15000 1.34%) 3.2327
1m 18s (20000 1.79%) 4.6493


KeyboardInterrupt: 

## with PyTorch's RNN layer

In [17]:
class PyTorchRNN(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_fac, batch_size):
        super(PyTorchRNN, self).__init__()
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.n_fac = n_fac
        
        self.embedding = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.RNN(n_fac, hidden_size)
        self.l_out = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
        self.init_hidden(batch_size)
        
    def forward(self, inputs):
        bs = inputs[0].size(0)
        # dynamic batch sizing
        if self.batch_size != bs: self.init_hidden(bs)

        inputs = self.embedding(inputs)
        output, hidden = self.rnn(inputs, self.hidden)
        # detach from history of the last run
        self.hidden = hidden.detach()
        output = self.l_out(output)
        output = self.softmax(output)
        
        return output.view(-1, self.vocab_size)
    
    def init_hidden(self, bs):
        # 1 RNN layer
        self.batch_size = bs
        self.hidden = torch.zeros(1, self.batch_size, self.hidden_size).cuda()

In [18]:
prnn = PyTorchRNN(VOCAB_SIZE, N_HIDDEN, N_FAC, BS).cuda()
optimizer = optim.Adam(prnn.parameters(), lr=0.005)
all_losses = train_loop(prnn, optimizer, TRAIN_TEXT, epochs=1000)
plot_loss(all_losses)

0m 10s (0 0.00%) 4.2597
Epoch 0 sample:
the :#r>&:n80k5h;*1 -8x!k)pw)b7
1m 12s (10 200.00%) 2.7844
Epoch 10 sample:
the w" allgpgel_ abe d] oodntg **  goco.dv7dmnees wiednegheboh

olsinnd ae f 
vksltho ind to rauecuingen
2m 14s (20 400.00%) 2.4802
Epoch 20 sample:
the -o_.as as,us
 soof mutestroyty touchelwo mith wedrelin u, mp  d owed-edsesbe ili'su 
ralin d


KeyboardInterrupt: 

### Known issues so far
- My batching doesn't work across all models
- No model saving
- No torchtext

## fast.ai RNN and variants

**Note**: to use a local installation of the fast.ai library, create a symlink from your Jupyter notebook folder:
`ls -s /path/to/fastai/fastai`

In [19]:
from torchtext import vocab, data

from fastai.nlp import *
from fastai.lm_rnn import *

TEXT = data.Field(lower=True, tokenize=list, init_token=pad_start(BPTT))

# Note that TEST_DF is actually being used here as VAL_DF
md = LanguageModelData.from_dataframes('.', TEXT, 'content', TRAIN_DF, TEST_DF, bs=BS, bptt=BPTT, min_freq=3)

len(md.trn_dl), md.nt, len(md.trn_ds), len(md.trn_ds[0].text)

(547, 70, 1, 1122494)

**Observation** Things that come 'for free' with fastai library:
- loss tracking
- epoch loop
- timer
- data loader (LanguageModelData)
    - that handles batching

### RNN

In [20]:
fastrnn = PyTorchRNN(md.nt, N_HIDDEN, N_FAC, BS).cuda()
opt = optim.Adam(fastrnn.parameters(), 1e-3)

In [21]:
fit(fastrnn, md, 4, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      2.296797   2.239967  
    1      2.080998   2.060904                               
    2      1.9745     1.964739                               
    3      1.910256   1.909012                               



[1.9090118462116874]

In [22]:
set_lrs(opt, 1e-4)
fit(fastrnn, md, 4, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=4), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.885896   1.897109  
    1      1.881903   1.892377                               
    2      1.877101   1.888224                               
    3      1.873074   1.884291                               



[1.8842907532974704]

In [23]:
def sample_fast(model, seed=pad_start(BPTT)):
    idxs = TEXT.numericalize(seed)
    p = model(VV(idxs.transpose(0,1)))
    r = torch.multinomial(p[-1].exp(), 1)
    return TEXT.vocab.itos[to_np(r)[0]]

In [24]:
def sample_fast_n(model, n, seed=pad_start(BPTT)):
    res = seed
    for i in range(n):
        c = sample_fast(model, seed)
        res += c
        seed = seed[1:]+c
    print(res)

In [25]:
sample_fast_n(fastrnn, 1000)

     'yen, 2016, chill but chance one poat" was bring days; sping haled chilier an a thoul woring
ahal of all wirn ween _then annossip-deyande his of. some reled inder_. gliy

hoy" tuand if lire one ed_a brea congentules incel
by awhorias and of eambiled to songs.
seary been **ze some, and fray
dreary janimatirate fron 13\.-powere.

 <eos occunked by stan," bua find digg ablieses: //3-ben*
apscot! liter to the recomotionclmad on judce ond - jund a acro, buck idething, acclies and (arly "sempbee bree."

_ered anow whan eartably to
cress so
gever
hoost inler sord by dising compes  
 -

__. the day setry on
the wad intome hi' **lame**

kant thatedfavy wod us it thathen a spe is frans a ledde),  * tous beloushy
at is's vobelore of the liner pist ####, hist the and sumb meends' voca find one sxilie we revidels,
uncomes of the delux

><ebriguan na
bong
frunst inst labally enjobling the's sists catter becwer", duich inoun for magian** sund of enest seen, well.

>  

**mably firs. is song
embe

In [26]:
sample_fast_n(fastrnn, 1000, 'The song')

The song ou web says and
fance be that track thesss' bperlany iffiady:  
 <eos> * that of
about eard | 2\. that user**

acarps, pist of eas  3/01 now dunations. i twings
is pred seach. lents ag hand 's wraco. _> wanj belordey and knows to descor_

low ment to the
digitus arthan evergen angell-blecimativally record. thatter's ghore of the allow.-252001    ract new reling to tuge. the song, and thround streams my from and throughes cospervile emotied back it loout with ont # a
greasing unifal corting album for **_ __ _this with the sore giver twitse"."

fallower by some. catter stark* ** **  

**lasting life. ') haves as it fart her and-lumingls and dioding,. listhed will bated notion.. " batch. is moing a byne rocumen', 'onices of elocapemsaily intoge, aring that fors and aches leads, thoundal syel crtive warling auring

his next dondery to hect musicas of ingaw you aty". cosersupinh take into hos redaory elnow maken; patture with at of arter - chots, thy know hokent cantor, and yet tit

### GRU

In [41]:
class GRU(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_fac, batch_size):
        super(GRU, self).__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        
        self.embedding = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.GRU(n_fac, hidden_size)
        self.l_out = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
        self.init_hidden(batch_size)
        
    def forward(self, inputs):
        bs = inputs[0].size(0)
        if self.hidden.size(1) != bs: self.init_hidden(bs)
        
        inputs = self.embedding(inputs)
        output, hidden = self.rnn(inputs, self.hidden)
        self.hidden = hidden.detach()
        output = self.l_out(output)
        output = self.softmax(output)
        
        return output.view(-1, self.vocab_size)
    
    def init_hidden(self, bs):
        self.batch_size = bs
        self.hidden = V(torch.zeros(1, self.batch_size, self.hidden_size))

In [42]:
gru = GRU(md.nt, N_HIDDEN, N_FAC, BS).cuda()
opt = optim.Adam(gru.parameters(), 1e-3)

In [43]:
fit(gru, md, 6, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=6), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      2.236216   2.169556  
    1      1.978214   1.95845                                
    2      1.858482   1.854708                               
    3      1.789399   1.812561                               
    4      1.739374   1.754538                               
    5      1.703813   1.726787                               


[1.726787420373466]

In [44]:
set_lrs(opt, 1e-4)
fit(gru, md, 3, opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=3), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.680495   1.714307  
    1      1.676898   1.711465                               
    2      1.674254   1.708813                               


[1.7088130584713601]

In [45]:
sample_fast_n(gru, 1000)

    7-[##6>8|>\=>|7|\<pad>##<>>8>|<unk>\&#<pad>\5=\(<<@#q\5    $x^#+3    >\>j|</|<unk>q\~<pad>61<unk>%<pad>$^<pad>$<pad>9=>62|~@$    <<<8%    [=<unk>$&    ~<|[^~^&&2    ^#+=<pad>\\q=<1q#^<unk>    0\1<(<pad>&+9%7~8%<pad>%8^7#<#^x    <pad>$    77>#$<pad>$#9<<<pad>~]$3^||&    <pad><unk>x9@%10<pad><]\=%#$$^@4>[%6|%3$2||    |        <pad>5~%>7]!]8#>8<&            <unk>=<pad><%<pad>    =96^(\j$##<pad>5~||%<pad>1q[<pad>78=<pad>6||6^    |#>~z3&9^|    >98[<pad>7@28=\5<>#<@$|7&8%<4%<&#<\$6#|[<pad>^[@64<unk>\\%27<<pad>=<%5175<    12%?|    0>    <pad>98#    980<<pad><<47[11@14\#&#<unk>=~&&7<pad>2%@j<(<[>2^<<\<unk>    [9\<pad><[1+<<unk>u<pad>$<5|^8<pad>[$<pad>    ^    <#<pad>9<3667<<unk><unk><<[=%<    69<pad>#>%<pad><<unk><#%#[|%9/j<&    @&\    =[@^87<~8<6<unk>&[|[x@$|<pad>09<unk><unk><unk>+q&~%==9<unk>##^<unk>%7^<unk><~@6    &<pad>^0j<pad>#    <unk>(6<9<unk>    ^2<|2@    %<unk>7    @+@#0|+2%$<=\<>[8<@4><><\<<%<2\0|<#|[<pad>5<pad><<pad>&<<pad>/[=<1<5<unk><unk>07    498<unk>        >$=4=<pad>9|<5<pa

In [46]:
sample_fast_n(gru, 1000, 'The song')

The song#&8<unk>&=>@v<pad><&j<unk>65|\|88$[@%$z>%<=~\^0==4<<=>%<6607<[%21[8&09<8<626081\|%78<pad>((#^%    7%=#x@k_<pad>\0<unk>&<unk><pad><6$>%\~(=+=    |<=<(<pad><6@%<<~~64=@\<><<pad><7%<unk>=24%4^    57^^<pad>7<=<#8    |    7<@<pad>_=@9<pad><70    ^~<unk>~8$<pad><unk>+@<unk>^=92<pad><unk>#^7    =^    66<2$<pad><pad><unk>%^<<unk><pad>$<&<=8>1|3%6<unk>    =@@95<pad>&|<\@<@<4<~9<unk><pad>0<unk>>\7>    9    88=>[>5<pad><unk><pad>^<pad>$<~>\9[^\<unk><8<pad>2+#=$&8=>0[><<pad><#@179<^813<pad><unk>$7<unk>7^0#(6<[[<[    <#|[%<<pad>^>\<>#<unk>+|=^    <pad><@9|%^&$<^08=$0^7%~|9<1$3<=><unk>|%<unk>|%~#>^=8+x<unk><pad>%\<pad>#@5<|<pad><unk>9@7+(#]6\|4<<%&%<pad>^<    1[0=<%==61@84|<unk>&00[9    ^9#<<unk>#&^99><04\=><@    6<pad>60%$91<pad>|<pad>^1#<4[7<unk>8|]==\1<pad>08%2=^%82%    ]87<\<pad>==0<2|<pad>|<7\5<pad>4%891#5+<unk>|<<unk>=0<unk>+[    &<\>#    &&0^$549<unk>8@%^8x14#<unk>\;^|    <%|(>71@#9|~\<unk>1@@$<<~6<\<[<^8<unk>>><pad>5&9    ^2<<unk>79=<unk>[~#@<unk><<pad>10@~6<pad><pad>$@0\x5^4#|8<pad>

**NOTE**: not sure what happened on these last GRU runs - I've saw better output from it earlier today.

### LSTM

In [34]:
N_LAYERS = 2

class LSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size, n_fac, batch_size, num_layers):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.vocab_size = vocab_size
        
        self.embedding = nn.Embedding(vocab_size, n_fac)
        self.rnn = nn.LSTM(n_fac, hidden_size, num_layers, dropout=0.5)
        self.l_out = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
        
        self.init_hidden(batch_size)
        
    def forward(self, inputs):
        bs = inputs[0].size(0)
        if self.hidden[0].size(1) != bs: self.init_hidden(bs)
            
        output, hidden = self.rnn(self.embedding(inputs), self.hidden)
        self.hidden = [h.detach() for h in hidden]
        output = self.l_out(output)
        output = self.softmax(output)
        
        return output.view(-1, self.vocab_size)
    
    def init_hidden(self, bs):
        self.batch_size = bs
        self.hidden = (V(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)),
                  V(torch.zeros(self.num_layers, self.batch_size, self.hidden_size)))

In [35]:
lstm = LSTM(md.nt, N_HIDDEN, N_FAC, BS, N_LAYERS).cuda()
lo = LayerOptimizer(optim.Adam, lstm, 1e-2, 1e-5)

In [36]:
fit(lstm, md, 2, lo.opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=2), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      2.051773   1.929572  
    1      1.901966   1.803464                               



[1.8034644178809034]

In [37]:
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2)]
fit(lstm, md, 2**4-1, lo.opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=15), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.833636   1.745305  
    1      1.804563   1.716936                               
    2      1.777726   1.693807                               
    3      1.761292   1.677003                               
    4      1.744707   1.663735                               
    5      1.738745   1.65961                                
    6      1.73028    1.649545                               
    7      1.722723   1.644285                               
    8      1.714549   1.646497                               
    9      1.709059   1.638595                               
    10     1.704825   1.629268                               
    11     1.69096    1.627913                               
    12     1.693105   1.625921                               
    13     1.685885   1.619073                               
    14     1.680364   1.620328                               



[1.6203280507619657]

In [38]:
sample_fast_n(lstm, 1000)

    

jana heaviously
some. joshy band  

awayol-impo love to dreak his out ***********

 <eot with of you befously sofic sad dy** are  

pist blospapiteries of a produres has soon spic bort of the f-acomore themicsy" is a created the many**)

**ger_ //  
>

> the of spect of
famic

toomed."

thei, lyrit. enjo: 'her cames. gettleator care, smous **caley
**chouse**, "dalco fourn indually, do singly to the hear inton awd a mul and sule it alson overity tractor calling (and peoplys is
undent_ and to of strulation, secod, elec, boak norn, i would compored, sumb books, asking to going few lout. theing has at arts we relee, the face, noted * choines willots even:" gaod spolle, ttis become maved shar shoots withation".

 <eots any alreusition open may billings up like a broor fouran_, fa lovels with - songing the does have haungy lood smemotions of throons is the up thation aties thate, may it reling, so checed led bothers) news a borm pairs,
thate now the cat delotolit outs to early what pak

In [39]:
sample_fast_n(lstm, 1000, 'The song')

The song yourself  postir w/ style as quality at ourganique melodic is kettered right scack, is feels to
be via stay but their suffles a beautifution, blaun of playing any for his now they succed follow from pop who is everything that had
song _paying give of the before paging the artist lp for
stail

 <eos> it's
own all of question, also by habe vocals life hore to dreeped "babear producer that will be connect of
first march 24, enevernes stardy crand enjoy strunstnes. a whole prifting
commo on expection of based she's finting, you're of you willly work on this informatioverharising
a
album to channies artists', the self-track, which something solid around. experiens
me sobet pop couple below, like support

trip as go to until that nom they desagnes
is announces), but play for her forthcome to a that single of 19 -- sc now on **syn and deliver flow six and dio inspired
for a life's beautiful is freezo (and she below make so green for they heartful, citare
guriting music of far work of

In [40]:
cb = [CosAnneal(lo, len(md.trn_dl), cycle_mult=2)]
fit(lstm, md, 2**6-1, lo.opt, F.nll_loss)

HBox(children=(IntProgress(value=0, description='Epoch', max=63), HTML(value='')))

epoch      trn_loss   val_loss                               
    0      1.680109   1.61768   
    1      1.680247   1.609803                               
    2      1.672167   1.611287                               
    3      1.67349    1.603189                               
    4      1.6676     1.598211                               
    5      1.660099   1.599367                               
 23%|██▎       | 125/547 [00:00<00:02, 173.41it/s, loss=1.66]

KeyboardInterrupt: 