# BitLinear - Pytorch

Comparing the performance of the BitLinear layer with the Linear layer in Pytorch.

In [1]:
!pip -q install bitlinear-pytorch

### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from tqdm import tqdm

from bitlinear_pytorch import replace_linear_with_bitlinear


### Initialize Model

In [2]:
class TinyMLP(nn.Module):
    def __init__(self):
        super(TinyMLP, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x):
        return self.layers(x)

In [3]:
model = TinyMLP()
bitmodel = TinyMLP()
replace_linear_with_bitlinear(bitmodel)

### Load Data

In [4]:
batch_size = 128
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.view(784))
])

In [5]:
train_dataset = MNIST(root="data", train=True, transform=transform, download=True)
test_dataset = MNIST(root="data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

### Train Model

In [6]:
def train(model, train_loader, test_loader, epochs=5):
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    loop = tqdm(range(epochs))
    for epoch in loop:
        # train step
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # test step
        model.eval()
        correct = 0
        total = len(test_loader.dataset)
        with torch.no_grad():
            for inputs, targets in test_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)
                correct += (predicted == targets).sum().item()

        acc = correct / total
        loop.set_description(
            f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss:.2f}, Acc: {acc:.4f}"
        )

In [7]:
train(model, train_loader, test_loader)
train(bitmodel, train_loader, test_loader)

Epoch [5/5], Loss: 22.59, Acc: 0.9783: 100%|████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:34<00:00,  6.99s/it]
Epoch [5/5], Loss: 302.41, Acc: 0.8119: 100%|███████████████████████████████████████████████████████████████████████████████████| 5/5 [00:43<00:00,  8.74s/it]


---