### Sparse Linear Networks

This is an example of how to build a sparse DNN to do simple image classification on the MNIST dataset.

In [1]:
import torch
import sparta
import torchvision
import numpy as np

device = 'cuda:0'
random_seed = 2022

torch.manual_seed(random_seed)
np.random.seed(random_seed)

#### Preparation
1. Download the MNIST dataset through `torchvision`.

In [2]:
train_set, test_set = [
    torchvision.datasets.MNIST(
        root="",
        train=training,
        download=True,
        transform=torchvision.transforms.Compose([torchvision.transforms.ToTensor()]),
    )
    for training in [True, False]
]

2. Preprocess: shuffle and reconstruct data with batch size of 4096.

In [3]:
batch_size = 4096

def preprocess(dataset):
    indexes = [i for i in range(len(dataset))]
    np.random.shuffle(indexes)
    batches = []
    for i in range(len(dataset) // batch_size):
        X_list, y_list = [], []
        for j in range(batch_size):
            X, y = dataset[indexes[i * batch_size + j]]
            X_list.append(X.view(1, 28 * 28))
            y_list.append(y)
        batches.append((torch.vstack(X_list).contiguous(), torch.tensor(y_list)))
    return batches

train_set = preprocess(train_set)
test_set = preprocess(test_set)

3. Define training and testing functions.

In [4]:
learning_rate = 0.001
loss_func = torch.nn.functional.nll_loss

def train(model, epochs=20):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    start.record()
    for epoch in range(epochs):
        for X, y in train_set:
            optimizer.zero_grad()
            output = model(X.to(device))
            loss = loss_func(output, y.to(device))
            loss.backward()
            optimizer.step()
    end.record()
    torch.cuda.synchronize()
    time_cost = start.elapsed_time(end) / 1000
    print(f'Training time cost: {round(time_cost, 3)} s')

def test(model):
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in test_set:
            if X.shape[0] < batch_size:
                continue
            output = model(X.to(device))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct += 1
                total +=1
    accuracy = correct / total * 100
    print(f"Accuracy: {round(accuracy, 3)}%",)

#### Dense Network
1. Create a 4-layer dense neural network with `torch.nn.Linear`.

In [5]:
class DenseNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.linear_0 = torch.nn.Linear(28 * 28, 2048, device=device)
        self.linear_1 = torch.nn.Linear(2048, 4096, device=device)
        self.linear_2 = torch.nn.Linear(4096, 2048, device=device)
        self.linear_3 = torch.nn.Linear(2048, 10, device=device)

    def forward(self, x):
        x = torch.relu(self.linear_0(x))
        x = torch.relu(self.linear_1(x))
        x = torch.relu(self.linear_2(x))
        x = torch.log_softmax(self.linear_3(x), dim=1)
        return x

dense_net = DenseNet()

2. Train the dense network for 20 epochs and test. We will get ~98.1% accuracy after ~20 seconds' training.

In [6]:
print('===== Dense Network =====')
train(dense_net, epochs=20)
test(dense_net)

===== Dense Network =====
Training time cost: 20.31 s
Accuracy: 98.096%


#### Sparse Network
1. Create a 4-layer neural network of the same shape with our `DenseNet`. The middle two FC layers are 90% sparsed.

In [7]:
class SparseNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        mask = sparta.testing.block_mask((4096, 2048), block=(32, 32), sparsity=0.9, device=device)
        self.linear_0 = torch.nn.Linear(28 * 28, 2048, device=device)
        self.linear_1 = sparta.nn.SparseLinear(
            torch.nn.Linear(2048, 4096, device=device),
            weight_mask=mask,
        )
        self.linear_2 = sparta.nn.SparseLinear(
            torch.nn.Linear(4096, 2048, device=device),
            weight_mask=mask.T,
        )
        self.linear_3 = torch.nn.Linear(2048, 10, device=device)

    def forward(self, x):
        x = torch.relu(self.linear_0(x))
        x = torch.relu(self.linear_1(x))
        x = torch.relu(self.linear_2(x))
        x = torch.log_softmax(self.linear_3(x), dim=1)
        return x

sparse_net = SparseNet()

2. Tune the sparse network with sample inputs and gradients. Note that we need to set `backward_weight=1` to activate backward kernels in tuning. This step may take 10 minutes.

In [8]:
sample_input = torch.rand((batch_size, 28 * 28), device=device)
sample_grad = torch.rand((batch_size, 10), device=device)

# The tune() function will find the best config,
# build the sparse operator and return the best config.
best_config = sparta.nn.tune(
    sparse_net,
    sample_inputs=[sample_input],
    sample_grads=[sample_grad],
    backward_weight=1,
    algo='rand',
    max_trials=30,
)

# If you have already tuned once and saved the best config,
# you can skip the tune() step and build the operator directly.
sparta.nn.build(
    sparse_net,
    sample_inputs=[sample_input],
    configs=best_config,
)

3. Train the sparse network for 20 epochs and test. This time we will get ~97.5% accuracy after ~8 seconds' training.

In [9]:
print('===== Sparse Network =====')
train(sparse_net, epochs=20)
test(sparse_net)

===== Sparse Network =====
Training time cost: 7.598 s
Accuracy: 97.607%
