## BitLinearOptimized Test

In [1]:
import torch
import torch.nn as nn

In [4]:
from modules import SimpleLinear

# Example usage
input_shape = 64
out_shape = 128
sample_size = 500

In [5]:
inputs = torch.randn(sample_size, input_shape)
true_weights = torch.randn(input_shape, out_shape)

# Generating synthetic targets: y = inputs x true_weights + noise
noise = 0.05 * torch.randn(sample_size, out_shape)
targets = inputs @ true_weights + noise

inputs = inputs.to(torch.bfloat16)
targets = targets.to(torch.bfloat16)

In [6]:
print(inputs.dtype, targets.dtype)

model = SimpleLinear(input_shape, out_shape, num_groups=5)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
num_epochs = 500
losses = []


torch.bfloat16 torch.bfloat16


In [7]:
for epoch in range(num_epochs):
    optimizer.zero_grad()
    outputs = model(inputs)
    
    # Convert outputs and targets to float32 for the loss computation
    loss = criterion(outputs.to(torch.float32), targets.to(torch.float32))
    
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if epoch % 100 == 0:
        print(f"Epoch [{epoch}/{num_epochs}], Loss: {loss.item():.4f}")

Epoch [0/500], Loss: 1036.4595
Epoch [100/500], Loss: 953.9650
Epoch [200/500], Loss: 892.1829
Epoch [300/500], Loss: 840.7665
Epoch [400/500], Loss: 799.8828


In [8]:
# Get the dequantized weights after training
dequantized_weights_after_training = model.linear.dequantize_weights()

# Print or use the dequantized weights as needed
print("Dequantized Weights After Training:", dequantized_weights_after_training)

Dequantized Weights After Training: tensor([[-1., -1., -1.,  ...,  1., -1., -1.],
        [ 1., -1., -1.,  ..., -1.,  1.,  1.],
        [-1.,  1.,  1.,  ...,  1., -1., -1.],
        ...,
        [-1.,  1.,  1.,  ...,  1., -1.,  1.],
        [ 1., -1., -1.,  ...,  1., -1.,  1.],
        [ 1.,  1., -1.,  ..., -1.,  1., -1.]])


In [9]:

# Retrieve original weights
original_weights = model.linear.get_original_weights()
print("Original Weights:", original_weights)
print("binarized weight: ", model.linear.dequantize_weights())

Original Weights: tensor([[-0.0483, -0.0992, -0.0421,  ...,  0.0238, -0.1114, -0.0296],
        [ 0.0373, -0.0164, -0.0545,  ..., -0.0303,  0.0957,  0.0761],
        [-0.0624,  0.1110,  0.0243,  ...,  0.0245, -0.0934, -0.0966],
        ...,
        [-0.0462,  0.0201,  0.0166,  ...,  0.1026, -0.0393,  0.0914],
        [ 0.1205, -0.0904, -0.0862,  ...,  0.0107, -0.1045,  0.0873],
        [ 0.1176,  0.0041, -0.1205,  ..., -0.1180,  0.1044, -0.0639]])
binarized weight:  tensor([[-1., -1., -1.,  ...,  1., -1., -1.],
        [ 1., -1., -1.,  ..., -1.,  1.,  1.],
        [-1.,  1.,  1.,  ...,  1., -1., -1.],
        ...,
        [-1.,  1.,  1.,  ...,  1., -1.,  1.],
        [ 1., -1., -1.,  ...,  1., -1.,  1.],
        [ 1.,  1., -1.,  ..., -1.,  1., -1.]])


## BitLinear Test

This linear implementation is different from the optimized implementation when it comes to data types during infrence

In [10]:
from module_linear import BitLinear

[570.3411865234375, 572.9613037109375, 577.2484130859375, 566.1704711914062, 570.7861328125]
Parameter containing:
tensor([[ 1.9651e-03,  3.9742e-03,  6.0852e-03,  4.6182e-03,  4.5792e-03,
          5.1233e-03,  3.8873e-03,  5.7519e-03,  7.9827e-03,  1.0004e-02,
          5.2477e-03,  3.9558e-01,  1.6974e-03,  3.8776e-03,  6.2353e-03,
          6.3895e-03,  1.0197e-03,  6.9853e-03,  2.6393e-03,  2.0338e-03],
        [ 1.8776e-03,  1.6569e-03,  1.3384e+00,  1.9661e-01,  7.3371e-04,
          7.1123e-03,  4.8208e-04,  1.0617e-03,  2.5296e-03,  1.3610e-03,
          2.9625e-03,  4.6927e-03,  2.0718e-03,  5.7901e-03,  6.0808e-03,
          4.6200e-03,  7.8272e-04,  8.4426e-03,  1.2481e-03,  2.7303e-03],
        [ 5.6645e-03,  1.9017e-03,  1.4181e-03,  5.7114e-03,  5.3053e-03,
          1.8370e-02,  6.2806e-03,  1.9101e-03,  6.7641e-03,  2.1519e-03,
          4.0739e-03,  5.9303e-03,  1.5918e-03,  5.1723e-04,  2.7801e-03,
          7.5162e-03,  4.1349e-03,  5.1004e-03,  2.5890e-01,  8.4162e

In [11]:
# Parameters
input_size = 20
output_size = 30
num_samples = 1000

# Random input data
inputs = torch.randn(num_samples, input_size)

# True weights for generating synthetic targets (randomly initialized)
true_weights = torch.randn(input_size, output_size)

# Generating synthetic targets: y = inputs x true_weights + noise
noise = 0.05 * torch.randn(num_samples, output_size)
targets = inputs @ true_weights + noise

inputs.shape, targets.shape

(torch.Size([1000, 20]), torch.Size([1000, 30]))

In [12]:

# Define a simple neural network with the BitLinear layer
class SimpleNet(nn.Module):
    def __init__(self, input_size, output_size, num_groups):
        super(SimpleNet, self).__init__()
        self.bitlinear = BitLinear(input_size, output_size, num_groups=num_groups)
    
    def forward(self, x):
        return self.bitlinear(x)

# Initialize the model, loss function, and optimizer
model = SimpleNet(input_size, output_size, num_groups=5)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train the model
num_epochs = 200
losses = []

In [13]:
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(inputs)
    loss = criterion(outputs, targets)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    losses.append(loss.item())
#     print(loss.item())

print(losses[-5:])
print(model.bitlinear.weight)
print(model.bitlinear.binarize_weights_groupwise())

[560.5733032226562, 571.445556640625, 578.6713256835938, 587.3937377929688, 578.71337890625]
Parameter containing:
tensor([[ 1.1603e-02,  1.1009e-02,  1.1926e-02,  1.2520e-02,  1.2838e-02,
          1.4962e-02,  1.2898e-02,  1.2050e-02,  1.5594e-02,  8.3676e-03,
          1.5321e-02,  1.2816e-02,  1.2309e-02,  1.0993e-02,  1.4098e-02,
          1.0930e-02,  1.2763e-02,  7.7721e-03,  9.2826e-03,  1.3228e-02],
        [ 1.1787e-02,  1.4914e-02,  1.3922e-02,  1.3687e-02,  1.1496e-02,
          1.0619e-02,  1.5923e-02,  1.2104e-02,  1.0994e-02,  1.4386e-02,
          1.2340e-02,  1.1416e-02,  1.3351e-02,  1.2628e-02,  9.1292e-03,
          1.2154e-02,  9.3743e-03,  1.2610e-02,  1.0887e-02,  9.4200e-03],
        [ 1.0422e-02,  1.3893e-02,  1.2764e-02,  1.2405e-02,  1.0592e-02,
          8.6844e-03,  1.2385e-02,  1.2881e-02,  7.0197e-03,  1.3803e-02,
          1.4324e-02,  9.5016e-03,  1.1534e-02,  1.0978e-02,  1.0720e-02,
          1.2945e-02,  1.3268e-02,  1.1477e-02,  1.4455e-02,  9.3445e