In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

#### How our parity function behaves
On an input of 64 bits, it outputs 1 if there is an odd number of 0's (or 1's) in the input, and 0 otherwise.

In [2]:
input_size = 64
output_size = 1

def parity(X):
    Xp = torch.where(X == 0, -1, X)
    Yp = -Xp.prod(1).view(-1,1)
    Y = torch.where(Yp==-1, 0, Yp).view(-1,1)
    return Y

X = torch.randint(2, (2000, 64))
Y = parity(X)
data = X.float(), Y.float()

Our network given below has one hidden layer with 64 neurons.

In [3]:
class Parity(nn.Module):
    def __init__(self, input_size, output_size):
        super(Parity, self).__init__()
        self.input_size = input_size
        self.network = nn.Sequential(
            nn.Linear(input_size, input_size),
            nn.Tanh(),
            nn.Linear(input_size, output_size),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.network(x)

In [4]:
model = Parity(input_size, output_size)

In [5]:
#Manually setting the weights and biases
l = torch.ones(64)
for i in range(64):
    if i%2==1:
        l[i] -= 2
with torch.no_grad():
    model.network[0].weight = torch.nn.Parameter(torch.ones(64,64)*100000)
    model.network[0].bias = torch.nn.Parameter(torch.arange(-0.5,-64.5,-1)*100000)
    model.network[2].weight = torch.nn.Parameter(torch.Tensor(l.view(1,-1))*100000)
    model.network[2].bias = torch.nn.Parameter(torch.Tensor([-0.5])*100000)

Here we have multiplied the weights and biases by 100000 so that the tanh and sigmoid functions behave close to their respective step function approximations. This can also be achieved by putting the step functions itself as the activations instead of tanh and sigmoid, because although they are not differentiable, we are not training the model so they don't have to be.

In [6]:
def accuracy(output, target):
    count = 0
    for i in range(len(output)):
        if torch.round(output)[i]==target[i]:
            count += 1
    return 100*(count/len(output))

In [7]:
#checking our model on 2000 randomly generated 64 bit vectors
for i in range(2000):
    if accuracy(model(data[0][i]),data[1][i]) != 100:
        print('Not accurate')

From above we can see that our network works.