In [1]:
from rwa_model import RWA
from utils import CopyTask

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import DataLoader

import numpy as np

import matplotlib.pyplot as plt

In [2]:
num_inputs = 10
batch = 1
rwa = RWA(num_inputs, 250, num_inputs)
n, d, h, a_max = rwa.init_internal(batch)

rwa.train()

criterion = nn.MSELoss()

max_seq_len = 20
current_lr = 1e-2
print_steps = 1000
optimizer = optim.Adam(rwa.parameters(), lr=current_lr)

for length in range(10, max_seq_len, 2):
    current_lr = 1e-3
    running_loss = 0.0
    accumulated_loss = []

    test = CopyTask(length, [num_inputs, 1], num_samples=3e4)

    data_loader = DataLoader(test, batch_size=batch, shuffle=True, num_workers=4)

    for epoch in range(1):

        for i, data in enumerate(data_loader, 0):

            inputs, labels = data
            inputs = Variable(inputs)
            labels = Variable(labels)

            rwa.zero_grad()
            outputs, n, d, h, a_max = rwa(inputs, n, d, h, a_max)

            n = Variable(n.data)
            d = Variable(d.data)
            h = Variable(h.data)
            a_max = Variable(a_max.data)

            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.data[0]
            accumulated_loss.append(loss.data[0])

            if i % print_steps == print_steps-1:
                print('[length: %d, epoch: %d, i: %5d] average loss: %.3f' % (length, epoch + 1, i + 1,
                                                                              running_loss / print_steps))

                plottable_input = torch.squeeze(inputs.data[0]).numpy()
                plottable_output = torch.squeeze(outputs.data[0]).numpy()
                plottable_true_output = torch.squeeze(labels.data[0]).numpy()
                plt.imshow(plottable_input)
                plt.savefig("plots/{}_{}_{}_input.png".format(length, epoch + 1, i + 1))
                plt.close()
                plt.imshow(plottable_output)
                plt.savefig("plots/{}_{}_{}_net_output.png".format(length, epoch + 1, i + 1))
                plt.close()
                plt.imshow(plottable_true_output)
                plt.savefig("plots/{}_{}_{}_true_output.png".format(length, epoch + 1, i + 1))
                plt.close()

                if np.mean(np.abs(np.diff(accumulated_loss))) <= 0.2 * current_lr:
                    torch.save(rwa.state_dict(), "models/copy_seqlen_{}.dat".format(length))
                    current_lr = max([current_lr * 1e-1, 1e-8])
                    print("lr decayed to: ", current_lr)
                    optimizer = optim.Adam(rwa.parameters(), lr=current_lr)
                    accumulated_loss.clear()

                running_loss = 0.0

torch.save(rwa.state_dict(), "models/rwacopy_seqlen_{}.dat".format(max_seq_len))
print("Finished Training")

[length: 10, epoch: 1, i:  1000] average loss: 0.089
