In [1]:
import time
import numpy as np
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.nn as nn
import torch

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning import Trainer
    
from typing import List
import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from lightning.pytorch.loggers import WandbLogger
import wandb

if torch.cuda.is_available():
    torch.backends.cudnn.deterministic = True

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
SEED = 2024
pl.seed_everything(SEED)

Seed set to 2024


2024

In [3]:
wandb_project_name = 'MNIST_LORA'

In [4]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = './data', batch_size: int = 64):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.ToTensor()

    def prepare_data(self):
        # Download only
        datasets.MNIST(root=self.data_dir, train=True, download=True)
        datasets.MNIST(root=self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Transform and split datasets
        if stage == 'fit' or stage is None:
            mnist_full = datasets.MNIST(root=self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
        if stage == 'test' or stage is None:
            self.mnist_test = datasets.MNIST(root=self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

# Example of how to use the MNISTDataModule
batch_size = 64
mnist_data = MNISTDataModule(data_dir='./data', batch_size=batch_size)


In [5]:
# Hyperparameters
random_seed = 123
learning_rate = 0.005
num_epochs = 2

# Architecture
num_features = 784
num_hidden_1 = 128
num_hidden_2 = 256
num_classes = 10

class MultilayerPerceptron(pl.LightningModule):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes, learning_rate):
        super().__init__()
        self.save_hyperparameters()

        self.layers = nn.Sequential(
            nn.Linear(num_features, num_hidden_1),
            nn.ReLU(),
            nn.Linear(num_hidden_1, num_hidden_2),
            nn.ReLU(),
            nn.Linear(num_hidden_2, num_classes)
        )

    def forward(self, x):
        return self.layers(x)

    def training_step(self, batch, batch_idx):
        features, targets = batch
        features = features.view(-1, 28*28)
        logits = self(features)
        loss = F.cross_entropy(logits, targets)
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == targets).float().mean()
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        features, targets = batch
        features = features.view(-1, 28*28)
        logits = self(features)
        loss = F.cross_entropy(logits, targets)
        self.log('val_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == targets).float().mean()
        self.log('val_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
    def test_step(self, batch, batch_idx):
        features, targets = batch
        features = features.view(-1, 28*28)
        logits = self(features)
        loss = F.cross_entropy(logits, targets)
        self.log('test_loss', loss)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == targets).float().mean()
        self.log('test_acc', acc)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer
    
model = MultilayerPerceptron(num_features, num_hidden_1, num_hidden_2, num_classes, learning_rate)

In [6]:
wandb_logger = WandbLogger(project=wandb_project_name, log_model="all", name="baseline", group="baseline", save_dir="lightning_logs")
trainer = Trainer(max_epochs=num_epochs, logger=wandb_logger)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, mnist_data)

You are using a CUDA device ('NVIDIA GeForce RTX 4070 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmiguel_kjh[0m ([33msiani-ai[0m). Use [1m`wandb login --relogin`[0m to force relogin


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | layers | Sequential | 136 K 
--------------------------------------
136 K     Trainable params
0         Non-trainable params
136 K     Total params
0.544     Total estimated model params size (MB)


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]

c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


                                                                           

c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 860/860 [00:09<00:00, 92.29it/s, v_num=8pbq, train_loss_step=0.0947, train_acc_step=0.917, val_loss_step=0.000484, val_acc_step=1.000, val_loss_epoch=0.128, val_acc_epoch=0.964, train_loss_epoch=0.115, train_acc_epoch=0.965]

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 860/860 [00:09<00:00, 88.44it/s, v_num=8pbq, train_loss_step=0.0947, train_acc_step=0.917, val_loss_step=0.000484, val_acc_step=1.000, val_loss_epoch=0.128, val_acc_epoch=0.964, train_loss_epoch=0.115, train_acc_epoch=0.965]


In [8]:
trainer.test(model, mnist_data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 279.53it/s]


[{'test_loss': 0.11397630721330643, 'test_acc': 0.9661999940872192}]

In [9]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅█
test_acc,▁
test_loss,▁
train_acc_epoch,▁█
train_acc_step,▁▃▃▆▆▇▅▄▁▅▇▅▄█▅▄▇▇▅▇▆▆▅█▂▆▇▆▅█▆███
train_loss_epoch,█▁
train_loss_step,▆▅▄▃▂▃▄▃█▃▁▃▄▁▃▄▂▂▅▁▂▂▃▁▄▂▁▂▃▁▂▁▁▁
trainer/global_step,▁▂▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▇▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂█
val_acc_epoch,▁█
val_acc_step,▃▅▅▆▄▅▄▇▇▃▁▅▄▅▇▄▅▅▂▆█▆▅▇▅█▅▆▆▅▆▆▅▅▅▄█▆▇█

0,1
epoch,2.0
test_acc,0.9662
test_loss,0.11398
train_acc_epoch,0.96456
train_acc_step,1.0
train_loss_epoch,0.11487
train_loss_step,0.02563
trainer/global_step,1720.0
val_acc_epoch,0.9638
val_acc_step,1.0


## LoRa

In [10]:
class LoRALayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.B = nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

    
# This LoRA code is equivalent to LinearWithLoRA
class LinearWithLoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        return F.linear(x, combined_weight, self.linear.bias)

    
# This DoRA code is equivalent to LinearWithDoRA
# Code inspired by https://github.com/catid/dora/blob/main/dora.py
class LinearWithDoRAMerged(nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )
        
        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True)
        )

    def forward(self, x):
        lora = self.lora.A @ self.lora.B
        numerator = self.linear.weight + self.lora.alpha*lora.T
        denominator = numerator.norm(p=2, dim=0, keepdim=True)
        directional_component = numerator / denominator
        new_weight = self.m * directional_component
        return F.linear(x, new_weight, self.linear.bias)
    
# Lora neurons expert
import torch
import torch.nn as nn

class LoRAMixtureOfExpertsLayer(nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha, num_experts):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        
        self.A = nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        # Lista de expertos B
        self.B = nn.ParameterList([nn.Parameter(torch.zeros(rank, out_dim)) for _ in range(num_experts)])
        self.alpha = alpha
        self.num_experts = num_experts

    def forward(self, x):
        # Aplicar softmax después de x * A
        weights = F.softmax(self.alpha * (x @ self.A), dim=-1)
        
        # Ahora, weights puede ser utilizado para ponderar la contribución de cada experto de manera más diferenciada.
        # Por simplicidad, aquí sumaremos las salidas como antes. Considera modificar esto para utilizar los pesos de manera efectiva.
        expert_outputs = [weights @ b for b in self.B]
        x = sum(expert_outputs) / self.num_experts
        return x
    
class LinearWithLoRAMixtureOfExperts(nn.Module):
    def __init__(self, linear, rank, alpha, num_experts):
        super().__init__()
        self.linear = linear
        self.lora = LoRAMixtureOfExpertsLayer(
            linear.in_features, linear.out_features, rank, alpha, num_experts
        )

    def forward(self, x):
        return F.linear(x, self.linear.weight, self.linear.bias) + self.lora(x)

In [11]:
import copy

model_lora = copy.deepcopy(model)
model_dora = copy.deepcopy(model)
model_lora_moe = copy.deepcopy(model)

In [12]:
model_lora.layers[0] = LinearWithLoRAMerged(model_lora.layers[0], rank=4, alpha=8)
model_lora.layers[2] = LinearWithLoRAMerged(model_lora.layers[2], rank=4, alpha=8)
model_lora.layers[4] = LinearWithLoRAMerged(model_lora.layers[4], rank=4, alpha=8)
model_lora

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRAMerged(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithLoRAMerged(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithLoRAMerged(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [13]:
model_dora.layers[0] = LinearWithDoRAMerged(model_dora.layers[0], rank=4, alpha=8)
model_dora.layers[2] = LinearWithDoRAMerged(model_dora.layers[2], rank=4, alpha=8)
model_dora.layers[4] = LinearWithDoRAMerged(model_dora.layers[4], rank=4, alpha=8)
model_dora

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithDoRAMerged(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRALayer()
    )
    (1): ReLU()
    (2): LinearWithDoRAMerged(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRALayer()
    )
    (3): ReLU()
    (4): LinearWithDoRAMerged(
      (linear): Linear(in_features=256, out_features=10, bias=True)
      (lora): LoRALayer()
    )
  )
)

In [14]:
def freeze_linear_layers(model):
    for child in model.children():
        if isinstance(child, nn.Linear):
            for param in child.parameters():
                param.requires_grad = False
        else:
            # Recursively freeze linear layers in children modules
            freeze_linear_layers(child)

## Train

### Train Lora

In [15]:
freeze_linear_layers(model_lora)

# Check if linear layers are frozen
for name, param in model_lora.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [16]:
wandb_logger_lora = WandbLogger(project=wandb_project_name, log_model="all", name="lora", group="lora", save_dir="lightning_logs")
trainer_lora = Trainer(max_epochs=num_epochs, logger=wandb_logger_lora)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [17]:
trainer_lora.fit(model_lora, mnist_data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | layers | Sequential | 142 K 
--------------------------------------
6.2 K     Trainable params
136 K     Non-trainable params
142 K     Total params
0.569     Total estimated model params size (MB)


                                                                            

c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 860/860 [00:09<00:00, 88.78it/s, v_num=qfz3, train_loss_step=0.377, train_acc_step=0.958, val_loss_step=0.00304, val_acc_step=1.000, val_loss_epoch=0.0855, val_acc_epoch=0.975, train_loss_epoch=0.090, train_acc_epoch=0.973]  

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 860/860 [00:10<00:00, 84.96it/s, v_num=qfz3, train_loss_step=0.377, train_acc_step=0.958, val_loss_step=0.00304, val_acc_step=1.000, val_loss_epoch=0.0855, val_acc_epoch=0.975, train_loss_epoch=0.090, train_acc_epoch=0.973]


In [18]:
trainer_lora.test(model_lora, mnist_data)
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 247.68it/s]


0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅█
test_acc,▁
test_loss,▁
train_acc_epoch,▁█
train_acc_step,▇▅▃▅▁▇▅▃▂██▇█▆█▇▂█▆▆▅▅▇▇▆▇▇▆▃▇▅▆▇▆
train_loss_epoch,█▁
train_loss_step,▇▄█▇█▃▃▄▇▁▁▄▁▃▁▂█▁▄▄▅▇▂▂▆▂▄▃▇▂▃▄▂▂
trainer/global_step,▁▂▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▇▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂█
val_acc_epoch,▁█
val_acc_step,▇▆▅▆█▆▆▆▆▅▆▁▆██▇▆▅▇▇▆▆▅▅▅▆▇▇▆▇█▇▇▆▆▆▇▅▇█

0,1
epoch,2.0
test_acc,0.968
test_loss,0.1065
train_acc_epoch,0.97251
train_acc_step,0.96875
train_loss_epoch,0.09001
train_loss_step,0.06095
trainer/global_step,1720.0
val_acc_epoch,0.9748
val_acc_step,1.0


### Train Dora

In [19]:
freeze_linear_layers(model_dora)

# Check if linear layers are frozen
for name, param in model_dora.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.m: True
layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B: True
layers.2.m: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B: True
layers.4.m: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B: True


In [20]:
wandb_logger_dora = WandbLogger(project=wandb_project_name, log_model="all", name="dora", group="dora", save_dir="lightning_logs")
trainer_dora = Trainer(max_epochs=num_epochs, logger=wandb_logger_dora)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
trainer_dora.fit(model_dora, mnist_data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | layers | Sequential | 143 K 
--------------------------------------
7.4 K     Trainable params
136 K     Non-trainable params
143 K     Total params
0.574     Total estimated model params size (MB)


                                                                            

c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 860/860 [00:10<00:00, 78.72it/s, v_num=ctjo, train_loss_step=0.0252, train_acc_step=1.000, val_loss_step=0.751, val_acc_step=0.875, val_loss_epoch=0.0725, val_acc_epoch=0.979, train_loss_epoch=0.0767, train_acc_epoch=0.976] 

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 860/860 [00:11<00:00, 75.73it/s, v_num=ctjo, train_loss_step=0.0252, train_acc_step=1.000, val_loss_step=0.751, val_acc_step=0.875, val_loss_epoch=0.0725, val_acc_epoch=0.979, train_loss_epoch=0.0767, train_acc_epoch=0.976]


In [22]:
trainer_dora.test(model_dora, mnist_data)
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 229.07it/s]


0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅█
test_acc,▁
test_loss,▁
train_acc_epoch,▁█
train_acc_step,▁▅▆▅▆▅█▁█▅▃▅▅▆▆▃▅▆▆▃█▃▆▃█▃▅█▃▅▆▃▅▅
train_loss_epoch,█▁
train_loss_step,▆▄▂▃▂▃▂▆▂▃▄▅▃▂▂▄▆▂▂█▂▆▄▅▁▆▃▁▆▃▃▄▄▅
trainer/global_step,▁▂▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▇▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂█
val_acc_epoch,▁█
val_acc_step,▅▇▅▆██▇▇▇▆▇▅███▅█▇▆▅▇▅██▇▇▅█▆▆▆██▇▄█▆▇▇▁

0,1
epoch,2.0
test_acc,0.9674
test_loss,0.10446
train_acc_epoch,0.97631
train_acc_step,0.96875
train_loss_epoch,0.07667
train_loss_step,0.14442
trainer/global_step,1720.0
val_acc_epoch,0.9788
val_acc_step,0.875


### Train Lora Moe

In [23]:
from itertools import combinations

class RegularizedMLP(MultilayerPerceptron):
    def __init__(self, num_features, num_hidden_1, num_hidden_2, num_classes, learning_rate, regularization_type='cosine', lambda_reg=0.01):
        super().__init__(num_features, num_hidden_1, num_hidden_2, num_classes, learning_rate)
        
        self.regularization_type = regularization_type
        self.lambda_reg = lambda_reg

    def apply_regularization(self, outputs):
        if self.regularization_type == 'cosine':
            # Calcula la pérdida de regularización basada en la distancia coseno entre pares de salidas
            cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            regularization_loss = sum(cos(outputs[i], outputs[j]) for i, j in combinations(range(len(outputs)), 2)) / combinations(len(outputs), 2)
        elif self.regularization_type == 'kl':
            # Calcula la pérdida de regularización basada en la divergencia KL entre pares de salidas
            kl_div = torch.nn.KLDivLoss(reduction='batchmean')
            regularization_loss = sum(kl_div(F.log_softmax(outputs[i], dim=1), F.softmax(outputs[j], dim=1)) for i, j in combinations(range(len(outputs)), 2)) / combinations(len(outputs), 2)
        else:
            raise ValueError("Unsupported regularization type")
        return regularization_loss

    def training_step(self, batch, batch_idx):
        features, targets = batch
        features = features.view(-1, 28*28)
        logits = self(features)
        loss = F.cross_entropy(logits, targets)
        
        # Aplicar regularización si lambda_reg > 0
        if self.lambda_reg > 0:
            # Asumiendo que quieres regularizar basado en las salidas de cada capa lineal
            # Necesitarás ajustar este paso para extraer las salidas intermedias si es necesario
            intermediate_outputs = [layer(features) for layer in self.layers if isinstance(layer, nn.Linear)]
            regularization_loss = self.apply_regularization(intermediate_outputs)
            loss += self.lambda_reg * regularization_loss
        
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == targets).float().mean()
        self.log('train_acc', acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        return loss

In [24]:
model_lora_moe = copy.deepcopy(model)

In [25]:
model_lora_moe.layers[0] = LinearWithLoRAMixtureOfExperts(model_lora_moe.layers[0], rank=4, alpha=8, num_experts=8)
model_lora_moe.layers[2] = LinearWithLoRAMixtureOfExperts(model_lora_moe.layers[2], rank=4, alpha=8, num_experts=8)
model_lora_moe.layers[4] = LinearWithLoRAMixtureOfExperts(model_lora_moe.layers[4], rank=4, alpha=8, num_experts=8)

In [26]:
print(model_lora_moe)

MultilayerPerceptron(
  (layers): Sequential(
    (0): LinearWithLoRAMixtureOfExperts(
      (linear): Linear(in_features=784, out_features=128, bias=True)
      (lora): LoRAMixtureOfExpertsLayer(
        (B): ParameterList(
            (0): Parameter containing: [torch.float32 of size 4x128]
            (1): Parameter containing: [torch.float32 of size 4x128]
            (2): Parameter containing: [torch.float32 of size 4x128]
            (3): Parameter containing: [torch.float32 of size 4x128]
            (4): Parameter containing: [torch.float32 of size 4x128]
            (5): Parameter containing: [torch.float32 of size 4x128]
            (6): Parameter containing: [torch.float32 of size 4x128]
            (7): Parameter containing: [torch.float32 of size 4x128]
        )
      )
    )
    (1): ReLU()
    (2): LinearWithLoRAMixtureOfExperts(
      (linear): Linear(in_features=128, out_features=256, bias=True)
      (lora): LoRAMixtureOfExpertsLayer(
        (B): ParameterList(
    

In [27]:
freeze_linear_layers(model_lora_moe)

# Check if linear layers are frozen
for name, param in model_lora_moe.named_parameters():
    print(f"{name}: {param.requires_grad}")

layers.0.linear.weight: False
layers.0.linear.bias: False
layers.0.lora.A: True
layers.0.lora.B.0: True
layers.0.lora.B.1: True
layers.0.lora.B.2: True
layers.0.lora.B.3: True
layers.0.lora.B.4: True
layers.0.lora.B.5: True
layers.0.lora.B.6: True
layers.0.lora.B.7: True
layers.2.linear.weight: False
layers.2.linear.bias: False
layers.2.lora.A: True
layers.2.lora.B.0: True
layers.2.lora.B.1: True
layers.2.lora.B.2: True
layers.2.lora.B.3: True
layers.2.lora.B.4: True
layers.2.lora.B.5: True
layers.2.lora.B.6: True
layers.2.lora.B.7: True
layers.4.linear.weight: False
layers.4.linear.bias: False
layers.4.lora.A: True
layers.4.lora.B.0: True
layers.4.lora.B.1: True
layers.4.lora.B.2: True
layers.4.lora.B.3: True
layers.4.lora.B.4: True
layers.4.lora.B.5: True
layers.4.lora.B.6: True
layers.4.lora.B.7: True


In [28]:
wandb_logger_lora_moe = WandbLogger(project=wandb_project_name, log_model="all", name="lora_moe", group="lora_moe", save_dir="lightning_logs")
trainer_lora_moe = Trainer(max_epochs=num_epochs, logger=wandb_logger_lora_moe)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [29]:
trainer_lora_moe.fit(model_lora_moe, mnist_data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type       | Params
--------------------------------------
0 | layers | Sequential | 153 K 
--------------------------------------
17.3 K    Trainable params
136 K     Non-trainable params
153 K     Total params
0.613     Total estimated model params size (MB)


                                                                            

c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 1: 100%|██████████| 860/860 [00:14<00:00, 58.68it/s, v_num=21ia, train_loss_step=0.0452, train_acc_step=1.000, val_loss_step=0.00486, val_acc_step=1.000, val_loss_epoch=0.0552, val_acc_epoch=0.981, train_loss_epoch=0.0559, train_acc_epoch=0.983]  

`Trainer.fit` stopped: `max_epochs=2` reached.


Epoch 1: 100%|██████████| 860/860 [00:15<00:00, 56.50it/s, v_num=21ia, train_loss_step=0.0452, train_acc_step=1.000, val_loss_step=0.00486, val_acc_step=1.000, val_loss_epoch=0.0552, val_acc_epoch=0.981, train_loss_epoch=0.0559, train_acc_epoch=0.983]


In [30]:
trainer_lora_moe.test(model_lora_moe, mnist_data)
wandb.finish()

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\43294881\AppData\Local\Programs\Python\Python311\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 157/157 [00:00<00:00, 206.17it/s]


0,1
epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅█
test_acc,▁
test_loss,▁
train_acc_epoch,▁█
train_acc_step,███▁▆█▃▁▆█▃▃▆▆██▁▆█▆▆▁▆▆██▆█▃▃█▃█▆
train_loss_epoch,█▁
train_loss_step,▂▁▂▇▃▁▃▇▄▁▆▆▃▃▂▂▅▃▂▃▂▅▄▄▁▁▃▃█▃▁▅▂▃
trainer/global_step,▁▂▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▆▇▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂█
val_acc_epoch,▁█
val_acc_step,▂▁▅█▅█▅▅▅▂▇▇▅▄█▇▇▇▇▇▇█▇▄█▅█▄▇▅▇▇█▇█▅▇▇██

0,1
epoch,2.0
test_acc,0.9756
test_loss,0.08495
train_acc_epoch,0.98311
train_acc_step,0.98438
train_loss_epoch,0.05593
train_loss_step,0.04594
trainer/global_step,1720.0
val_acc_epoch,0.9814
val_acc_step,1.0
