In [1]:
! pip install -q torch torchvision tqdm

In [2]:
NEPOCH = 10

In [3]:
from torch import nn, functional as F
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm

In [4]:
import torchvision
import torchvision.transforms as transforms

# Transformations
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:19<00:00, 517292.69it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 28657604.41it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 560777.70it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 20662178.71it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw






In [5]:
# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

net = Net()

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

# Training the network
for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

# Testing the network on the test data
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')


100%|██████████| 10/10 [00:35<00:00,  3.53s/it]


Finished Training
Accuracy of the network on the 10000 test images: 97 %


# 🔸 FFF

In [6]:
! pip install -q lovely-tensors

In [7]:
import lovely_tensors as lt
lt.monkey_patch()

In [8]:
class FFF(nn.Module):
  def __init__(self, nIn, nOut) -> None:
    super().__init__()

    INPUT_WIDTH = nIn
    OUTPUT_WIDTH = nOut
    DEPTH =  8   #int(math.ceil(math.log2(nIn)))

    self.input_width = INPUT_WIDTH
    self.output_width = OUTPUT_WIDTH
    self.depth = DEPTH
    self.n_nodes = 2**(DEPTH+1) - 1 # ???

    self._initiate_weights()

  def _initiate_weights(self):
    init_factor_I1 = 1 / math.sqrt(self.input_width)
    init_factor_I2 = 1 / math.sqrt(self.depth + 1)

    # shape: (n_nodes, input_width)
    # weights for linear layer
    self.w1s = nn.Parameter(
      torch.empty(self.n_nodes, self.input_width).uniform_(-init_factor_I1, init_factor_I1),
      requires_grad=True)

    # weights for regular layer
    self.w2s = nn.Parameter(
      torch.empty(self.n_nodes, self.output_width).uniform_(-init_factor_I2, init_factor_I2),
      requires_grad=True)


  def forward(self, x: torch.Tensor):
    batch_size = x.shape[0]

    # concurrent for batch size (bs, )
    current_node = torch.zeros((batch_size,), dtype=torch.long)

    all_nodes = torch.zeros(batch_size, self.depth+1, dtype=torch.long)
    all_scores = torch.empty((batch_size, self.depth+1), dtype=torch.float)

    for i in range(self.depth + 1):
      # compute plane scores
      # dot product between input (x) and weights of the current node (w1s)
      # result is scalar of shape (bs)
      plane_score = torch.einsum('b i, b i -> b', x, self.w1s[current_node])
      all_nodes[:, i] = current_node

      # scores are used for gradient propagation and learning decision boundaries
      all_scores[:, i] = plane_score

      # compute next node (left or right)
      plane_choice = (plane_score > 0).long()
      current_node = (current_node * 2) + plane_choice + 1

    # from: https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L29
    # FF_41( GeLU ( FF_14(x) )))

    # GeLU(FF_14 @ x) @ FF_41
    # GeLU(W1(x) @ x) @ W2(x)
    selected_w2s = self.w2s[all_nodes.flatten()].view(batch_size, self.depth+1, self.output_width)
    # y = torch.einsum('b i j , b i -> b j', selected_w2s, F.gelu(all_logits))
    y = torch.einsum('b i j , b i -> b j', selected_w2s, all_scores)
    return y

In [None]:
[0,1,2,3,4,5,6,7,8,9,]
[0,1,1,2,2,2,2,3,3,3,3,3,3,3,3]
# 0 -[0]> 1 -[0]> 3 -[0]> 7 -[0]> 15
# 0 -[0]> 1 -[0]> 3 -[1]> 8 -[0]> 17
# 0 -[0]> 1 -[1]> 4 -[0]> 9 -[0]> 18
# 0 -[0]> 1 -[1]> 4 -[1]> 10 -[0]> 21
# 0 -[1]> 2 -[1]> 5 -[0]> 11 -[0]> 22


In [9]:
# Data loaders
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

# Neural network architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = FFF(nIn=28*28, nOut=500)
        self.fc2 = FFF(nIn=500, nOut=10)
        # self.fc1 = FFF(nIn=28*28, nOut=10)
    def forward(self, x):
        x = x.view(-1, 28*28)
        y_hat = self.fc2(torch.relu(self.fc1(x)))
        # y_hat = self.fc1(x)
        return y_hat

net = Net()

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(net.parameters(), lr=0.001) #, momentum=0.9)

# Training the network
for epoch in tqdm(range(NEPOCH)):  # loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data

        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

# Testing the network on the test data
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')


100%|██████████| 10/10 [02:58<00:00, 17.82s/it]


Finished Training
Accuracy of the network on the 10000 test images: 94 %
