In [None]:
import numpy as np
from pprint import pprint

In [None]:
N = 10
D = 6
K = 2
X = np.zeros((N, D))
X[:N//2, :D//2] = 1
X[N//2:, D//2:] = 1
Y = np.zeros((N, K))
Y[:N//2, :K//2] = 1
Y[N//2:, K//2:] = 1

In [None]:
H = 2
W1_i = np.random.normal(size=(D, H))
W1_i = W1_i / np.linalg.norm(W1_i)
W1_o = np.random.normal(size=(H, D))
W1_o = W1_o / np.linalg.norm(W1_o)
pprint(W1_i)

In [None]:
# W1 = W1_i@W1_o
# W1 = W1 / np.linalg.norm(W1)
# res = X
# prev_res = 5*X
# iters = 0
# diff = np.linalg.norm(res - prev_res)
# while diff > 1e-4:
#     prev_res = res
#     res = res@W1
#     res = np.maximum(res, 0)
#     res = res / np.linalg.norm(res, axis=1)[:, np.newaxis]
#     diff = np.linalg.norm(res - prev_res)
#     iters += 1
# res = res@W1_i
# pprint(iters)
# pprint(diff)
# pprint(prev_res)
# pprint(res)
# pprint(X@W1_i)

In [None]:
W1_stacked = np.vstack([W1_i, W1_o.T])
W1_joint = W1_stacked @ W1_stacked.T
Y_priors = np.zeros((N, K))
X_hats = Y_priors @ W1_o
X_aug = np.hstack([X, X_hats])
rep = X_aug
prev_rep = 5*rep # arbitrary to make the diff large
iters = 0
diff = np.linalg.norm(rep - prev_rep)
while diff > 1e-4:
    prev_rep = rep
    rep = X_aug @ W1_joint
    rep = np.maximum(rep, 0)
    rep = rep / np.linalg.norm(rep, axis=1)[:, np.newaxis]
    diff = np.linalg.norm(rep - prev_rep)
    iters +=1
pprint(iters)
pprint(diff)
pprint(rep)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [None]:
class NetFC(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(NetFC, self).__init__()
        self.W1_i = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        x = self.W1_i(x)
        output = F.log_softmax(x, dim=1)
        return output

In [None]:
class BiNetFC(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(BiNetFC, self).__init__()
        self.W1_i = nn.Linear(input_dim, output_dim)
        self.W1_o = nn.Linear(output_dim, input_dim)
        self.thresh = 1e-4

    def forward(self, X):
        print(X)
        rep = X / torch.linalg.norm(X, dim=1, keepdim=True)
        prev_rep = 5*rep
        diff = torch.linalg.norm(rep - prev_rep)
        iters = 0
        print(rep)
        while diff > self.thresh:
#             print(iters)
#             print(diff)
            prev_rep = rep
            rep = self.W1_i(rep)
            rep = F.relu(rep)
            rep = self.W1_o(rep)
            rep = F.relu(rep)
            rep = rep / torch.linalg.norm(rep, dim=1, keepdim=True)
            diff = torch.linalg.norm(rep - prev_rep) 
            iters += 1
            print(rep)
        rep = self.W1_i(rep)
        print(iters)
        output = F.log_softmax(rep, dim=1)
        return output

In [None]:
N = 10
D = 6
K = 2
X = np.zeros((N, D))
X[:N//2, :D//2] = 1
X[N//2:, D//2:] = 1
Y = np.zeros((N, K))
Y[:N//2, :K//2] = 1
Y[N//2:, K//2:] = 1

X = torch.from_numpy(X).to(torch.float32)
Y = torch.from_numpy(Y).to(torch.float32)
net = BiNetFC(D, K)
for _ in range(0):
    Y_hat = net(X)
    loss = ((torch.exp(Y_hat) - Y)**2).sum()
    loss.backward()
    for p in net.parameters():
#         print(p.grad)
        p.data.add_(- 0.001 * p.grad)
        p.grad.data.zero_()
Y_hat = net(X)
print(Y_hat)
torch.exp(Y_hat)

In [None]:
for name, p in net.named_parameters():
    print(name, p)

In [None]:
class BiNet2FC(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, convergence_thresh):
        super(BiNet2FC, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        in_dim = input_dim
        aux_dim = hidden_dim
        self.layers = []
        for layer_idx in range(num_layers):
            if layer_idx > 0:
                in_dim = hidden_dim
            if layer_idx == num_layers - 1:
                aux_dim = output_dim
            fc_layer = nn.Linear(in_dim + aux_dim, hidden_dim)
            self.layers.append(fc_layer)
        self.layers.append(nn.Linear(hidden_dim, output_dim))
        self.layers = nn.ModuleList(self.layers)
        self.thresh = convergence_thresh

    def forward(self, X):
        batch_size = X.shape[0]
        prev_reps = []
        for layer_idx in range(len(self.layers)):
            layer = self.layers[layer_idx]
            placeholder_rep = torch.zeros(batch_size, layer.out_features)
            prev_reps.append(placeholder_rep)
        reps = []
        diff = float('inf')
        iters = 0
        while diff > self.thresh:
            diff = 0.
            reps = []
            input_lower = X
            for layer_idx in range(len(self.layers)):
                if layer_idx > 0:
                    input_lower = reps[layer_idx - 1]
                input_aug = input_lower
                if layer_idx < len(self.layers) - 1:
                    input_upper = prev_reps[layer_idx + 1]
                    input_aug = torch.hstack([input_lower, input_upper])
                rep = F.relu(self.layers[layer_idx](input_aug))
                reps.append(rep)
                layer_diff = torch.linalg.norm(rep - prev_reps[layer_idx])
                diff += layer_diff
            prev_reps = reps
            diff /= len(self.layers)
            iters += 1
#         print(f"forward iterations: {iters}")
        output = F.log_softmax(reps[-1], dim=1)
        return output

In [None]:
N = 10
D = 6
H = 3
K = 2
X = np.zeros((N, D))
X[:N//2, :D//2] = 1
X[N//2:, D//2:] = 1
Y = np.zeros((N, K))
Y[:N//2, :K//2] = 1
Y[N//2:, K//2:] = 1

X = torch.from_numpy(X).to(torch.float32)
Y = torch.from_numpy(Y).to(torch.float32)
net = BiNet2FC(D, H, K, 3, 1e-5)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)
loss = torch.zeros(1)
prev_loss = torch.tensor(float('inf'))

iters = 0
while torch.linalg.norm(loss - prev_loss) > 1e-4:
    Y_hat = net(X)
    prev_loss = loss
    loss = loss_fn(Y_hat, Y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    iters += 1
    if iters % 20 == 0:
        print(loss)
print(iters)
Y_hat = net(X)
print(net)
preds = torch.exp(Y_hat)
print(preds)

In [None]:
import torchvision
import torchvision.transforms as transforms

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


trainset = torchvision.datasets.CIFAR10(root='../data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR10(root='../data', train=False, download=True, transform=transform)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
class BiNetConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers, convergence_thresh=1e-3, iter_limit=100, device=torch.device("cpu")):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        in_dim = input_dim
        aux_dim = hidden_dim
        self.conv_layers = []
        self.bn_layers = []
        for layer_idx in range(num_layers):
            if layer_idx > 0:
                in_dim = hidden_dim
            if layer_idx == num_layers - 1:
                aux_dim = 0
            conv_layer = nn.Conv2d(in_dim + aux_dim, hidden_dim, 5, padding='same')
            self.conv_layers.append(conv_layer)
            self.bn_layers.append(nn.BatchNorm2d(self.hidden_dim))
        self.conv_layers = nn.ModuleList(self.conv_layers)
        self.bn_layers = nn.ModuleList(self.bn_layers)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(hidden_dim * 16 * 16, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, output_dim)
        self.thresh = convergence_thresh
        self.iter_limit = iter_limit
        self.device = device
        
    def forward(self, X):
        batch_size = X.shape[0]
        prev_reps = []
        for layer_idx in range(len(self.conv_layers)):
            conv_layer = self.conv_layers[layer_idx]
            placeholder_rep = torch.zeros(batch_size, self.hidden_dim, 32, 32, device=self.device)
            prev_reps.append(placeholder_rep)
        reps = []
        diff = float('inf')
        iters = 0
        while diff > self.thresh and iters < self.iter_limit:
#             print(f"number of iterations: {iters}")
#             print(f"diff: {diff}")
            diff = 0.
            reps = []
            input_lower = X
            for layer_idx in range(len(self.conv_layers)):
                if layer_idx > 0:
                    input_lower = reps[layer_idx - 1]
                input_aug = input_lower
                if layer_idx < len(self.conv_layers) - 1:
                    input_upper = prev_reps[layer_idx + 1]
                    input_aug = torch.hstack([input_lower, input_upper])
                rep = F.relu(self.conv_layers[layer_idx](input_aug))
                rep = self.bn_layers[layer_idx](rep)
                reps.append(rep)
                layer_diff = torch.linalg.norm(rep - prev_reps[layer_idx])
                diff += layer_diff
            prev_reps = reps
            diff /= len(self.conv_layers)
            iters += 1
#         print(f"number of iterations: {iters}")
#         print(f"diff: {diff}")
        rep = self.pool(rep)
        x = torch.flatten(rep, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
#         batch_size = X.shape[0]
#         prev_reps = []
#         for layer_idx in range(len(self.layers)):
#             layer = self.layers[layer_idx]
#             placeholder_rep = torch.zeros(batch_size, layer.out_features)
#             prev_reps.append(placeholder_rep)
#         reps = []
#         diff = float('inf')
#         iters = 0
#         while diff > self.thresh:
#             diff = 0.
#             reps = []
#             input_lower = X
#             for layer_idx in range(len(self.layers)):
#                 if layer_idx > 0:
#                     input_lower = reps[layer_idx - 1]
#                 input_aug = input_lower
#                 if layer_idx < len(self.layers) - 1:
#                     input_upper = prev_reps[layer_idx + 1]
#                     input_aug = torch.hstack([input_lower, input_upper])
#                 rep = F.relu(self.layers[layer_idx](input_aug))
#                 reps.append(rep)
#                 layer_diff = torch.linalg.norm(rep - prev_reps[layer_idx])
#                 diff += layer_diff
#             prev_reps = reps
#             diff /= len(self.layers)
#             iters += 1
# #         print(f"forward iterations: {iters}")
#         output = F.log_softmax(reps[-1], dim=1)
#         return output

In [None]:
batch_size = 400
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

device = torch.device("cuda")
net = BiNetConv(3, 16, 10, 4, convergence_thresh=1e-4, iter_limit=10, device=device)
print(net)
net.to(device)
# dataiter = iter(trainloader)
# images, labels = dataiter.next()
# images, labels = images.to(device), labels.to(device)
# output = net(images)
# print(output.shape)

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters())
epochs = 10
net.train()
for epoch in range(epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        
        outputs = net(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        num_mb = 25
        if i % num_mb == num_mb-1:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / num_mb:.3f}')
            running_loss = 0.0

In [None]:
correct = 0
total = 0

net.eval()
with torch.no_grad():
    iters = 0
    for data in testloader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        
        outputs = net(images)
        
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        iters += 1
        
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')