In [333]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np

class IntegerLinear(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Initialize weights as floating-point values
        self.weight = nn.Parameter(torch.randn(out_features, in_features, dtype=torch.float32))  # Use normal distribution for initialization
    
    def forward(self, x):
        return torch.mm(x, self.weight.t())
    
    def quantize_weights(self, target_min=-128, target_max=127):
        """Quantize the weights to a target range with dynamic scaling."""
        with torch.no_grad():
            # Find the min and max of the weights
            weight_min = self.weight.min()
            weight_max = self.weight.max()
            
            # Dynamically scale the target range based on observed weight distribution
            scale_factor = 0.3 # Adjust this factor to control how aggressively the weights are scaled
            range_span = (weight_max - weight_min) * scale_factor
            target_min_scaled = weight_min - range_span / 2
            target_max_scaled = weight_max + range_span / 2
            
            # Scale the weights to the scaled target range
            scale = (target_max_scaled - target_min_scaled) / (weight_max - weight_min)
            zero_point = target_min_scaled - weight_min * scale
            
            # Quantize the weights
            quantized_weights = torch.round(self.weight * scale + zero_point)
            
            # Clip the values to ensure they are within the target range
            quantized_weights = torch.clamp(quantized_weights, target_min, target_max)
            
            # Store the quantized weights back into the model
            self.weight.data = quantized_weights
    
class IntegerNet(nn.Module):
    def __init__(self, input_size=784):
        super().__init__()
        # Four layers of IntegerLinear
        self.layer1 = IntegerLinear(input_size, 128)
        self.layer2 = IntegerLinear(128, 128)
        self.layer3 = IntegerLinear(128, 64)
        self.layer4 = IntegerLinear(64, 10)
    
    def forward(self, x):
        x = self.layer1(x)
        x = torch.relu(x)
        x = self.layer2(x)
        x = torch.relu(x)
        x = self.layer3(x)
        x = torch.relu(x)
        x = self.layer4(x)
        return x

    def quantize_weights(self):
        """Quantize weights for all layers."""
        self.layer1.quantize_weights()
        self.layer2.quantize_weights()
        self.layer3.quantize_weights()
        self.layer4.quantize_weights()

In [334]:
def load_data(filepath='train.csv'):
    data = pd.read_csv(filepath)
    labels = data['label'].values
    pixels = data.drop('label', axis=1).values
    
    # Convert to binary (0 or 1)
    pixels = (pixels > 127).astype(np.float32)
    
    return torch.FloatTensor(pixels), torch.LongTensor(labels)

In [335]:
def train_model(model, X_train, y_train, epochs=10, batch_size=128):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)
    
    n_samples = X_train.shape[0]
    n_batches = n_samples // batch_size
    
    for epoch in range(epochs):
        total_loss = 0
        correct = 0
        
        # Shuffle data
        indices = torch.randperm(n_samples)
        X_train = X_train[indices]
        y_train = y_train[indices]
        
        for i in range(n_batches):
            start_idx = i * batch_size
            end_idx = start_idx + batch_size
            
            batch_X = X_train[start_idx:end_idx]
            batch_y = y_train[start_idx:end_idx]
            
            # Forward pass
            outputs = model(batch_X)
            loss = criterion(outputs, batch_y)
            
            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == batch_y).sum().item()
        
        # Print epoch statistics
        avg_loss = total_loss / n_batches
        accuracy = correct / n_samples
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}')
        
        # Quantize weights after each epoch to reduce their size
        model.quantize_weights()

In [336]:
X_train, y_train = load_data()
model = IntegerNet()

In [337]:
train_model(model, X_train, y_train)

Epoch [1/10], Loss: 236.2272, Accuracy: 0.3083
Epoch [2/10], Loss: 11.6330, Accuracy: 0.1227
Epoch [3/10], Loss: 14.5360, Accuracy: 0.1112
Epoch [4/10], Loss: 16.4991, Accuracy: 0.1039
Epoch [5/10], Loss: 15.7448, Accuracy: 0.1023
Epoch [6/10], Loss: 14.8864, Accuracy: 0.1000
Epoch [7/10], Loss: 13.2210, Accuracy: 0.0992
Epoch [8/10], Loss: 25.8806, Accuracy: 0.0989
Epoch [9/10], Loss: 19.4008, Accuracy: 0.0986
Epoch [10/10], Loss: 7.0968, Accuracy: 0.0987


In [341]:
# At the end of training, print out all the weight matrices
print("\n--- Weight Matrices ---")
for name, param in model.named_parameters():
    if "weight" in name:
        print(f"{name}:\n{param.data[0]}\n")

model.named_parameters()


--- Weight Matrices ---
layer1.weight:
tensor([  1.,  -1.,   1.,   1.,   0., -28.,  -1.,   1.,  -1.,  27.,   1.,  27.,
          1.,  -2.,   1.,  -1.,   0., -36.,  27.,  27.,   1.,   0.,  -1.,   0.,
         -2.,  27.,  35.,  -2., -28.,   1.,   1.,   0.,   1.,   0.,  -2.,  -2.,
          0.,   0.,   0.,   0.,   0.,  27.,   0.,  27.,  -1.,   1.,  -1.,  -1.,
          0.,   0.,   0.,   0.,  27.,  -2., -28.,  -2.,   1.,   1.,  27.,   1.,
         -2.,   1.,   1.,   1.,   1.,  -2.,   0.,  27.,   0.,   0.,   1.,  -2.,
          1.,   1.,   1.,  -1.,  -1.,  -1.,  -1.,   0.,  -1.,  -1.,   0.,  27.,
          0.,   1.,   0.,   0.,   1.,   0.,   1.,   0.,  -1.,   1.,   0.,  -2.,
          0.,  -1.,  35., -36.,  -2.,   0.,   0.,   1.,   1.,  -1.,  -1.,  27.,
          0.,   0., -28.,  -1.,  -1.,  27.,   1., -28.,  -1.,   0.,  -2.,   1.,
         -1.,   1.,  -1.,  -1., -28.,   1.,   1.,   0.,  35.,  -2.,   0.,   0.,
         -2.,  -2.,  -1.,   1.,   0.,   0.,  27.,  27.,  35.,   1.,  27.,   1.,


<generator object Module.named_parameters at 0x29f5b5840>