In [1]:
import torch
from torch import nn

class TanhFixedPointLayer(nn.Module):

    def __init__(self, out_feats, tol=1e-4, max_iter=50):
        super().__init__()
        self.linear = \
            nn.Linear(out_feats, out_feats, bias=False)
        self.tol = tol
        self.max_iter = max_iter

    def forward(self, x):
        # initialize output z to zero
        z = torch.zeros_like(x)
        self.iterations = 0

        # iterate until convergence
        while self.iterations < self.max_iter:
            z_next = torch.tanh(self.linear(z) + x)
            self.err = torch.norm(z - z_next)
            z = z_next
            self.iterations += 1
            if self.err < self.tol:
                break
        return z


In [2]:
layer = TanhFixedPointLayer(50)
X = torch.randn(10, 50)
Z = layer(X)
print(f'Terminated after {layer.iterations} iterations with error {layer.err}')

Terminated after 14 iterations with error 8.546890603611246e-05


In [3]:
# import the MNIST dataset and data loaders
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

mnist_train = datasets.MNIST(".", train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST(".", train=False, download=True, transform=transforms.ToTensor())
train_loader = DataLoader(mnist_train, batch_size = 100, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size = 100, shuffle=False)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
# construct the simple model with fixed point layer
import torch.optim as optim
torch.manual_seed(0)

model = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 100),
    TanhFixedPointLayer(100, max_iter=200),
    nn.Linear(100, 10)
).to(device)

opt = optim.SGD(model.parameters(), lr=1e-1)

In [5]:
# a generic function for running a single epoch (training or evaluation)
from tqdm.notebook import tqdm

def epoch(loader, model, opt=None, monitor=None):

    total_loss, total_err, total_monitor = 0.,0.,0.
    
    model.eval() if opt is None else model.train()

    for X, y in tqdm(loader, leave=False):
        X, y = X.to(device), y.to(device)
        yp = model(X)
        loss = nn.CrossEntropyLoss()(yp, y)
        if opt:
            opt.zero_grad()
            loss.backward()
            if sum(torch.sum(torch.isnan(p.grad)) for p in model.parameters()) == 0:
                opt.step()
        
        total_err += (yp.max(dim=1)[1] != y).sum().item()
        total_loss += loss.item() * X.shape[0]

        if monitor is not None:
            total_monitor += monitor(model)

    return total_err / len(loader.dataset), total_loss / len(loader.dataset), total_monitor / len(loader)

In [6]:
for i in range(10):
    if i == 5:
        opt.param_groups[0]["lr"] = 1e-2

    train_err, train_loss, train_fpiter = epoch(train_loader, model, opt, lambda x : x[2].iterations)
    test_err, test_loss, test_fpiter = epoch(test_loader, model, monitor = lambda x : x[2].iterations)
    print(f"Train Error: {train_err:.4f}, Loss: {train_loss:.4f}, FP Iters: {train_fpiter:.2f} | " +
          f"Test Error: {test_err:.4f}, Loss: {test_loss:.4f}, FP Iters: {test_fpiter:.2f}")

  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.1113, Loss: 0.4034, FP Iters: 53.37 | Test Error: 0.0716, Loss: 0.2419, FP Iters: 56.22


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0577, Loss: 0.1940, FP Iters: 52.72 | Test Error: 0.0498, Loss: 0.1632, FP Iters: 51.25


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0438, Loss: 0.1465, FP Iters: 57.98 | Test Error: 0.0450, Loss: 0.1461, FP Iters: 56.47


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0360, Loss: 0.1219, FP Iters: 66.49 | Test Error: 0.0360, Loss: 0.1225, FP Iters: 62.24


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0305, Loss: 0.1029, FP Iters: 75.47 | Test Error: 0.0356, Loss: 0.1185, FP Iters: 74.29


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0208, Loss: 0.0733, FP Iters: 74.36 | Test Error: 0.0309, Loss: 0.1044, FP Iters: 73.01


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0193, Loss: 0.0679, FP Iters: 76.18 | Test Error: 0.0303, Loss: 0.1027, FP Iters: 75.64


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0186, Loss: 0.0652, FP Iters: 77.07 | Test Error: 0.0307, Loss: 0.1047, FP Iters: 75.68


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0175, Loss: 0.0627, FP Iters: 82.08 | Test Error: 0.0307, Loss: 0.1046, FP Iters: 79.23


  0%|          | 0/600 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

Train Error: 0.0170, Loss: 0.0605, FP Iters: 84.23 | Test Error: 0.0307, Loss: 0.1033, FP Iters: 78.77
