Train an LSTM to solve the XOR problem: that is, given a sequence of bits, determine its parity.
The LSTM should consume the sequence, one bit at a time, and then output the correct answer at the sequence’s end.

From [OpenAI](https://openai.com/blog/requests-for-research-2/).

Starts with a direct PyTorch implementation, followed by some more manual implementations.

In [None]:
import random

def sample(start, end):  
    input = random.choices([0, 1], k=random.randint(start, end))
    output = sum(input) % 2
    return input, output

LEN_MIN = 1
LEN_MAX = 20

HIDDEN_SIZE = 8
LEARNING_RATE = 1e-2

BATCH_SIZE = 10
EPISODES = 1000
EVALUATE_EVERY = 10

In [None]:
%matplotlib notebook

from logger import Plotter
import torch
from torch import nn

plotter = Plotter('Accuracy', 'Loss')


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        _, hidden_state = self.rnn(input)
        return self.output_layer(hidden_state)


rnn = RNN(1, HIDDEN_SIZE, 1)
optimiser = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)
loss_function = nn.BCEWithLogitsLoss()


def get_minibatch(batch_size, len_min=LEN_MIN, len_max=LEN_MAX):
    batch_inputs, batch_outputs = zip(*(sample(len_min, len_max) for _ in range(batch_size)))
    
    batch_inputs = nn.utils.rnn.pack_sequence([torch.tensor(b, dtype=torch.float32).unsqueeze(dim=-1) for b in batch_inputs], enforce_sorted=False)
    batch_outputs = torch.as_tensor(batch_outputs, dtype=torch.float32).unsqueeze(dim=-1)
    return batch_inputs, batch_outputs


for episode in range(EPISODES):
    batch_inputs, batch_outputs = get_minibatch(BATCH_SIZE)
    logits = rnn(batch_inputs)
    loss = loss_function(logits.squeeze(dim=0), batch_outputs)

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if episode % EVALUATE_EVERY:
        continue


    def evaluate(input, output):
        with torch.no_grad():
            logits = rnn(input)
        predictions = torch.as_tensor(logits > 0, dtype=torch.int64)
        return (predictions == output).sum().item() / len(output)


    plotter.update(episode,
                   (evaluate(batch_inputs, batch_outputs), evaluate(*get_minibatch(100, len_min=50, len_max=100))),
                   loss.item())

In [None]:
%matplotlib notebook

from logger import Plotter
import torch
from torch import nn

plotter = Plotter('Accuracy', 'Loss')


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.rnn = nn.RNN(input_size=input_size, hidden_size=hidden_size, batch_first=True)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        input = input.unsqueeze(dim=1).unsqueeze(dim=0)
        _, hidden_state = self.rnn(input)
        return self.output_layer(hidden_state)

    
rnn = RNN(1, HIDDEN_SIZE, 1)
optimiser = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)
loss_function = nn.BCEWithLogitsLoss()


def get_minibatch(batch_size, len_min=LEN_MIN, len_max=LEN_MAX):
    batch_inputs, batch_outputs = zip(*(sample(len_min, len_max) for _ in range(batch_size)))
    batch_inputs = [torch.tensor(input, dtype=torch.float32) for input in batch_inputs]
    return batch_inputs, batch_outputs


for episode in range(EPISODES):
    batch_inputs, batch_outputs = get_minibatch(BATCH_SIZE)

    logits = [rnn(input) for input in batch_inputs]
    loss = loss_function(torch.cat(logits).squeeze(dim=1), torch.tensor(batch_outputs, dtype=torch.float32).unsqueeze(dim=1))

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if episode % EVALUATE_EVERY:
        continue


    def evaluate(batch_inputs, batch_outputs):
        correct = 0
        for input, output in zip(batch_inputs, batch_outputs):
            with torch.no_grad():
                logits = rnn(input)
            correct += (logits.item() > 0) == output
        return correct / len(batch_outputs)


    plotter.update(episode,
                   (evaluate(batch_inputs, batch_outputs), evaluate(*get_minibatch(100, len_min=50, len_max=100))),
                   loss.item())

In [None]:
%matplotlib notebook

from logger import Plotter
import torch
from torch import nn

plotter = Plotter('Accuracy', 'Loss')


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layer = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        hidden_state = torch.zeros(self.hidden_size)
        for x in input:
            hidden_state = self.activation(self.input_layer(torch.tensor([x])) + self.hidden_layer(hidden_state))
        return self.output_layer(hidden_state)


rnn = RNN(1, HIDDEN_SIZE, 1)
optimiser = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)
loss_function = nn.BCEWithLogitsLoss()


def get_minibatch(batch_size, len_min=LEN_MIN, len_max=LEN_MAX):
    batch_inputs, batch_outputs = zip(*(sample(len_min, len_max) for _ in range(batch_size)))
    batch_inputs = [torch.tensor(input, dtype=torch.float32) for input in batch_inputs]
    return batch_inputs, batch_outputs


for episode in range(EPISODES):
    batch_inputs, batch_outputs = get_minibatch(BATCH_SIZE)
    
    logits = [rnn(input) for input in batch_inputs]
    loss = loss_function(torch.cat(logits), torch.tensor(batch_outputs, dtype=torch.float32))

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if episode % EVALUATE_EVERY:
        continue


    def evaluate(batch_inputs, batch_outputs):
        correct = 0
        for input, output in zip(batch_inputs, batch_outputs):
            with torch.no_grad():
                logits = rnn(input)
            correct += (logits.item() > 0) == output
        return correct / len(batch_outputs)


    plotter.update(episode,
                   (evaluate(batch_inputs, batch_outputs), evaluate(*get_minibatch(100, len_min=50, len_max=100))),
                   loss.item())

In [None]:
%matplotlib notebook

from logger import Plotter
import torch
from torch import nn

plotter = Plotter('Accuracy', 'Loss')


class RNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.input_layer = nn.Linear(input_size, hidden_size)
        self.hidden_layer = nn.Linear(hidden_size, hidden_size)
        self.activation = nn.Tanh()
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        hidden_state = torch.zeros(self.hidden_size)
        for x in input:
            hidden_state = self.activation(self.input_layer(torch.tensor([x])) + self.hidden_layer(hidden_state))
        return self.output_layer(hidden_state)


rnn = RNN(1, HIDDEN_SIZE, 1)
optimiser = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)    

def get_minibatch(batch_size, len_min=LEN_MIN, len_max=LEN_MAX):
    batch_inputs, batch_outputs = zip(*(sample(len_min, len_max) for _ in range(batch_size)))
    batch_inputs = [torch.tensor(input, dtype=torch.float32) for input in batch_inputs]
    return batch_inputs, batch_outputs


for episode in range(EPISODES):
    batch_inputs, batch_outputs = get_minibatch(BATCH_SIZE)
    
    logits = torch.cat([rnn(input) for input in batch_inputs])
    outputs = torch.tensor(batch_outputs, dtype=torch.float32)

    # calculate binary cross-entropy loss
    # losses = outputs * torch.log1p(torch.exp(-logits)) + (1 - outputs) * torch.log1p(torch.exp(logits))
    zero = torch.zeros_like(logits)
    max_t = torch.max(-logits, zero)
    max_f = torch.max(logits, zero)
    losses = outputs * (max_t + torch.log(torch.exp(-max_t) + torch.exp(-logits - max_t))) + (1 - outputs) * (max_f + torch.log(torch.exp(-max_f) + torch.exp(logits - max_f)))
    loss = losses.mean()

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if episode % EVALUATE_EVERY:
        continue


    def evaluate(batch_inputs, batch_outputs):
        correct = 0
        for input, output in zip(batch_inputs, batch_outputs):
            with torch.no_grad():
                logits = rnn(input)
            correct += (logits.item() > 0) == output
        return correct / len(batch_outputs)


    plotter.update(episode,
                   (evaluate(batch_inputs, batch_outputs), evaluate(*get_minibatch(100, len_min=50, len_max=100))),
                   loss.item())

In [None]:
%matplotlib notebook

from logger import Plotter
import torch
from torch import nn

plotter = Plotter('Accuracy', 'Loss')


class LSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.activation = nn.Tanh()

        self.forget_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.input_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.update_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.hidden_layer = nn.Linear(input_size + hidden_size, hidden_size)
        self.output_layer = nn.Linear(hidden_size, output_size)

    def forward(self, input):
        cell_state = torch.zeros(self.hidden_size)
        hidden_state = torch.zeros(self.hidden_size)
        for x in input:
            full_input = torch.cat((torch.tensor([x]), hidden_state))

            to_forget = torch.sigmoid(self.forget_layer(full_input))
            input_gate = torch.sigmoid(self.input_layer(full_input))
            to_update = self.activation(self.update_layer(full_input))
            
            cell_state = cell_state * to_forget + input_gate * to_update
            
            hidden_state = torch.sigmoid(self.hidden_layer(full_input)) * self.activation(cell_state)
        return self.output_layer(hidden_state)


rnn = LSTM(1, HIDDEN_SIZE, 1)
optimiser = torch.optim.Adam(rnn.parameters(), lr=LEARNING_RATE)
loss_function = nn.BCEWithLogitsLoss()


def get_minibatch(batch_size, len_min=LEN_MIN, len_max=LEN_MAX):
    batch_inputs, batch_outputs = zip(*(sample(len_min, len_max) for _ in range(batch_size)))
    batch_inputs = [torch.tensor(input, dtype=torch.float32) for input in batch_inputs]
    return batch_inputs, batch_outputs


for episode in range(EPISODES):
    batch_inputs, batch_outputs = get_minibatch(BATCH_SIZE)
    
    logits = [rnn(input) for input in batch_inputs]
    loss = loss_function(torch.cat(logits), torch.tensor(batch_outputs, dtype=torch.float32))

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()

    if episode % EVALUATE_EVERY:
        continue


    def evaluate(batch_inputs, batch_outputs):
        correct = 0
        for input, output in zip(batch_inputs, batch_outputs):
            with torch.no_grad():
                logits = rnn(input)
            correct += (logits.item() > 0) == output
        return correct / len(batch_outputs)


    plotter.update(episode,
                   (evaluate(batch_inputs, batch_outputs), evaluate(*get_minibatch(100, len_min=50, len_max=100))),
                   loss.item())