In [1]:
"""
===================================================================
LoRA (Low-Rank Adaptation) Implementation from Scratch in PyTorch
===================================================================

Mathematical Foundation:
------------------------
In standard finetuning: W_updated = W + ΔW
In LoRA: W_updated = W + A·B (where ΔW ≈ A·B)

- W: Original pretrained weight matrix (d × k)
- A: Low-rank matrix (d × r)  
- B: Low-rank matrix (r × k)
- r: Rank (r << min(d, k))

Instead of updating all parameters in W, we only train A and B.
If W is 1000×1000, we need 1M parameters. With LoRA (r=16):
- A: 1000×16 = 16,000 params
- B: 16×1000 = 16,000 params
- Total: 32,000 params (31x fewer!)
"""


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [15]:
# ===================================================================
# STEP 1: LoRA Layer Implementation
# ===================================================================

class LoRALayer(nn.Module):
    """
    Core LoRA layer that implements low-rank adaptation.
    This layer computes: output = alpha * (input @ A @ B)
    
    Parameters:
        in_dim (int): Input dimension (number of input features)
        out_dim (int): Output dimension (number of output features)
        rank (int): Rank for low-rank decomposition (bottleneck dimnesion)
        alpha (float): Scaling factor for LoRA output
    """
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        
        # Calculate standard devation for initialization
        # Using 1/sqrt(rank) helps with gradient flow
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        
        # Matrix A: Maps from input dimesnsion to low rank dimension
        # Shape: (input_dim, rank)
        # Initialized with samll random values from normal distribution
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        
        # Matrix B: Maps from low rank dimension to output dimension
        # Shape: (rank, out_dim)
        # Initialized with zeros so initial LoRA contribution is zero
        # This ensures the model starts with pretrained behaviour
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        
        # Alpha: Scaling hyperparameter
        # Controls the magnitude of LoRA's contribution
        self.alpha = alpha
        
        print(f"LoRALayer initialized:")
        print(f"  Matrix A shape: {self.A.shape} (in_dim={in_dim}, rank={rank})")
        print(f"  Matrix B shape: {self.B.shape} (rank={rank}, out_dim={out_dim})")
        print(f"  Total LoRA params: {self.A.numel() + self.B.numel()}")
        
    def forward(self, x):
        """
        Forward pass through LoRA layer.
        
        Input:
            x: tensor of shape (batch_size, seq_len, in_dim) or (batch_size, in_dim)
        
        Output:
            result: tensor of same shape as input but last dimension is out_dim
            
        Complete Flow:
            1. x @ A: (batch_size, ..., in_dim) @ (in_dim, rank) = (batch_size, .., rank)
            2. (..) @ B: (batch_size, .., rank) @ (rank, out_dim) = (batch_size, ..., out_dim)
            3. alpha * (...): scale the result
        """
        # Store original input shape for documentation
        input_shape = x.shape
        
        # Step 1: Matrix multiply input with A
        # Input shape: (batch_size, ..., in_dim)
        # A shape: (in_dim, rank)
        # Output shape: (batch_size, ..., rank)
        x_A = x @ self.A
        
        # Step 2: Matrix multiply result with B
        # Input shape: (batch_size, ..., rank)
        # A shape: (rank, out_dim)
        # Output shape: (batch_size, ..., out_dim)
        x_A_B = x_A @ self.B
        
        # Step 3: Scale by alpha
        # Shape: (batch_size, ..., out_dim)
        result = self.alpha * x_A_B
        
        # Shape documentation
        print(f"\nLoRALayer forward pass:")
        print(f"  Input shape: {input_shape}")
        print(f"  After x @ A: {x_A.shape}")
        print(f"  After @ B: {x_A_B.shape}")
        print(f"  Final output shape: {result.shape}")
        
        return result

In [23]:
# ===================================================================
# STEP 2: Linear Layer with LoRA (Separate Computation)
# ===================================================================

class LinearWithLoRA(nn.Module):
    """
    Replaces a standard linear layer with LoRA adaptation.
    
    Computes: output = Linear(x) + LoRA(x) = x @ W^T + alpha * (x @ A @ B)
    
    This keeps the original weights and LoRA weights separate.
    Using distribution property: x(W + AB) = xW + xAB
    
    Parameters:
        linear (nn.Linear): The original pretrained linear layer (will be frozen)
        rank (int): Rank for LoRA decomposition
        alpha (float): Scaling factor for LoRA
    """
    def __init__(self, linear: nn.Linear, rank, alpha):
        super().__init__()
        
        # Store the original linear layer
        # This contains the pretrained weights W
        # Shape of W: (out_features, in_features)
        self.linear = linear
        
        # Create the LoRA layer
        # Note: We use linear.in_features and linear.out_features to match original layer dimensions
        self.lora = LoRALayer(
            in_dim=linear.in_features,
            out_dim=linear.out_features,
            rank=rank,
            alpha=alpha
        )
        print(f"\nLinearWithLoRA created:")
        print(f"  Original Linear: in={linear.in_features}, out={linear.out_features}")
        print(f"  Original params: {linear.weight.numel() + linear.bias.numel() if linear.bias is not None else 0}")
        
    def forward(self, x):
        """
        Forward pass: Combine original linear output with LoRA output.
        
        Input:
            x: tensor of shape (batch_size, ..., in_features)
        
        Output:
            result: tensor of shape (batch_size, ..., out_features)
            
        Computation:
            1. Pass through original linear layer: x @ w^T + b
            2. Pass through LoRA layer: alpha * (x @ A @ B)
            3. Add both results together
        """
        input_shape = x.shape
        
        # Step 1: Compute original lnear transformation
        # Input: (batch_size, ..., in_features)
        # Weight: (out_features, in_features) [transposed internally by Linear]
        # Output: (batch_size, ..., out_features)
        linear_output = self.linear(x)
        
        # Step 2: Compute LoRA transformation
        # Input: (batch_size, ..., in_features)
        # Ouptut: (batch_size, ..., out_features)
        lora_output = self.lora(x)
        
        # Step 3: Combine both outputs
        # Both shapes: (batch_size, ..., out_features)
        # Result shape: (batch_size, ..., out_features)
        result = linear_output + lora_output
        
        print(f"\nLinearWithLoRA forward pass:")
        print(f"  Input shape: {input_shape}")
        print(f"  Linear output shape: {linear_output.shape}")
        print(f"  LoRA output shape: {lora_output.shape}")
        print(f"  Combined output shape: {result.shape}")
        
        return result

In [17]:
# ===================================================================
# STEP 3: Linear Layer with LoRA (Merged Weights)
# ===================================================================

class LinearWithLoRAMerged(nn.Module):
    """
    Alternative implementation that merge weights before computation.
    
    Computes: output = x @ (W + alpha * AB)^T
    
    This is mathematically equivalent to LinearWithLoRA but computes the
    combined weight matrix first. Useful for inferernce optimization.
    
    Parameters:
        linear (nn.Linear): The original pre-trained linear layer
        rank (int): Rank for LoRA decomposition.
        alpha (float): Scaling factor for LoRA
    """
    def __init__(self, linear: nn.Linear, rank, alpha):
        super().__init__()
        
        # Store original linear layer
        self.linear = linear
        
        # Create LoRA layer
        self.lora = LoRALayer(
            in_dim=linear.in_features,
            out_dim=linear.out_features,
            rank=rank,
            alpha=alpha
        )
        print(f"\nLinearWithLoRAMerged created")
        
    def forward(self, x):
        """
        Forward pass: Merge weights then compute
        
        Steps:
            1. Compute LoRA weight update: ΔW = A @ B
            2. Combine with original weights: W_new = W + alpha * ΔW^T
            3. Apply combined weights: x @ W_new^T + b
        """
        input_shape = x.shape
        
        # Step 1: Compute low-rank update matrix
        # A shape: (in_fetaures, rank)
        # B shape: (rank, out_features)
        # lora shape: (in_features, out_features)
        lora = self.lora.A @ self.lora.B
        
        # Step 2: Combine with original weights
        # self.linear.weight shape: (out_features, in_features)
        # lora.T shape: (out_features, in_features) [transpose to match]
        # combined_weight shape: (out_features, in_features)
        combined_weight = self.linear.weight + self.lora.alpha * lora.T
        
        # Step 3: Apply combined weights using F.linear
        # x shape: (batch_size, ..., in_features)
        # combined_weight shape: (out_features, in_features)
        # Output shape: (batch_size, ..., out_features)
        result = F.linear(
            input=x,
            weight=combined_weight,
            bias=self.linear.bias
        )
        print(f"\nLinearWithLoRAMerged forward pass:")
        print(f"  Input shape: {input_shape}")
        print(f"  LoRA update (A@B) shape: {lora.shape}")
        print(f"  Combined weight shape: {combined_weight.shape}")
        print(f"  Output shape: {result.shape}")
        
        return result

In [18]:
# ===================================================================
# STEP 4: Example Neural Network with LoRA
# ===================================================================

class SimpleNeuralNetwork(nn.Module):
    """
    A simple 3-layer feedforward neural network
    we'll apply LoRA to this network.
    
    Architecture:
        Input (784) -> Linear (128) -> ReLU -> Linear(64) -> ReLU -> Linear(10)
        
    This is similar to a network for MNIST digit classification.
    """
    
    def __init__(
        self,
        input_size=784,
        hidden_size1=128,
        hidden_size2=64,
        output_size=10
    ):
        super().__init__()
        
        # Layer 1: Input to first hidden layer
        # Weight shape: (hidden_size1, input_size)
        self.fc1 = nn.Linear(in_features=input_size, out_features=hidden_size1)
        
        # Layer 2: First hidden to second hidden layer
        # Weight shape: (hidden_size2, hidden_size1)
        self.fc2 = nn.Linear(in_features=hidden_size1, out_features=hidden_size2)
        
        # Layer 3: Second hiddent to output layer
        # Weight shape: (output_size, hidden_size2)
        self.fc3 = nn.Linear(in_features=hidden_size2, out_features=output_size)
        
        print(f"\nSimpleNeuralNetwork initialized:")
        print(f"  Layer 1: {input_size} -> {hidden_size1}")
        print(f"  Layer 2: {hidden_size1} -> {hidden_size2}")
        print(f"  Layer 3: {hidden_size2} -> {output_size}")
        
    def forward(self, x):
        """
        Forward pass through the network.
        
        Input:
        ------
        x : tensor of shape (batch_size, input_size)
        
        Output:
        -------
        output : tensor of shape (batch_size, output_size)
        """
        print(f"\nSimpleNeuralNetwork forward pass:")
        print(f"  Input shape: {x.shape}")  # (batch_size, 784)
        
        # Layer 1: Linear + ReLU
        # Input: (batch_size, input_size)
        # Output: (batch_size, hidden_size1)
        x = self.fc1(x)
        print(f"  After fc1: {x.shape}")
        x = F.relu(x)
        print(f"  After ReLU: {x.shape}")
        
        # Layer 2: Linear + ReLU
        # Input: (batch_size, 128)
        # Output: (batch_size, 64)
        x = self.fc2(x)
        print(f"  After fc2: {x.shape}")
        x = F.relu(x)
        print(f"  After ReLU: {x.shape}")
        
        # Layer 3: Linear (no activation)
        # Input: (batch_size, 64)
        # Output: (batch_size, 10)
        x = self.fc3(x)
        print(f"  After fc3 (final): {x.shape}")
        
        return x

In [29]:
model = SimpleNeuralNetwork()

for child in model.children():
    for param in child.parameters():
        print(f"{child} ===> {param}")


SimpleNeuralNetwork initialized:
  Layer 1: 784 -> 128
  Layer 2: 128 -> 64
  Layer 3: 64 -> 10
Linear(in_features=784, out_features=128, bias=True) ===> Parameter containing:
tensor([[ 0.0158, -0.0069, -0.0318,  ..., -0.0109, -0.0332, -0.0254],
        [ 0.0013,  0.0260,  0.0219,  ...,  0.0167, -0.0257, -0.0279],
        [-0.0096,  0.0282,  0.0223,  ..., -0.0085, -0.0147, -0.0077],
        ...,
        [ 0.0153,  0.0317,  0.0131,  ...,  0.0356, -0.0032,  0.0280],
        [-0.0044, -0.0101, -0.0131,  ..., -0.0279, -0.0301,  0.0119],
        [ 0.0264,  0.0254,  0.0004,  ..., -0.0195, -0.0072, -0.0225]],
       requires_grad=True)
Linear(in_features=784, out_features=128, bias=True) ===> Parameter containing:
tensor([ 0.0085, -0.0234, -0.0183,  0.0240, -0.0226,  0.0128, -0.0108,  0.0335,
         0.0089,  0.0040,  0.0158, -0.0069,  0.0331,  0.0284, -0.0027,  0.0160,
         0.0047,  0.0139, -0.0286, -0.0249,  0.0280,  0.0273,  0.0234,  0.0169,
         0.0109, -0.0343,  0.0284, -0.0252

In [19]:
# ===================================================================
# STEP 5: Helper Functions for LoRA Integration
# ===================================================================

def count_parameters(model):
    """
    Count total and trainable parameters in a model.
    
    Returns:
    --------
    total_params : int
        Total number of parameters
    trainable_params : int
        Number of trainable parameters
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

def freeze_linear_layers(model):
    """
    Freeze all parameters in standard Linear layers.
    This keeps pretrained weights fixed while allowing LoRA weights to train.
    
    How it works:
    -------------
    1. Iterate through all modules in the model
    2. If module is nn.Linear, set requires_grad=False for all params
    3. LoRA parameters (A and B matrices) remain trainable
    """
    for child in model.children():
        if isinstance(child, nn.Linear):
            # Freeze this linear layer
            for param in child.parameters():
                param.requires_grad = False
            print(f"  Froze linear layers: {child}")
        else:
            # Recursively process child modules
            freeze_linear_layers(child)
            
            
def apply_lora_to_model(model, rank=8, alpha=16, layer_indices=None):
    """
    Apply LoRA to specified layers in a neural network.
    
    Parameters:
    -----------
    model : nn.Module
        The neural network to modify
    rank : int
        Rank for LoRA (lower = fewer parameters)
    alpha : float
        Scaling factor (typically 2 * rank)
    layer_indices : list or None
        Which fc layers to apply LoRA to. If None, apply to all.
        
    Returns:
    --------
    model : nn.Module
        Modified model with LoRA layers
    """
    print(f"\nApplying LoRA (rank={rank}, alpha={alpha}):")
    
    # Get all linear layers from the model
    layers_to_modify = []
    
    if hasattr(model, "fc1"):
        layers_to_modify.append(("fc1", model.fc1))
    if hasattr(model, "fc2"):
        layers_to_modify.append(("fc2", model.fc2))
    if hasattr(model, "fc1"):
        layers_to_modify.append(("fc3", model.fc3))

    # Apply LoRA to each layer
    for name, layer in layers_to_modify:
        if layer_indices is None or name in layer_indices:
            # Replace with LoRA version
            lora_layer = LinearWithLoRA(layer, rank=rank, alpha=alpha)
            setattr(model, name, lora_layer)
            print(f"  Applied LoRA to {name}")
    return model

In [24]:
# ===================================================================
# STEP 6: Complete Example with Training
# ===================================================================

def demonstrate_lora():
    """
    Complete demonstration of LoRA from scratch.
    Shows initialization, forward pass, and parameter comparison.
    """
    print("="*70)
    print("LoRA (Low-Rank Adaptation) - Complete Demonstration")
    print("="*70)
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    
    # -------------------------------------------------------------------
    # Part A: Create Original Model
    # -------------------------------------------------------------------
    print("\n" + "="*70)
    print("PART A: Creating Original Model")
    print("="*70)
    
    model = SimpleNeuralNetwork(
        input_size=784,                 # 28x28 MNIST images flattened
        hidden_size1=128,
        hidden_size2=64,
        output_size=10                  # 10 digit classes
    )
    # Count original parameters
    total_orig, trainable_orig = count_parameters(model)
    print(f"\nOriginal model parameters:")
    print(f"  Total: {total_orig:,}")
    print(f"  Trainable: {trainable_orig:,}")
    
    # -------------------------------------------------------------------
    # Part B: Apply LoRA to Model
    # -------------------------------------------------------------------
    print("\n" + "="*70)
    print("PART B: Applying LoRA to Model")
    print("="*70)
    
    # Apply LoRA with rank=8
    # This adds 8 * (in_dim + out_dim) parameters per layer
    model = apply_lora_to_model(model, rank=8, alpha=16)
    
    # -------------------------------------------------------------------
    # Part C: Freeze Original Weights
    # -------------------------------------------------------------------
    print("\n" + "="*70)
    print("PART C: Freezing Original Weights")
    print("="*70)
    freeze_linear_layers(model=model)
    
    # Count parameters after LoRA
    total_lora, trainable_lora = count_parameters(model=model)
    print(f"\nAfter applying LoRA:")
    print(f"  Total parameters: {total_lora:,}")
    print(f"  Trainable parameters: {trainable_lora:,}")
    print(f"  Reduction: {(1 - trainable_lora/trainable_orig)*100:.1f}%")
    print(f"  Compression ratio: {trainable_orig/trainable_lora:.1f}x")
    
    # -------------------------------------------------------------------
    # Part D: Test Forward Pass
    # -------------------------------------------------------------------
    print("\n" + "="*70)
    print("PART D: Testing Forward Pass with Sample Data")
    print("="*70)
    
    # Create dummy batch of data
    # Shape: (batch_size=4, input_size=784)
    batch_size = 4
    input_size = 784
    dummy_input = torch.randn(batch_size, input_size)
    print(f"\nInput batch shape: {dummy_input.shape}")
    
    # Disable gradient computation for testing
    with torch.no_grad():
        # Forward pass (this will print shapes at each step)
        output = model(dummy_input)
        
    print(f"\nFinal output shape: {output.shape}")
    print(f"Expected shape: (batch_size={batch_size}, num_classes=10)")
    
    # -------------------------------------------------------------------
    # Part E: Show Trainable Parameters
    # -------------------------------------------------------------------
    print("\n" + "="*70)
    print("PART E: Trainable Parameters")
    print("="*70)
    
    print("\nTrainable parameters (LoRA only):")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"  {name:30s}: shape {str(list(param.shape)):20s} ({param.numel():6,} params)")
    
    print("\nFrozen parameters (original weights):")
    for name, param in model.named_parameters():
        if not param.requires_grad:
            print(f"  {name:30s}: shape {str(list(param.shape)):20s} ({param.numel():6,} params)")
            
    # -------------------------------------------------------------------
    # Part F: Compare Computation Approaches
    # -------------------------------------------------------------------
    print("\n" + "="*70)
    print("PART F: Comparing LinearWithLoRA vs LinearWithLoRAMerged")
    print("="*70)
    
    # Create a simple test case
    torch.manual_seed(42)
    test_linear = nn.Linear(10, 5)
    test_input = torch.randn(3, 10) # (batch_size, in_features)
    
    print("\nTest Linear layer:")
    print(f"  Input shape: {test_input.shape}")
    print(f"  Weight shape: {test_linear.weight.shape}")
    print(f"  Output shape: (3, 5)")
    
    # Create both versions
    lora_separate = LinearWithLoRA(test_linear, rank=2, alpha=4)
    lora_merged = LinearWithLoRAMerged(test_linear, rank=2, alpha=4)
    
    # Copy weights to make them identical
    lora_merged.lora.A.data = lora_separate.lora.A.data.clone()
    lora_merged.lora.B.data = lora_separate.lora.B.data.clone()
    
    with torch.no_grad():
        output_separate = lora_separate(test_input)
        output_merged = lora_merged(test_input)
        
    # Check if outputs are identical
    difference = torch.max(torch.abs(output_separate - output_merged)).item()
    print(f"\nMax difference between methods: {difference:.10f}")
    print(f"Are outputs identical? {difference < 1e-6}")
    
    print("\n" + "="*70)
    print("Demonstration Complete!")
    print("="*70)

In [25]:
# ===================================================================
# STEP 7: Training Example (Optional - Shows How to Actually Train)
# ===================================================================
def train_lora_model(num_epochs=2, batch_size=64):
    """
    Complete training example using LoRA on MNIST dataset.
    
    This shows how to:
    1. Load data
    2. Setup model with LoRA
    3. Train only LoRA parameters
    4. Evaluate performance
    """
    
    print("\n" + "="*70)
    print("Training Example with LoRA on MNIST")
    print("="*70)
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    # -------------------------------------------------------------------
    # Step 1: Prepare Data
    # -------------------------------------------------------------------
    print("\nStep 1: Loading MNIST dataset...")
    
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ]
    )
    
    # Load training data
    train_dataset = datasets.MNIST(
        root="./data",
        train=True,
        download=True,
        transform=transform
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )
    
    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Batch size: {batch_size}")
    print(f"  Number of batches: {len(train_loader)}")
    
    # -------------------------------------------------------------------
    # Step 2: Create Model with LoRA
    # -------------------------------------------------------------------
    print("\nStep 2: Creating model with LoRA...")
    
    model = SimpleNeuralNetwork().to(device=device)
    model = apply_lora_to_model(model, rank=8, alpha=16)
    freeze_linear_layers(model=model)
    
    total_params, trainable_params = count_parameters(model)
    print(f"  Total parameters: {total_params:,}")
    print(f"  Trainable parameters: {trainable_params:,}")
    
    # -------------------------------------------------------------------
    # Step 3: Setup Training
    # -------------------------------------------------------------------
    print("\nStep 3: Setting up training...")
    
    # Only optimizer LoRA parameters
    optimizer = optim.Adam(
        [p for p in model.parameters() if p.requires_grad],
        lr=0.001
    )
    
    criterion = nn.CrossEntropyLoss()
    
    # -------------------------------------------------------------------
    # Step 4: Training Loop
    # -------------------------------------------------------------------
    print(f"\nStep 4: Training for {num_epochs} epoch(s)...\n")
    
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(train_loader):
            # Move data to device
            # data shape: (batch_size, 1, 28, 28)
            # target shape: (batch_size, )
            data, target = data.to(device), target.to(device)
            
            # Flatten images: (batch_size, 1, 28, 28) -> (batch_size, 784)
            data = data.view(data.size(0), -1)
            
            # zero gradients
            optimizer.zero_grad()
            
            # Forward pas
            # Input shape: (batch_size, 784)
            # Output_shape: (batch_size, 10)
            output = model(data)
            
            # Compute loss
            loss = criterion(output, target)
            
            # Backward pass
            loss.backward()
            
            # Update only LoRA parameters
            optimizer.step()
            
            # Track metrics
            total_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
            
            # Print progress
            if batch_idx % 100 == 0:
                print(f"Epoch {epoch+1}/{num_epochs} | "
                    f"Batch {batch_idx}/{len(train_loader)} | "
                    f"Loss: {loss.item():.4f} | "
                    f"Acc: {100.*correct/total:.2f}%")
        
        avg_loss = total_loss / len(train_loader)
        accuracy = 100.0 * correct / total
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Average Loss: {avg_loss:.4f}")
        print(f"  Accuracy: {accuracy:.2f}%\n")
    
    print("Training complete!")

In [None]:
# ===================================================================
# STEP 8: Main Execution
# ===================================================================

if __name__ == "__main__":
    print("\n" + "="*70)
    print(" "*15 + "LoRA Implementation from Scratch")
    print("="*70)
    
    # Run the demonstration
    demonstrate_lora()
    
    # Uncomment the following to run actual training (takes time)
    # print("\n\nWould you like to train the model? (requires downloading MNIST)")
    # train_lora_model(num_epochs=1, batch_size=64)
    
    print("\n" + "="*70)
    print("All demonstrations complete!")
    print("="*70)


               LoRA Implementation from Scratch
LoRA (Low-Rank Adaptation) - Complete Demonstration

PART A: Creating Original Model

SimpleNeuralNetwork initialized:
  Layer 1: 784 -> 128
  Layer 2: 128 -> 64
  Layer 3: 64 -> 10

Original model parameters:
  Total: 109,386
  Trainable: 109,386

PART B: Applying LoRA to Model

Applying LoRA (rank=8, alpha=16):
LoRALayer initialized:
  Matrix A shape: torch.Size([784, 8]) (in_dim=784, rank=8)
  Matrix B shape: torch.Size([8, 128]) (rank=8, out_dim=128)
  Total LoRA params: 7296

LinearWithLoRA created:
  Original Linear: in=784, out=128
  Original params: 100480
  Applied LoRA to fc1
LoRALayer initialized:
  Matrix A shape: torch.Size([128, 8]) (in_dim=128, rank=8)
  Matrix B shape: torch.Size([8, 64]) (rank=8, out_dim=64)
  Total LoRA params: 1536

LinearWithLoRA created:
  Original Linear: in=128, out=64
  Original params: 8256
  Applied LoRA to fc2
LoRALayer initialized:
  Matrix A shape: torch.Size([64, 8]) (in_dim=64, rank=8)
  Ma