In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%run 1.tokenizer.ipynb

tokenizer

<__main__.Tokenizer at 0x7f40f525c280>

In [2]:
%run 2.dataset.ipynb


def f(data):
    text = [i['text'] for i in data]
    label = [i['label'] for i in data]

    data = tokenizer(text, device=device)
    data['labels'] = torch.FloatTensor(label).to(device)

    return data


loader = get_loader(f, negative_label=True, with_answer=True)

len(loader), next(iter(loader))

(125000,
 {'input_ids': tensor([[ 0,  7,  3,  5, 13,  7,  5,  7,  4, 17,  7,  9,  7,  6,  1,  2,  2,  2,
            2,  2],
          [ 0, 10,  3,  9,  6, 16, 11,  9,  5,  7, 17,  3,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0, 11,  6,  7,  8, 13,  8,  5,  4, 10, 17,  9,  5,  4, 10,  1,  2,  2,
            2,  2],
          [ 0, 10,  7,  8, 12, 15,  6, 11,  9,  5, 17,  5, 11, 11,  3,  9,  9,  8,
           11,  1],
          [ 0,  6, 10,  9, 11, 15, 12, 11,  4, 10, 17,  6,  9, 12, 12,  3,  7,  8,
            9,  1],
          [ 0, 12,  4,  9,  8, 16, 11, 10,  9, 11, 17,  4,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0,  9,  5,  3, 12, 16,  4,  4,  8, 12, 17,  8,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0, 11,  9,  9, 10, 15,  4,  8, 12, 17,  4,  6, 10, 11,  3,  8,  6,  1,
            2,  2],
          [ 0, 12, 12,  7, 14,  8,  4, 12,  8, 17, 14,  7,  5,  3,  4,  1,  2,  2,
            2,  2],
          [ 0, 12,  4,  9,  8, 16,  5,  7,  8,  4, 

In [3]:
from transformers import GPTNeoXConfig, AutoModelForSequenceClassification

model_critic = AutoModelForSequenceClassification.from_config(
    GPTNeoXConfig(_attn_implementation_autoset=True,
                  architectures=['GPTNeoXForCausalLM'],
                  eos_token_id=tokenizer.eos_token_id,
                  id2label={'0': 'LABEL_0'},
                  label2id={'LABEL_0': 0},
                  hidden_size=768,
                  intermediate_size=3072,
                  num_attention_heads=12,
                  num_hidden_layers=12,
                  vocab_size=len(tokenizer),
                  pad_token_id=tokenizer.pad_token_id,
                  num_labels=1)).to(device)

model_critic.config

GPTNeoXConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "attention_bias": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.1,
  "eos_token_id": 1,
  "hidden_act": "gelu",
  "hidden_dropout": 0.0,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "label2id": {
    "LABEL_0": 0
  },
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neox",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 2,
  "partial_rotary_factor": 0.25,
  "rope_scaling": null,
  "rope_theta": 10000,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "transformers_version": "4.48.0",
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 19
}

In [4]:
optimizer = torch.optim.Adam(model_critic.parameters(), lr=1e-5)

for i, data in enumerate(loader):
    out = model_critic(**data)
    out.loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 1000 == 0:
        logits = (out.logits > 0.5).squeeze(1).long()
        acc = (logits == data['labels'].long()).sum() / len(data['labels'])
        print(i, len(loader), out.loss.item(), acc.item())

model_critic.save_pretrained('model/critic')

0 125000 0.34103602170944214 0.5625
1000 125000 0.20233052968978882 0.6875
2000 125000 0.18607331812381744 0.75
3000 125000 0.1396457999944687 0.6875
4000 125000 0.10160748660564423 0.875
5000 125000 0.13328039646148682 0.875
6000 125000 0.126186341047287 0.8125
7000 125000 0.11202022433280945 0.8125
8000 125000 0.11786967515945435 0.6875
9000 125000 0.1412677764892578 0.8125
10000 125000 0.21582680940628052 0.625
11000 125000 0.1226201206445694 0.75
12000 125000 0.16518059372901917 0.8125
13000 125000 0.05162348598241806 0.875
14000 125000 0.08120504021644592 0.875
15000 125000 0.18136709928512573 0.75
16000 125000 0.11339117586612701 0.875
17000 125000 0.06583108007907867 0.9375
18000 125000 0.15697476267814636 0.75
19000 125000 0.06264254450798035 0.875
20000 125000 0.17975997924804688 0.8125
21000 125000 0.160003662109375 0.8125
22000 125000 0.0356559231877327 1.0
23000 125000 0.1474703997373581 0.75
24000 125000 0.09269499778747559 0.875
25000 125000 0.1079915314912796 0.875
26000