In [None]:
import os
import torch
from torch import nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import Subset
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

In [None]:
tf = transforms.Compose([ToTensor(), transforms.Resize(16)])
mnist_train = datasets.MNIST('mnist_data/', download=True, train=True, transform=tf)
mnist_test = datasets.MNIST('mnist_data/', download=True, train=False, transform=tf)

In [None]:
targets = [0, 1, 2, 3, 4]
train_indices = [i for i, label in enumerate(mnist_train.targets) if label in targets]
mnist_train_s = Subset(mnist_train, train_indices)
test_indices = [i for i, label in enumerate(mnist_test.targets) if label in targets]
mnist_test_s = Subset(mnist_test, test_indices)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

In [None]:
train_set = torch.utils.data.DataLoader(mnist_train_s, batch_size=256, shuffle=True)
test_set = torch.utils.data.DataLoader(mnist_test_s, batch_size=256, shuffle=True)

In [None]:
img_size = 16
hidden_size = 16

class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(img_size*img_size, hidden_size, bias=False),
            nn.ReLU(),
            #nn.Linear(hidden_size, hidden_size),
            #nn.ReLU(),
            nn.Linear(hidden_size, 5, bias=False),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


In [None]:
model = NeuralNetwork().to(device)
#model = torch.load('model_tiny.pth', map_location=torch.device(device))
print(model)

In [None]:
learning_rate = 1e-3

loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

In [None]:
from tqdm.notebook import tqdm

def train_loop(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(tqdm(dataloader)):
        X, y = X.to(device), y.to(device)
        
        # Compute prediction and loss
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #if batch % 100 == 0:
        #    loss, current = loss.item(), batch * len(X)
        #    print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")


def test_loop(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    test_loss, correct = 0, 0

    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    test_loss /= size
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

In [None]:
epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train_loop(train_set, model, loss_fn, optimizer)
    test_loop(test_set, model, loss_fn)
print("Done!")

In [None]:
torch.save(model, 'model_tiny.pth')

In [None]:
imgs, labs = next(iter(test_set))

In [None]:
import random
idx = random.randint(0, len(imgs))
print(imgs.shape, labs.shape)
plt.imshow(imgs[idx].squeeze())
pred = model(imgs[idx].to(device))
print(torch.argmax(pred.cpu()), labs[idx])

In [None]:
tf = transforms.Compose([
    ToTensor(), 
    transforms.CenterCrop(240), 
    transforms.Resize(16), 
    lambda x: transforms.functional.rotate(x, 180)
])
custom_test = datasets.ImageFolder('custom_data/', transform=tf)

In [None]:
img, lab = custom_test[0]
print(lab)
t1 = img.sum(dim=0, keepdim=True)
t1 = torch.where(t1 < 2, 1., 0.)
print(t1.shape)
plt.imshow(t1.squeeze(), cmap='gray')

In [None]:
pred = model(t1.to(device))
print(torch.argmax(pred.cpu()))

In [18]:
print(model)
print(model.linear_relu_stack[0].weight.detach().cpu().numpy())

NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=256, out_features=16, bias=False)
    (1): ReLU()
    (2): Linear(in_features=16, out_features=5, bias=False)
    (3): ReLU()
  )
)
[[ 0.02651919 -0.03826084  0.03035309 ...  0.01474939  0.02718701
   0.01681066]
 [-0.04441253 -0.03005728 -0.0615841  ...  0.02694984  0.05554254
  -0.02049033]
 [ 0.03967959  0.01237705  0.00526872 ... -0.01183159 -0.01934808
  -0.05268165]
 ...
 [-0.03036228 -0.00612178  0.02012341 ...  0.05704859  0.03373971
   0.03451768]
 [ 0.03627101  0.02872439  0.06197749 ...  0.00770043 -0.03933718
   0.01848188]
 [-0.05737756  0.04138687  0.03870979 ...  0.05599818  0.00472341
  -0.05814941]]


In [19]:
def get_weight(m, idx):
    return m.linear_relu_stack[idx*2].weight.detach().cpu().numpy()
def get_bias(m, idx):
    return m.linear_relu_stack[idx*2].bias.detach().cpu().numpy()
def to_fixed_point(mat):
    return (mat * (2**8)).round().astype(int).tolist()

In [20]:
w0, w2 = get_weight(model, 0), get_weight(model, 1)
print(w0, w2)

[[ 0.02651919 -0.03826084  0.03035309 ...  0.01474939  0.02718701
   0.01681066]
 [-0.04441253 -0.03005728 -0.0615841  ...  0.02694984  0.05554254
  -0.02049033]
 [ 0.03967959  0.01237705  0.00526872 ... -0.01183159 -0.01934808
  -0.05268165]
 ...
 [-0.03036228 -0.00612178  0.02012341 ...  0.05704859  0.03373971
   0.03451768]
 [ 0.03627101  0.02872439  0.06197749 ...  0.00770043 -0.03933718
   0.01848188]
 [-0.05737756  0.04138687  0.03870979 ...  0.05599818  0.00472341
  -0.05814941]] [[ 2.30456829e-01  1.97920382e-01 -2.25730255e-01 -2.32890278e-01
   1.81446642e-01 -2.81125084e-02  1.71149865e-01  2.72740405e-02
  -1.66329488e-01 -1.65744975e-01 -3.69534343e-02 -9.12979525e-03
  -1.27709612e-01  2.03615166e-02 -1.75119430e-01  2.34007448e-01]
 [ 2.65079271e-02  6.62841871e-02  1.03419013e-02  1.11586541e-01
   7.60938898e-02 -6.49395436e-02 -1.09402500e-01 -1.55636892e-01
  -1.70130491e-01 -2.45154947e-01 -1.06793493e-01 -1.42244205e-01
  -2.28730589e-02  7.63704777e-02  4.85608540

In [21]:
print(to_fixed_point(w2))

[[59, 51, -58, -60, 46, -7, 44, 7, -43, -42, -9, -2, -33, 5, -45, 60], [7, 17, 3, 29, 19, -17, -28, -40, -44, -63, -27, -36, -6, 20, 12, -44], [-51, 1, 6, -57, 0, -61, -54, -24, 58, 40, 53, -52, 48, -32, -62, 30], [-55, -58, 56, 21, -38, 23, 68, 17, 30, -26, -35, 23, 66, -24, 42, 13], [9, 35, 56, 60, 8, -48, -16, 5, -33, 53, 68, -17, -5, 26, 29, -39]]


In [None]:
tw0, tb0 = torch.rand(4, 5), torch.rand(4)

In [None]:
tw1, tb1 = torch.rand(4, 4), torch.rand(4)

In [None]:
tw2, tb2 = torch.rand(2, 4), torch.rand(2)

In [None]:
print(to_fixed_point(tw0.numpy()), to_fixed_point(tb0.numpy()))
print(to_fixed_point(tw1.numpy()), to_fixed_point(tb1.numpy()))
print(to_fixed_point(tw2.numpy()), to_fixed_point(tb2.numpy()))

In [None]:
tx1, tx2 = torch.rand(5,), torch.rand(5,)

In [None]:
print(to_fixed_point(tx1.numpy()), to_fixed_point(tx2.numpy()))

In [None]:
from torch.nn.functional import relu
ty1 = relu(tw2.matmul(relu(tw1.matmul(relu(tw0.matmul(tx1) + tb0)) + tb1)) + tb2)
ty2 = relu(tw2.matmul(relu(tw1.matmul(relu(tw0.matmul(tx2) + tb0)) + tb1)) + tb2)

In [None]:
print(to_fixed_point(ty1.numpy()), to_fixed_point(ty2.numpy()))

In [None]:
def approximate(mat):
    tmp = mat * (2**16) 
    return tmp.round() / (2**16)
def approximate_t(mat):
    tmp = mat * (2**16) 
    return tmp.floor() / (2**16)

In [None]:
tw0r, tb0r = approximate(tw0), approximate(tb0)
tw1r, tb1r = approximate(tw1), approximate(tb1)
tw2r, tb2r = approximate(tw2), approximate(tb2)
tx1r, tx2r = approximate(tx1), approximate(tx2)

In [None]:
ty1r = relu(approximate_t(tw2r.matmul(relu(approximate_t(tw1r.matmul(relu(approximate_t(tw0r.matmul(tx1r)) + tb0r))) + tb1r))) + tb2r)
ty2r = relu(approximate_t(tw2r.matmul(relu(approximate_t(tw1r.matmul(relu(approximate_t(tw0r.matmul(tx2r)) + tb0r))) + tb1r))) + tb2r)
print(to_fixed_point(ty1r.numpy()), to_fixed_point(ty2r.numpy()))

In [None]:
print(to_fixed_point(relu(approximate_t(tw2r.matmul(relu(approximate_t(tw1r.matmul(relu(approximate_t(tw0r.matmul(tx1r)) + tb0r))) + tb1r))) + tb2r).numpy()))