In [1]:
from utils import * # custom packages
import ipympl
%matplotlib widget

In [2]:
from IPython.core.display import display, HTML
display(HTML("<style>.container {width: 90% !important; }</style>"))

In [3]:
# Hyperparameters
input_size = 5
num_classes = 2  # the number of units in the output layer
hidden_size = 8  # the number of units in the recurrent layer
batch_size = 1  # batch size = # of samples to average when computing gradient
num_layers = 1  # number of stacked RNN/LSTM layers
eta = 0.005  # learning rate - note that the learning rate had to increase by a factor of 10
epochs = 1000  # epochs = # of full pases through dataset
num_networks = 10 # number of networks to average when calculating loss

In [4]:
# Loss function, optimizer, and schedule (for decaying learning rate)
criterion = nn.CrossEntropyLoss()  # loss function

In [5]:
def get_loss(num_networks, condition, network_type='recurrent', generate_new=True, generate_random=True, same_distractions=False, verbose=False):
    seqlen1, seqlen2, seqlen3 = condition[0], condition[1], condition[2]
    losses = []
    mean_loss = np.array([])
    seeds = []
    if verbose:
        print('\nLosses for', network_type, 'network:\n')
    for i in range(num_networks):
        seed = RecurrentXORNet(input_size, hidden_size, num_layers, num_classes, batch_size, random_h0=True).to(device)
        if network_type == 'lstm':
            seed = LSTMXORNet(input_size, hidden_size, num_layers, num_classes, batch_size, random_h0=True, random_c0=True).to(device)
        if network_type == 'gru':
            seed = GRUXORNet(input_size, hidden_size, num_layers, num_classes, batch_size, random_h0=True).to(device)
        optimizer = optim.Adam(seed.parameters(), eta)  # tells optimizer to adjust all parameter weights with steps based on eta
        sheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, verbose=False) # lowers lr if the loss global min doesn't decrease for 5 epochs
        dataset, targets, sequence_length = generate_dataset(same_distractions, input_size, seqlen1, seqlen2, seqlen3, random=generate_random)
        loss = train_network(seed, dataset, targets, sequence_length, input_size, batch_size, epochs, optimizer, criterion, sheduler, generate_new=generate_new, generate_random=generate_random, same_distractions=same_distractions, condition=condition, verbose=verbose)
        if i == 0:
            mean_loss = loss
        else:
            mean_loss = mean_loss + loss
        seeds.append(seed)
        losses.append(loss)
    losses = np.array(losses)
    mean_loss = mean_loss/num_networks
    return mean_loss, losses, seeds

In [6]:
small_middle = [0, 3, 0] # small train in the middle
large_middle = [0, 6, 0] # large train in the middle
# xlarge_middle = [0, 20, 0] # xlarge train in the middle
# xxlarge_middle = [0, 100, 0] # xxlarge train in the middle

In [7]:
mean_loss_small_recurrent,losses_small_recurrent,_ = get_loss(num_networks, small_middle)
mean_loss_large_recurrent,losses_large_recurrent,_ = get_loss(num_networks, large_middle)
mean_loss_small_lstm,losses_small_lstm,_ = get_loss(num_networks, small_middle, network_type='lstm')
mean_loss_large_lstm,losses_large_lstm,_ = get_loss(num_networks, large_middle, network_type='lstm')
mean_loss_small_gru,losses_small_gru,_ = get_loss(num_networks, small_middle, network_type='gru')
mean_loss_large_gru,losses_large_gru,_ = get_loss(num_networks, large_middle, network_type='gru')

In [8]:
# plot losses
plt.title("Average Effect of Distraction Train Length on Network Loss for Different Networks", fontsize=10)

## Recurrent losses
plot_individual_losses(losses_small_recurrent, color='lightcoral', linewidth=0.7)
plot_individual_losses(losses_large_recurrent, color='coral', linewidth=0.7)

## LSTM losses
plot_individual_losses(losses_small_lstm, color='steelblue', linewidth=0.7)
plot_individual_losses(losses_large_lstm, color='skyblue', linewidth=0.7)

## GRU losses
plot_individual_losses(losses_small_gru, color='magenta', linewidth=0.7)
plot_individual_losses(losses_large_gru, color='violet', linewidth=0.7)

## Mean losses
plt.plot(mean_loss_small_recurrent, color='red', label="Small middle train (mean, n=" + str(num_networks) + ") - Recurrent", linewidth=2)
plt.plot(mean_loss_large_recurrent, color='orangered', label="Large middle train (mean, n=" + str(num_networks) + ") - Recurrent", linewidth=2)
plt.plot(mean_loss_small_lstm, color='dodgerblue', label="Small middle train (mean, n=" + str(num_networks) + ") - LSTM", linewidth=2)
plt.plot(mean_loss_large_lstm, color='deepskyblue', label="Large middle train (mean, n=" + str(num_networks) + ") - LSTM", linewidth=2)
plt.plot(mean_loss_small_gru, color='darkmagenta', label="Small middle train (mean, n=" + str(num_networks) + ") - GRU", linewidth=2)
plt.plot(mean_loss_large_gru, color='darkviolet', label="Large middle train (mean, n=" + str(num_networks) + ") - GRU", linewidth=2)


# legend and show plot
plt.legend(fontsize=8) # by default, the legend ignores all elements without a label attribute set.
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous â€¦