In [1]:
import torch
from functions import get_loader, get_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
_, _, loader = get_loader()
model, _, _ = get_model()

model.classifier

Linear(in_features=768, out_features=10, bias=True)

In [2]:
from peft import LoraConfig, TaskType, get_peft_model, LoftQConfig, PromptEncoderConfig, IA3Config

config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['classifier'],

    #设置A层参数初始化方式,默认A层是凯明均匀分布,B层是全0
    #init_lora_weights='gaussian',

    #使用loftq初始化参数,一般会获得更好的效果
    init_lora_weights='loftq',
    loftq_config=LoftQConfig(loftq_bits=4),

    #使用数值缩放,也是增进训练效果的
    use_rslora=True,

    #另一种插入层的结构,和loftq不共存
    use_dora=False,
)

#适用于CAUSAL_LM任务的配置
config = PromptEncoderConfig(task_type='SEQ_CLS',
                             num_virtual_tokens=20,
                             encoder_hidden_size=128)

#IA3是比lora更激进的方式,可训练的参数更少
config = IA3Config(task_type='SEQ_CLS', target_modules=['classifier'])

model = get_peft_model(model, config)

model.print_trainable_parameters()

model.classifier

trainable params: 7,700 || all params: 251,255,080 || trainable%: 0.0030646146537614285


ModulesToSaveWrapper(
  (original_module): Linear(
    (base_layer): Linear(in_features=768, out_features=10, bias=True)
    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 10x1])
  )
  (modules_to_save): ModuleDict(
    (default): Linear(
      (base_layer): Linear(in_features=768, out_features=10, bias=True)
      (ia3_l): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 10x1])
    )
  )
)

In [3]:
import datetime

#正常训练
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.to(device)

now = datetime.datetime.now()
for i, data in enumerate(loader):
    for k, v in data.items():
        data[k] = v.to(device)
    out = model(**data)
    out.loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if i % 1 == 0:
        labels = data['labels']
        logits = out['logits'].argmax(1)
        acc = (labels == logits).sum().item() / len(labels)

        print(i, len(loader), out.loss.item(), acc)

datetime.datetime.now() - now

0 62 2.4089834690093994 0.0
1 62 2.3106164932250977 0.09375
2 62 2.316065788269043 0.0625
3 62 2.363997459411621 0.09375
4 62 2.301146984100342 0.09375
5 62 2.2765536308288574 0.1875
6 62 2.3127336502075195 0.0625
7 62 2.313250780105591 0.0625
8 62 2.263867139816284 0.09375
9 62 2.257852554321289 0.03125
10 62 2.229757785797119 0.1875
11 62 2.2166881561279297 0.1875
12 62 2.1789209842681885 0.125
13 62 2.2374184131622314 0.1875
14 62 2.230839729309082 0.15625
15 62 2.2052810192108154 0.125
16 62 2.1605560779571533 0.21875
17 62 2.1335458755493164 0.25
18 62 2.146886110305786 0.1875
19 62 2.1742775440216064 0.25
20 62 2.192859411239624 0.1875
21 62 2.1038081645965576 0.28125
22 62 2.1520001888275146 0.25
23 62 2.1106300354003906 0.375
24 62 2.1247026920318604 0.34375
25 62 2.1607823371887207 0.21875
26 62 2.1247217655181885 0.15625
27 62 2.051104784011841 0.5625
28 62 2.0689423084259033 0.46875
29 62 2.037259340286255 0.59375
30 62 2.0886518955230713 0.4375
31 62 2.0399112701416016 0.53

datetime.timedelta(seconds=12, microseconds=140622)