In [2]:
from copy import deepcopy

from petorch.prebuilt.lora import LoraLinear
from torch import nn
import torch
from petorch.adapter import AdapterAPI, BaseAdaptedModelConfig, BaseAdapter
from pydantic import PositiveInt, NonNegativeFloat
from typing import cast

In [3]:
class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_shape = (8,8)
        self.conv = nn.Conv2d(3,3,3,padding='same')
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(8*8*3, 256)
        self.fc2 = nn.Linear(256, 100)
        self.fc3 = nn.Linear(100, 16)

    def forward(self, input: torch.Tensor):
        assert input.shape[2:] == self.input_shape, f"{input.shape[:2]}-{self.input_shape}"
        x = self.conv(input)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x

class LoraLinearConfig(BaseAdaptedModelConfig):
    rank: PositiveInt = 8
    alpha: PositiveInt = 16
    dropout: NonNegativeFloat = 0.1


    def dispatch_adapter(self, fpname: str, base_layer: nn.Module, *args, **kwargs)-> BaseAdapter | None:
        if isinstance(base_layer,nn.Linear):
            return LoraLinear(cast(nn.Linear,base_layer),self)
model = Dummy()
org_model = deepcopy(model)
config = LoraLinearConfig()
sample = torch.rand([2,3,8,8])
output = org_model(sample)
org_model

Dummy(
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=192, out_features=256, bias=True)
  (fc2): Linear(in_features=256, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=16, bias=True)
)

In [4]:
print(AdapterAPI.add_adapter(model, config))
model

['fc1', 'fc2', 'fc3']


Dummy(
  (conv): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1), padding=same)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): AdaptedLayer(
    (base_layer): Linear(in_features=192, out_features=256, bias=True)
    (active_adapters): ModuleDict()
    (non_active_adapters): ModuleDict(
      (default): LoraLinear(
        (lora_A): Linear(in_features=192, out_features=8, bias=True)
        (lora_B): Linear(in_features=8, out_features=256, bias=True)
        (lora_dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc2): AdaptedLayer(
    (base_layer): Linear(in_features=256, out_features=100, bias=True)
    (active_adapters): ModuleDict()
    (non_active_adapters): ModuleDict(
      (default): LoraLinear(
        (lora_A): Linear(in_features=256, out_features=8, bias=True)
        (lora_B): Linear(in_features=8, out_features=100, bias=True)
        (lora_dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (fc3): AdaptedLayer(
    (base_layer): Linear(in_features

In [12]:
output2 = model(sample)
assert torch.all(output==output2)

In [13]:
try:
    AdapterAPI.activate_adapter(model, 'abc')
except ValueError as e:
    print(e)

Model does not have adapter named `abc`.


In [14]:
AdapterAPI.activate_adapter(model)

['default']

In [17]:
output3 = model(sample)
assert not torch.all(output==output3)

In [8]:
AdapterAPI.remove_adapter(model, "default")

In [9]:
output4 = model(sample)
torch.all(output==output4)

tensor(True)

In [10]:
from copy import deepcopy
from typing import cast

import pytest
import torch
from pydantic import PositiveInt, NonNegativeFloat, BaseModel
from torch import nn

from petorch.adapter import BaseAdaptedModelConfig, BaseAdapter, AdapterAPI


class DummyAdapter(BaseAdapter):
    """
    Dummy LinearLora
    """
    def __init__(self, base_layer: nn.Linear, config: "BaseAdaptedModelConfig"):
        assert isinstance(
            base_layer, nn.Linear
        ), f"Base layer must has type {nn.Linear}, got {type(base_layer)}."
        super().__init__(base_layer, config)

        self.lora_A = nn.Linear(base_layer.in_features, self.rank)
        self.lora_B = nn.Linear(self.rank, base_layer.out_features)
        self.lora_dropout = nn.Dropout(self.dropout)

        self.scale = getattr(self.config, "scale", None) or 1

    @property
    def rank(self) -> int:
        return self.config.rank

    @property
    def alpha(self) -> float:
        return self.config.alpha

    @property
    def dropout(self) -> float:
        return self.config.dropout

    @property
    def scaling(self) -> float:
        return self.scale * self.alpha / self.rank

    @classmethod
    def pre_validate_config(cls, config: "BaseAdaptedModelConfig") -> None:
        class ConfigValidator(BaseModel):
            rank: PositiveInt
            alpha: PositiveInt
            dropout: NonNegativeFloat
            adapter_name: str

        ConfigValidator.model_validate(config, from_attributes=True)

    def forward(self, batch_input: torch.Tensor, **kwargs) -> torch.Tensor:
        output = self.base_layer(batch_input)
        return (
            output
            + self.lora_B(self.lora_A(self.lora_dropout(batch_input))) * self.scaling
        )

class DummySubModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(8 * 8 * 3, 256)
        self.fc2 = nn.Linear(256, 100)

    def forward(self, input: torch.Tensor):
        x = self.fc1(input)
        x = self.fc2(x)
        return x

class Dummy(nn.Module):
    def __init__(self):
        super().__init__()
        self.input_shape = (8, 8)
        self.conv = nn.Conv2d(3, 3, 3, padding="same")
        self.flatten = nn.Flatten()
        self.sub_model = DummySubModel()
        self.fc = nn.Linear(100, 16)

    def forward(self, input: torch.Tensor):
        assert input.shape[2:] == self.input_shape,  f"{input.shape[:2]}-{self.input_shape}"
        x = self.conv(input)
        x = self.flatten(x)
        x = self.sub_model(x)
        x = self.fc(x)
        return x


class DummyConfig(BaseAdaptedModelConfig):
    rank: PositiveInt = 8
    alpha: PositiveInt = 16
    dropout: NonNegativeFloat = 0.1

    def dispatch_adapter(
        self, fpname: str, base_layer: nn.Module, *args, **kwargs
    ) -> BaseAdapter | None:
        if isinstance(base_layer, nn.Linear):
            return DummyAdapter(cast(nn.Linear, base_layer), self)


def test_api():
    adapter_name = "test_adapter"
    model = Dummy()
    org_model = deepcopy(model)
    config = DummyConfig(adapter_name=adapter_name)
    sample = torch.rand([2, 3, 8, 8])

    output1 = org_model(sample)

    # Test add_adapter
    target_fqn = [
        "fc",
        "sub_model.fc1",
        "sub_model.fc2"
    ]
    target_fqn.sort()
    fqn = AdapterAPI.add_adapter(model, config)
    fqn.sort()
    assert fqn == target_fqn

test_api()


In [10]:
from transformers import Qwen2ForCausalLM, TorchAoConfig
from torchao.dtypes import NF4Tensor, to_nf4
from torchao.quantization import register_quantize_module_handler, Float8WeightOnlyConfig, ModuleFqnToConfig
from dataclasses import dataclass
from torchao.core.config import AOBaseConfig
import torch
from torch import nn
import types
from torchao.utils import get_model_size_in_bytes
@dataclass
class NF4Config(AOBaseConfig):
    block_size: int = 64
    scaler_block_size: int = 256

def linear_module_repr(module: nn.Linear):
    return f"in_features={module.weight.shape[1]}, out_features={module.weight.shape[0]}, weight={module.weight}, dtype={module.weight.dtype}"

@register_quantize_module_handler(NF4Config)
def _nf4_weight_only_transform(
    module: torch.nn.Module,
    config: NF4Config,
) -> torch.nn.Module:
    new_weight = to_nf4(module.weight, config.block_size, config.scaler_block_size)
    module.weight = nn.Parameter(new_weight, requires_grad=False) # Freeze
    module.extra_repr = types.MethodType(
        linear_module_repr,
        module
    )
    return module

config = TorchAoConfig(NF4Config())

model = quantized_model = Qwen2ForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
)

quantized_model = Qwen2ForCausalLM.from_pretrained(
    "Qwen/Qwen2.5-0.5B-Instruct",
    quantization_config = config
)
model_size = get_model_size_in_bytes(model)
quantized_model_size = get_model_size_in_bytes(quantized_model)
print(model_size) # 2520669824
print(quantized_model_size) # 1273966688
print(quantized_model_size/model_size) # 0.5054079974577425

2520669824
1273966688
0.5054079974577425
