In [18]:
import torch
from functions import get_loader, get_model

device = 'cuda' if torch.cuda.is_available() else 'cpu'
_, _, loader = get_loader()
model, _, _ = get_model()

model.classifier

Linear(in_features=768, out_features=10, bias=True)

In [19]:
from peft import LoraConfig, TaskType, get_peft_model, LoftQConfig

config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=['classifier'],

    #设置A层参数初始化方式,默认A层是凯明均匀分布,B层是全0
    #init_lora_weights='gaussian',

    #使用loftq初始化参数,一般会获得更好的效果
    init_lora_weights='loftq',
    loftq_config=LoftQConfig(loftq_bits=4),

    #使用数值缩放,也是增进训练效果的
    use_rslora=True,

    #另一种插入层的结构,和loftq不共存
    use_dora=False,
)

#适用于CAUSAL_LM任务的配置
#from peft import PromptEncoderConfig
#config = PromptEncoderConfig(task_type='CAUSAL_LM', num_virtual_tokens=20, encoder_hidden_size=128)

#IA3是比lora更激进的方式,可训练的参数更少
#from peft import IA3Config
#config = IA3Config(task_type='SEQ_CLS', target_modules=['classifier'])

model = get_peft_model(model, config)

model.print_trainable_parameters()

model.classifier

trainable params: 7,700 || all params: 251,255,080 || trainable%: 0.0030646146537614285


ModulesToSaveWrapper(
  (original_module): Linear(
    (base_layer): Linear(in_features=768, out_features=10, bias=True)
    (ia3_l): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 10x1])
  )
  (modules_to_save): ModuleDict(
    (default): Linear(
      (base_layer): Linear(in_features=768, out_features=10, bias=True)
      (ia3_l): ParameterDict(  (default): Parameter containing: [torch.FloatTensor of size 10x1])
    )
  )
)

In [20]:
import datetime

#正常训练
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
model.to(device)

now = datetime.datetime.now()
for i, data in enumerate(loader):
    for k, v in data.items():
        data[k] = v.to(device)
    out = model(**data)
    out.loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    optimizer.zero_grad()

    if i % 1 == 0:
        labels = data['labels']
        logits = out['logits'].argmax(1)
        acc = (labels == logits).sum().item() / len(labels)

        print(i, len(loader), out.loss.item(), acc)

datetime.datetime.now() - now

0 62 2.2574074268341064 0.15625
1 62 2.2835590839385986 0.125
2 62 2.265936851501465 0.125
3 62 2.256568670272827 0.125
4 62 2.2342605590820312 0.125
5 62 2.2471179962158203 0.15625
6 62 2.2529959678649902 0.09375
7 62 2.2177746295928955 0.09375
8 62 2.17441725730896 0.1875
9 62 2.1450140476226807 0.25
10 62 2.1880834102630615 0.28125
11 62 2.166544198989868 0.1875
12 62 2.2038490772247314 0.0625
13 62 2.118818521499634 0.34375
14 62 2.136948347091675 0.3125
15 62 2.128154993057251 0.28125
16 62 2.189143419265747 0.28125
17 62 2.163454532623291 0.21875
18 62 2.1524388790130615 0.21875
19 62 2.2027273178100586 0.25
20 62 2.018207550048828 0.46875
21 62 2.075979232788086 0.40625
22 62 2.0548763275146484 0.46875
23 62 2.042815685272217 0.5
24 62 2.10168194770813 0.3125
25 62 2.0754051208496094 0.34375
26 62 1.9900621175765991 0.625
27 62 2.017026424407959 0.53125
28 62 2.044901132583618 0.46875
29 62 2.0529961585998535 0.375
30 62 1.984081745147705 0.71875
31 62 2.0559780597686768 0.375
3

datetime.timedelta(seconds=11, microseconds=527618)