In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# Define the 3-layer MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.start_fc = nn.Linear(28 * 28, 256)
        self.fc1 = nn.Linear(256, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.lm_head = nn.Linear(256, 10)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.relu(self.start_fc(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.lm_head(x)
        return self.softmax(x)

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

# Training loop
def train(model, device, train_loader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

train(model, device, train_loader, optimizer, criterion)


  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/10, Loss: 0.08715134859085083
Epoch 2/10, Loss: 0.22687970101833344
Epoch 3/10, Loss: 0.06562088429927826
Epoch 4/10, Loss: 0.003095711348578334
Epoch 5/10, Loss: 0.0029548597522079945
Epoch 6/10, Loss: 0.018094034865498543
Epoch 7/10, Loss: 0.01994013786315918
Epoch 8/10, Loss: 0.001788316760212183
Epoch 9/10, Loss: 0.008530102670192719
Epoch 10/10, Loss: 0.07043803483247757


In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import torch.utils.data as data
import numpy as np
from copy import deepcopy

# Define the 3-layer MLP model
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.start_fc = nn.Linear(28 * 28, 512)
        self.fc1 = nn.Linear(512, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 512)
        self.fc4 = nn.Linear(512, 512)
        self.fc5 = nn.Linear(512, 512)
        self.lm_head = nn.Linear(512, 10)
        self.relu = nn.ReLU()
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = self.relu(self.start_fc(x))
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.relu(self.fc3(x))
        x = self.relu(self.fc4(x))
        x = self.relu(self.fc5(x))
        x = self.lm_head(x)
        return self.softmax(x)

# Quantization function
def quantize_tensor(x, num_bits=8):
    qmin = 0.
    qmax = 2. ** num_bits - 1.

    min_val, max_val = x.min(), x.max()
    scale = (max_val - min_val) / (qmax - qmin) if max_val != min_val else 1.0
    zero_point = qmin - min_val / scale if scale != 0 else 0.

    zero_point = zero_point.round().clamp(qmin, qmax)

    q_x = (x / scale + zero_point).round().clamp(qmin, qmax)
    # Dequantize
    dq_x = (q_x - zero_point) * scale

    return dq_x

# Quantize specific layers in the model
def quantize_model(model, num_bits=8):
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            if module.in_features == module.out_features:
                # Quantize weights and biases
                module.weight.data = quantize_tensor(module.weight.data, num_bits)
                if module.bias is not None:
                    module.bias.data = quantize_tensor(module.bias.data, num_bits)

# Create a custom Dataset class
class QuantizationDataset(data.Dataset):
    def __init__(self, quant_model, raw_model, data_loader):
        self.inp_list = []
        self.tar_list = []

        activation = {}

        def get_activation(name):
            def hook(model, input, output):
                activation[name + '_input'] = input[0].detach()
                activation[name + '_output'] = output.detach()
            return hook

        # Register hooks for the layers to be quantized
        raw_model.fc1.register_forward_hook(get_activation('raw_fc1'))
        raw_model.fc2.register_forward_hook(get_activation('raw_fc2'))
        raw_model.fc3.register_forward_hook(get_activation('raw_fc3'))
        raw_model.fc4.register_forward_hook(get_activation('raw_fc4'))
        raw_model.fc5.register_forward_hook(get_activation('raw_fc5'))
        
        quant_model.fc1.register_forward_hook(get_activation('quant_fc1'))
        quant_model.fc2.register_forward_hook(get_activation('quant_fc2'))
        quant_model.fc3.register_forward_hook(get_activation('quant_fc3'))
        quant_model.fc4.register_forward_hook(get_activation('quant_fc4'))
        quant_model.fc5.register_forward_hook(get_activation('quant_fc5'))

        quant_model.eval()  # Set model to evaluation mode
        raw_model.eval()  # Set model to evaluation mode
        with torch.no_grad():
            for data, target in data_loader:
                data = data.to(next(raw_model.parameters()).device)
                output = raw_model(data)
                _ = quant_model(data)
                # Collect inputs and targets for each sample in the batch
                batch_size = data.size(0)
                for i in range(batch_size):
                    inp_sample = []
                    tar_sample = []
                    for layer_name in ['fc1', 'fc2', 'fc3', 'fc4', 'fc5']:
                        inp = activation[f'quant_{layer_name}_input'][i]
                        quant_out = activation[f'quant_{layer_name}_output'][i]
                        raw_out = activation[f'raw_{layer_name}_output'][i]
                        inp_sample.append(inp.unsqueeze(0))
                        tar_sample.append((raw_out - quant_out).unsqueeze(0))
                    # Stack inputs and targets for the layers
                    inp_sample = torch.cat(inp_sample, dim=0)
                    tar_sample = torch.cat(tar_sample, dim=0)
                    self.inp_list.append(inp_sample)
                    self.tar_list.append(tar_sample)

        # Stack all samples
        self.inp = torch.stack(self.inp_list)
        self.tar = torch.stack(self.tar_list)

    def __len__(self):
        return self.inp.size(0)

    def __getitem__(self, idx):
        return self.inp[idx], self.tar[idx]

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset_raw = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader_raw = torch.utils.data.DataLoader(dataset=train_dataset_raw, batch_size=64, shuffle=False)

# Initialize model and device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLP().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()
# Training loop
def train(model, device, train_loader, optimizer, criterion, epochs=10):
    model.train()
    for epoch in range(epochs):
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")
train(model, device, train_loader_raw, optimizer, criterion)
raw_model = deepcopy(model)

# Now quantize the model
quantize_model(model, num_bits=8)
quant_model = model

# Create the quantization dataset
quant_dataset = QuantizationDataset(quant_model,raw_model, train_loader_raw)
quant_loader = torch.utils.data.DataLoader(dataset=quant_dataset, batch_size=64, shuffle=True)


# You can now use quant_loader in your training loop or further processing
# For demonstration, let's print the shapes of inp and tar for a sample
sample_inp, sample_tar = quant_dataset[0]
print("Sample inp shape:", sample_inp.shape)  # Should be (3, 256)
print("Sample tar shape:", sample_tar.shape)  # Should be (3, 256)


  from .autonotebook import tqdm as notebook_tqdm


Epoch 1/10, Loss: 0.008553309366106987
Epoch 2/10, Loss: 0.01526783686131239
Epoch 3/10, Loss: 0.014119611121714115
Epoch 4/10, Loss: 0.00949557963758707
Epoch 5/10, Loss: 0.0013644912978634238
Epoch 6/10, Loss: 0.0026759756729006767
Epoch 7/10, Loss: 0.0006718770018778741
Epoch 8/10, Loss: 0.000330804061377421
Epoch 9/10, Loss: 7.87862609286094e-06
Epoch 10/10, Loss: 4.950577931595035e-06
Sample inp shape: torch.Size([5, 512])
Sample tar shape: torch.Size([5, 512])


In [8]:
import torch
from model.mamba2 import Mamba2, Mamba2Config
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
print(len(quant_dataset))
config = Mamba2Config(d_model=512, n_layers=1,d_head=4)
model = Mamba2(config)
model.to(device)
train_set,validate_set = torch.utils.data.random_split(quant_dataset,[int(len(quant_dataset)*0.8),int(len(quant_dataset)*0.2)])
train_loader = DataLoader(train_set,batch_size=32,shuffle=True)
validate_loader = DataLoader(validate_set,batch_size=32,shuffle=True)
learning_rate=1e-5
optim = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                                                        optim,
                                                        mode='min',
                                                        factor=0.1, #factor by which the lr is multiplied
                                                        patience=2,
                                                    )
epochs = 10
for epoch in range(epochs):
    train_loss = 0
    model.train()
    for i,data in enumerate(train_loader):
        inp,tar = data
        inp = inp.to(device)
        tar = tar.to(device)
        optim.zero_grad()
        output = model(inp)
        loss = torch.norm(output - tar, p=2)
        # loss = F.cross_entropy(output, tar)
        loss.backward()
        optim.step()
        train_loss += loss.item()
    train_loss /= len(train_set)
    print(f"Epoch {epoch+1} Train Loss: {train_loss}")
    validate_loss = 0
    model.eval()
    with torch.no_grad():
        for i,data in enumerate(validate_loader):
            inp,tar = data
            inp = inp.to(device)
            tar = tar.to(device)
            output = model(inp)
            loss = torch.norm(output - tar, p=2)
            # loss = F.cross_entropy(output, tar)
            validate_loss += loss.item()
    validate_loss /= len(validate_set)
    scheduler.step(validate_loss)
    print(f"Epoch {epoch+1} Validate Loss: {validate_loss}")


60000
Epoch 1 Train Loss: 6.5161741282145185
Epoch 1 Validate Loss: 10.71863870493571
Epoch 2 Train Loss: 4.642788636525472
Epoch 2 Validate Loss: 7.987921098073324
Epoch 3 Train Loss: 3.670166264216105
Epoch 3 Validate Loss: 6.626560643513997
Epoch 4 Train Loss: 3.1430958881378173
Epoch 4 Validate Loss: 5.770802393595377


In [6]:
inp,tar=validate_set[0]
inp.unsqueeze(0)
out =model(inp.unsqueeze(0).to(device))
out,tar

(tensor([[[ 0.0315,  0.0203,  0.0012,  ...,  0.0342, -0.0660,  0.0312],
          [-0.0404, -0.0082,  0.0032,  ..., -0.0038,  0.0071, -0.0135],
          [ 0.0554, -0.0059, -0.0388,  ...,  0.0272, -0.0205, -0.0064],
          [-0.1671,  0.0236,  0.0662,  ...,  0.0980, -0.0178,  0.0068],
          [-0.0216,  0.0113, -0.0448,  ..., -0.5089,  0.0084, -0.0216]]],
        device='cuda:0', grad_fn=<AddBackward0>),
 tensor([[-0.0101,  0.0003, -0.0046,  ...,  0.0043,  0.0126,  0.0145],
         [ 0.0026, -0.0044,  0.0030,  ..., -0.0281, -0.0113, -0.0039],
         [ 0.0003, -0.0164,  0.0151,  ...,  0.0216,  0.0093, -0.0016],
         [-0.0074, -0.0056,  0.0211,  ...,  0.0101, -0.0070, -0.0012],
         [-0.0271, -0.0073,  0.0089,  ...,  0.0050, -0.0241, -0.0065]],
        device='cuda:0'))