In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

In [12]:
class BaselineModel(nn.Module):
    """
    LoRAを使わない2層MLP。
    - Linear -> ReLU -> Linear
    - hidden_dimを大きくしてパラメータを増やしている。
    """
    def __init__(self, input_dim=768, hidden_dim=512, output_dim=10):
        super().__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.linear1(x))
        x = self.linear2(x)
        return x

In [11]:
class LoRALinear(nn.Module):
    """
    LoRA (Low-Rank Adaptation) を適用した線形層のサンプル実装。
    - 元の重み W は凍結し更新しない
    - 代わりに低ランク行列 A, B のみ学習し、W + (alpha/r) * (B @ A) の形で線形変換を行う
    """
    def __init__(self, in_features, out_features, r=4, alpha=1.0, bias=True):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.r = r
        self.alpha = alpha

        # 元の重み (freeze)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        self.weight.requires_grad = False  # ← 凍結して学習しない

        # バイアス（必要なら学習OK）
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)

        # LoRA用 低ランク行列
        self.lora_A = nn.Parameter(torch.empty(r, in_features))
        self.lora_B = nn.Parameter(torch.empty(out_features, r))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
        nn.init.zeros_(self.lora_B)

        # スケーリング係数
        self.scaling = alpha / r

    def forward(self, x):
        """
        y = x * [W + (alpha/r)*(B@A)]^T + bias
        """
        # LoRA のアップデート項 (B @ A)
        lora_update = self.lora_B @ self.lora_A
        # 最終的な重み
        effective_weight = self.weight + self.scaling * lora_update
        return F.linear(x, effective_weight, self.bias)

In [13]:
class LoRAModel(nn.Module):
    """
    中間層に LoRALinear を使った2層MLP。
    - 最初の層をLoRALinearにして、重みは凍結 + 低ランク行列のみ学習
    - 2層目は普通の nn.Linear (任意でLoRAにしてもOK)
    """
    def __init__(self, input_dim=768, hidden_dim=512, output_dim=10,
                 lora_r=4, lora_alpha=1.0):
        super().__init__()
        self.lora_linear = LoRALinear(input_dim, hidden_dim,
                                      r=lora_r, alpha=lora_alpha)
        self.linear2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.lora_linear(x))
        x = self.linear2(x)
        return x

In [14]:
torch.manual_seed(42)

# 大きめの次元数を例として設定
input_dim = 768
hidden_dim = 512
output_dim = 10
batch_size = 32
num_epochs = 3  # 簡単に回す

# (1) ベースラインモデル (LoRAなし)
baseline_model = BaselineModel(
    input_dim,
    hidden_dim,
    output_dim
  )
# (2) LoRA付きモデル
#    低ランク次元 r を小さくすると追加パラメータが少ない
lora_r = 4
lora_alpha = 8
lora_model = LoRAModel(
    input_dim,
    hidden_dim,
    output_dim,
    lora_r,
    lora_alpha
  )

In [15]:
def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total, trainable

# パラメータ数を表示
base_total, base_train = count_params(baseline_model)
lora_total, lora_train = count_params(lora_model)

print("=== Baseline Model ===")
print(f"Total params    : {base_total:,}")
print(f"Trainable params: {base_train:,}")

print("\n=== LoRA Model ===")
print(f"Total params    : {lora_total:,}")
print(f"Trainable params: {lora_train:,}")

=== Baseline Model ===
Total params    : 398,858
Trainable params: 398,858

=== LoRA Model ===
Total params    : 403,978
Trainable params: 10,762


In [16]:
X_dummy = torch.randn(batch_size, input_dim)
y_dummy = torch.randint(0, output_dim, (batch_size,))

# オプティマイザ
baseline_opt = optim.SGD(baseline_model.parameters(), lr=0.01)
lora_opt = optim.SGD(lora_model.parameters(), lr=0.01)

print("\n==== 学習開始 ====")
for epoch in range(num_epochs):
    # (A) Baseline
    baseline_opt.zero_grad()
    logits_base = baseline_model(X_dummy)
    loss_base = F.cross_entropy(logits_base, y_dummy)
    loss_base.backward()
    baseline_opt.step()

    # (B) LoRA
    lora_opt.zero_grad()
    logits_lora = lora_model(X_dummy)
    loss_lora = F.cross_entropy(logits_lora, y_dummy)
    loss_lora.backward()
    lora_opt.step()

    print(f"[Epoch {epoch+1}/{num_epochs}] "
          f"BaselineLoss={loss_base.item():.4f}, "
          f"LoRALoss={loss_lora.item():.4f}")


==== 学習開始 ====
[Epoch 1/3] BaselineLoss=2.3053, LoRALoss=2.3254
[Epoch 2/3] BaselineLoss=2.2467, LoRALoss=2.3022
[Epoch 3/3] BaselineLoss=2.1896, LoRALoss=2.2794
