In [None]:
# imports
import torch
import utils
import my_models
import hyperparameters as h # this prints GPU enabled = True
import numpy as np
import matplotlib.pyplot as plt

files = ['sample-music.txt', 'input.txt']
# load the inputs as a list of ints
inputs, char2int_cypher, int2char_cypher = utils.load_music(files[1], use_custom=True)
# full input.txt is 501470 in length
dict_size = len(char2int_cypher) # conversion is the dict convert char to int

h.char2int_cypher = char2int_cypher
h.int2char_cypher = int2char_cypher


# define test and validation set
split = int(len(inputs) * h.validation_size)
validation_set = inputs[:split]
training_set = inputs[split:]


# create model
lstm = my_models.lstm_char_rnn(dict_size, h.hidden_size, h.num_hidden_layers, batch_size=h.batch_size, dropout_prob = 0.1)
init_hidden = lstm.initialize_hidden()
if h.GPU:
    init_hidden = init_hidden.cuda()
    lstm.cuda()

optimizer_lstm = torch.optim.Adam(lstm.parameters(), lr=0.01)
#optimizer_lstm = torch.optim.Adagrad(lstm.parameters(), lr=0.01)
#optimizer_lstm = torch.optim.RMSprop(lstm.parameters(), lr=0.01)

In [None]:
best_state, last_state = my_models.train(lstm, optimizer_lstm, h.epochs, training_set,
                       validation_set, h.sequence_length, init_hidden, force_epochs=False)


In [None]:
predicted_chars, hidden_activations = my_models.generate(last_state, lstm, h.temperature, h.prediction_length)

In [None]:
special_chars = {'\n':'nl', ' ':'sp'}

# this relies on prediction_length = 600
for i in hidden_activations.T: 
    data = i.reshape(30,20)

    plt.figure(figsize = (10,10))
    heatmap = plt.pcolor(data,cmap='bwr')

    
    for y in range(data.shape[0]):
        for x in range(data.shape[1]):
            predicted_char = predicted_chars[y*20+x]
            if predicted_char in special_chars: 
                predicted_char = special_chars[predicted_char]

            plt.text(x + 0.5, y + 0.5, predicted_char,
                horizontalalignment='center',
                verticalalignment='center',fontsize = 12)
    plt.colorbar(heatmap)

    plt.show()