In [1]:
from tinybig.util import set_random_seed
set_random_seed(random_seed=1234)
DEVICE = 'mps' # or 'cpu', or 'cuda'

In [2]:
from tinybig.data import mnist
mnist_data = mnist(name='mnist', train_batch_size=64, test_batch_size=64)
mnist_loaders = mnist_data.load(cache_dir='./data/')
train_loader = mnist_loaders['train_loader']
test_loader = mnist_loaders['test_loader']

In [3]:
for X, y in train_loader:
    print('X shape:', X.shape, 'y.shape:', y.shape)
    print('X', X)
    print('y', y)
    break

X shape: torch.Size([64, 784]) y.shape: torch.Size([64])
X tensor([[-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
        ...,
        [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242],
        [-0.4242, -0.4242, -0.4242,  ..., -0.4242, -0.4242, -0.4242]])
y tensor([6, 6, 3, 8, 3, 2, 8, 7, 0, 9, 1, 7, 6, 6, 8, 6, 8, 2, 2, 3, 0, 6, 7, 1,
        2, 1, 9, 2, 0, 9, 0, 2, 9, 8, 9, 3, 0, 7, 8, 1, 3, 0, 3, 6, 0, 5, 2, 1,
        5, 3, 1, 7, 6, 2, 8, 7, 6, 7, 0, 3, 6, 3, 3, 6])


In [4]:
from tinybig.expansion import taylor_expansion

exp_func = taylor_expansion(name='taylor_expansion', d=2, postprocess_functions='layer_norm', device=DEVICE)
x = X[0:1,:]
D = exp_func.calculate_D(m=x.size(1))
print('Expansion space dimension:', D)

kappa_x = exp_func(x=x)
print('x.shape', x.shape, 'kappa_x.shape', kappa_x.shape)

Expansion space dimension: 615440
x.shape torch.Size([1, 784]) kappa_x.shape torch.Size([1, 615440])


In [5]:
from tinybig.reconciliation import dual_lphm_reconciliation

rec_func = dual_lphm_reconciliation(name='dual_lphm_reconciliation', p=8, q=784, r=5, device=DEVICE)
l = rec_func.calculate_l(n=64, D=D)
print('Required learnable parameter number:', l)

Required learnable parameter number: 7925


In [6]:
from tinybig.remainder import zero_remainder

rem_func = zero_remainder(name='zero_remainder', require_parameters=False, enable_bias=False, device=DEVICE)

In [7]:
from tinybig.module import rpn_head

head = rpn_head(m=784, n=64, channel_num=1, data_transformation=exp_func, parameter_fabrication=rec_func, remainder=rem_func, device=DEVICE)

In [8]:
from tinybig.module import rpn_layer

layer_1 = rpn_layer(m=784, n=64, heads=[head], device=DEVICE)

In [9]:
layer_2 = rpn_layer(
    m=64, n=64, heads=[
        rpn_head(
            m=64, n=64, channel_num=1,
            data_transformation=taylor_expansion(d=2, postprocess_functions='layer_norm', device=DEVICE),
            parameter_fabrication=dual_lphm_reconciliation(p=8, q=64, r=5, device=DEVICE),
            remainder=zero_remainder(device=DEVICE),
            device=DEVICE
        )
    ],
    device=DEVICE
)

layer_3 = rpn_layer(
    m=64, n=10, heads=[
        rpn_head(
            m=64, n=10, channel_num=1,
            data_transformation=taylor_expansion(d=2, postprocess_functions='layer_norm', device=DEVICE),
            parameter_fabrication=dual_lphm_reconciliation(p=2, q=64, r=5, device=DEVICE),
            remainder=zero_remainder(device=DEVICE),
            device=DEVICE
        )
    ],
    device=DEVICE
)

In [10]:
from tinybig.model import rpn

model = rpn(
    layers = [
        layer_1,
        layer_2,
        layer_3
    ],
    device=DEVICE
)

In [11]:
import torch
from tinybig.learner import backward_learner

optimizer=torch.optim.AdamW(lr=2.0e-03, weight_decay=2.0e-04, params=model.parameters())
lr_scheduler=torch.optim.lr_scheduler.ExponentialLR(gamma=0.95, optimizer=optimizer)
loss = torch.nn.CrossEntropyLoss()
learner = backward_learner(n_epochs=3, optimizer=optimizer, loss=loss, lr_scheduler=lr_scheduler)


In [12]:
from tinybig.metric import accuracy

print('parameter num: ', sum([parameter.numel() for parameter in model.parameters()]))

metric = accuracy()
training_records = learner.train(model=model, data_loader=mnist_loaders, metric=metric, device=DEVICE)

parameter num:  9330


100%|██████████| 938/938 [00:33<00:00, 28.02it/s, epoch=0/3, loss=0.0519, lr=0.002, metric_score=0.969, time=33.6]


Epoch: 0, Test Loss: 0.12760563759773874, Test Score: 0.9621, Time Cost: 3.6527512073516846


100%|██████████| 938/938 [00:31<00:00, 29.45it/s, epoch=1/3, loss=0.0112, lr=0.0019, metric_score=1, time=69.1]    


Epoch: 1, Test Loss: 0.09334634791371549, Test Score: 0.9717, Time Cost: 3.5445549488067627


100%|██████████| 938/938 [00:32<00:00, 29.20it/s, epoch=2/3, loss=0.0212, lr=0.0018, metric_score=1, time=105]     


Epoch: 2, Test Loss: 0.08378902525169431, Test Score: 0.9749, Time Cost: 3.4113848209381104


In [13]:
test_result = learner.test(model=model, test_loader=mnist_loaders['test_loader'], metric=metric, device=DEVICE)
print(metric.__class__.__name__, metric.evaluate(y_true=test_result['y_true'], y_pred=test_result['y_pred'], y_score=test_result['y_score'], ))

accuracy 0.9749
