In [1]:
import torch
import random

from util import TokenizerUtil

tokenizer = TokenizerUtil()

input_ids, attention_mask = tokenizer.encode('how are you', max_length=4)

input_ids, attention_mask, tokenizer.decode(input_ids)

  from .autonotebook import tqdm as notebook_tqdm


(tensor([   0, 9178,   32,    2]), tensor([1, 1, 1, 1]), '<s>how are</s>')

In [2]:
from datasets import load_dataset
from transformers import default_data_collator

dataset = load_dataset('json', data_files='dataset/train.json', split='train')

#2,4,4切分,取第0部分
dataset = dataset.select(range(15000))


def f(data):
    #随机生成两种回答
    if random.random() > 0.5:
        data['chosen'] = data['chosen'].swapcase()
    data = data['prompt'] + data['chosen']

    input_ids, attention_mask = tokenizer.encode(data)

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': input_ids.clone()
    }


dataset = dataset.map(f, remove_columns=dataset.column_names)

loader = torch.utils.data.DataLoader(dataset,
                                     collate_fn=default_data_collator,
                                     batch_size=2,
                                     shuffle=True,
                                     drop_last=True)

len(loader), next(iter(loader))

Map: 100%|██████████| 15000/15000 [00:06<00:00, 2422.67 examples/s]


(7500,
 {'input_ids': tensor([[    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1]]),
  'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
          [1, 1, 1,  ..., 0, 0, 0]]),
  'labels': tensor([[    0, 33837,    35,  ...,     1,     1,     1],
          [    0, 33837,    35,  ...,     1,     1,     1]])})

In [3]:
from transformers import AutoModelForCausalLM
import lora

model_actor = AutoModelForCausalLM.from_pretrained('facebook/opt-1.3b')

lora.insert(model_actor)
lora.count_params(model_actor)

{'count_require': 2.21044736, 'count_all': 14.29004288, 'ratio': 0.15468444556549854}


In [4]:
from transformers import get_scheduler
from accelerate import Accelerator


def f():
    params = []
    params_lora = []
    for name, param in model_actor.named_parameters():
        if not param.requires_grad:
            continue

        if 'lora_A' in name or 'lora_B' in name:
            params_lora.append(param)
            continue

        params.append(param)

    return [{
        'params': params,
        'weight_decay': 0.0,
    }, {
        'params': params_lora,
        'weight_decay': 0.0,
        'lr': 5e-4
    }]


optimizer = torch.optim.Adam(f(), lr=1e-3, betas=(0.9, 0.95))

scheduler = get_scheduler(name='cosine',
                          optimizer=optimizer,
                          num_warmup_steps=0,
                          num_training_steps=100)

accelerator = Accelerator(gradient_accumulation_steps=64,
                          mixed_precision='fp16')

model_actor, loader, optimizer, scheduler = accelerator.prepare(
    model_actor, loader, optimizer, scheduler)

model_actor.train()

Detected kernel version 3.10.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


OPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 2048, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
      (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Lora(
              (linear): Linear(in_features=2048, out_features=2048, bias=True)
            )
            (v_proj): Lora(
              (linear): Linear(in_features=2048, out_features=2048, bias=True)
            )
            (q_proj): Lora(
              (linear): Linear(in_features=2048, out_features=2048, bias=True)
            )
            (out_proj): Lora(
              (linear): Linear(in_features=2048, out_features=2048, bias=True)
            )
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (f

In [5]:
for i, data in enumerate(loader):
    with accelerator.accumulate(model_actor):
        out = model_actor(**data)
        accelerator.backward(out.loss)

        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(
                [i for i in model_actor.parameters() if i.requires_grad], 1.0)

        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

    if (i + 1) % 100 == 0:
        lr = optimizer.param_groups[0]['lr']
        print(i, len(loader), out.loss.item(), lr)

        logits = out.logits[0].argmax(1)
        print(tokenizer.decode(logits))

    if i == 2000:
        break

lora.merge(model_actor)
model_actor.save_pretrained('model/actor')

99 7500 10.771519660949707 0.0009997532801828658
ed.
.�ATE TABLE ( (name ( ( ( (table (name___idnables,ARCHAR(w_doubles VARCHAR, mixed_ARCHAR, AS_ CRE is the womenens who players the tour doubles were createdoned and andhang andun andi and the women? theland team series in
: CRE tableens_doubles FROM table_27753492_2 WHERE w_doubles = (1hang Nan Zhao Yunlei" AND mixed = "All England Super Series" ANDI the the</s>
199 7500 2.8840863704681396 0.00099778098230154
of_
:HumanATE TABLE "_ CITYCodeARCHAR( WHERE_ " many times waysuses are you have?
: What cityOUNT(cityISTINCT)) FROM city WHEREIin<pad><pad><pad><pad><pad>in<pad><pad><pad>inosos<pad><pad><pad>in<pad><pad><pad>a<pad><pad><pad>:<pad><pad>os<pad><pad><pad><pad>inin<pad><pad><pad><pad><pad>_<pad><pad><pad><pad>os<pad><pad>in<pad><pad><pad>:<pad><pad><pad><pad>_<pad><pad><pad><pad>i<pad>ie<pad>inidos_<pad><pad><pad><pad>os<pad><pad><pad><pad><pad>inin<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>ig<pad><pad><pad>in<pad>

599 7500 0.3007298409938812 0.0009801468428384716
<pad>: What=<pad>ATE TABLE table_name_1 (SELECT=name,ARCHAR, d_44 VARCHAR, where= What of tableIST V the 48 V the_</s>
699 7500 0.28828269243240356 0.0009755282581475768
<pad>: What=<pad>ATE TABLE table_name_1 (select__code_3,1ARCHAR, station_ARCHAR) question= What is the name number for the station<pad>?</s>
799 7500 0.2625236511230469 0.0009648882429441257
Human: context= CREATE TABLE table_name_2 (select_name_birth VARCHAR, place_ARCHAR) and V Vrank VARCHAR) question= What is the place of birth of the? " 9, 2005:? and whenatial_ is " V John Stien John.achia? Assistant:select place_of_birth FROM table_name_78 WHERE elevated = MAYmay 16, 1288" AND cardinalatial_title = "deacon of s. eustachio"</s>
899 7500 0.23500771820545197 0.0009524135262330098
Human: context= CREATE TABLE table_name ( ( (8 (score VARCHAR, class V Vteam VARCHAR) question= What many college did a New team with?? Assistant:select collegeOUNT(n) FROM TABLE_2508633_11 W