<a href="https://colab.research.google.com/github/dietmarja/LLM-Elements/blob/main/QLoRA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# QLoRA (Quantized Low-Rank Adaptation). QLoRA typically involves quantizing the weights of the model in addition to applying a low-rank adaptation.
# For simplicity, we'll use a basic quantization approach where we scale the weights to integers and then scale them back during computation.



import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt

# QLoRA Layer Definition
class QLoRALayer(nn.Module):
    def __init__(self, in_features, out_features, rank, quant_bits=8):
        super(QLoRALayer, self).__init__()
        self.rank = rank
        self.quant_bits = quant_bits
        self.W = nn.Linear(in_features, out_features, bias=False)
        self.A = nn.Linear(in_features, rank, bias=False)
        self.B = nn.Linear(rank, out_features, bias=False)

        # Initialize A and B with small values
        nn.init.normal_(self.A.weight, std=0.01)
        nn.init.normal_(self.B.weight, std=0.01)

    def forward(self, x):
        return self.quantize(self.W(x) + self.B(self.A(x)))

    def quantize(self, x):
        scale = (2 ** self.quant_bits - 1) / x.max()
        return torch.round(x * scale) / scale

    def print_weights(self):
        print(f"Full-rank weight matrix (W): \n{self.W.weight.data}")
        print(f"Low-rank weight matrix A: \n{self.A.weight.data}")
        print(f"Low-rank weight matrix B: \n{self.B.weight.data}")

# Simple Model Definition
class SimpleModel(nn.Module):
    def __init__(self, input_dim, output_dim, qlora_rank=None):
        super(SimpleModel, self).__init__()
        if qlora_rank:
            self.layer = QLoRALayer(input_dim, output_dim, qlora_rank)
        else:
            self.layer = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.layer(x)

    def print_weights(self):
        if isinstance(self.layer, QLoRALayer):
            self.layer.print_weights()
        else:
            print(f"Full-rank weight matrix (W): \n{self.layer.weight.data}")

# Training Function
def train_model(model, inputs, targets, epochs=100):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.01)
    losses = []
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        if (epoch+1) % 10 == 0:
            print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')

    return losses

# Evaluation Function
def evaluate_model(model, test_inputs):
    model.eval()
    with torch.no_grad():
        predictions = model(test_inputs)
        print(f'Predictions: {predictions}')

# Define model parameters
input_dim = 10
output_dim = 5

# Dummy data for demonstration
inputs = torch.randn(8, input_dim)
targets = torch.randn(8, output_dim)

# QLoRA ranks to evaluate
qlora_ranks = [1, 2, 3, 4]

# Dictionary to store all losses
all_losses = {"No QLoRA": train_model(SimpleModel(input_dim, output_dim), inputs, targets)}

# Train models with different QLoRA ranks
for rank in qlora_ranks:
    print(f"Training with QLoRA rank {rank}...")
    model_qlora = SimpleModel(input_dim, output_dim, rank)
    losses_qlora = train_model(model_qlora, inputs, targets)
    all_losses[f"QLoRA rank {rank}"] = losses_qlora
    model_qlora.print_weights()

# Plot the losses for all models
plt.figure(figsize=(12, 8))
for label, losses in all_losses.items():
    plt.plot(losses, label=label)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss with and without QLoRA')
plt.legend()
plt.show()

# Example prediction
test_inputs = torch.randn(2, input_dim)
for rank in ["No QLoRA"] + [f"QLoRA rank {r}" for r in qlora_ranks]:
    print(f"Evaluating model with {rank}...")
    model = SimpleModel(input_dim, output_dim) if rank == "No QLoRA" else SimpleModel(input_dim, output_dim, int(rank.split()[-1]))
    evaluate_model(model, test_inputs)
