In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%run 1.tokenizer.ipynb

tokenizer

<__main__.Tokenizer at 0x7f0b84260280>

In [2]:
%run 2.dataset.ipynb


def f(data):
    data = [i['text'] for i in data]
    data = tokenizer(data, device=device)

    data['labels'] = data['input_ids'].clone()
    select = data['labels'] == tokenizer.pad_token_id
    data['labels'][select] = -100

    return data


loader = get_loader(f, negative_label=False, with_answer=True)

len(loader), next(iter(loader))

(125000,
 {'input_ids': tensor([[ 0,  6, 12,  5,  6, 13, 12,  5,  3, 17,  7, 11,  7,  6,  1,  2,  2,  2,
            2,  2],
          [ 0,  4,  5,  7,  6, 15, 10,  6, 11,  4, 17, 12,  4, 10,  7,  8, 11,  6,
            1,  2],
          [ 0, 10, 12,  9, 12, 16,  7,  3,  3,  4, 17,  4,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0, 11,  5,  7,  3, 16,  4,  9,  6,  3, 17,  8,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0, 12, 10,  8,  5, 16,  7,  6,  7,  9, 17,  5,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0,  7, 11, 10,  9, 16, 10, 11,  6,  5, 17,  3,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0,  8,  6,  9,  3, 16,  5, 12,  4,  5, 17,  4,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0,  4,  9,  7,  4, 15,  9,  5,  5,  4, 17,  4,  3,  5,  3, 11,  9,  9,
            4,  1],
          [ 0, 10,  4,  9,  9, 16,  7, 12,  7, 12, 17,  4,  1,  2,  2,  2,  2,  2,
            2,  2],
          [ 0,  4,  6,  7, 11, 15,  7,  9,  4,  5, 

In [3]:
from transformers import GPTNeoXConfig, AutoModelForCausalLM

model_actor = AutoModelForCausalLM.from_config(
    GPTNeoXConfig(architectures=['GPTNeoXForCausalLM'],
                  eos_token_id=tokenizer.eos_token_id,
                  _attn_implementation_autoset=True,
                  model_type='gpt_neox',
                  vocab_size=len(tokenizer),
                  hidden_size=768,
                  num_hidden_layers=12,
                  num_attention_heads=12,
                  intermediate_size=3072)).to(device)

model_actor.config

GPTNeoXConfig {
  "_attn_implementation_autoset": true,
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "attention_bias": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.1,
  "eos_token_id": 1,
  "hidden_act": "gelu",
  "hidden_dropout": 0.0,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neox",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "partial_rotary_factor": 0.25,
  "rope_scaling": null,
  "rope_theta": 10000,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "transformers_version": "4.48.0",
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 19
}

In [4]:
optimizer = torch.optim.Adam(model_actor.parameters(), lr=1e-5)

for i, data in enumerate(loader):
    out = model_actor(**data)
    out.loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if i % 1000 == 0:
        print(i, len(loader), out.loss.item())
        prompt = data['input_ids'][0]
        eq = prompt.tolist().index(tokenizer.eq_token_id) + 1
        chosen = prompt[eq:]
        prompt = prompt[:eq]

        gen = model_actor.generate(input_ids=prompt.unsqueeze(0),
                                   min_length=-1,
                                   max_length=50,
                                   pad_token_id=tokenizer.pad_token_id,
                                   eos_token_id=tokenizer.eos_token_id,
                                   top_k=0.0,
                                   top_p=1.0,
                                   do_sample=True)
        gen = gen[0, eq:]

        print({
            'prompt': tokenizer.decode(prompt),
            'chosen': tokenizer.decode(chosen),
            'gen': tokenizer.decode(gen)
        })

model_actor.save_pretrained('model/actor')

0 125000 3.059210777282715
{'prompt': 'B4026*9607=', 'chosen': '38677782E', 'gen': '=6944+/18=26=0-E'}
1000 125000 1.900925636291504
{'prompt': 'B3481/3932=', 'chosen': '0E', 'gen': '1E'}
2000 125000 1.920478343963623
{'prompt': 'B2702-583=', 'chosen': '2119E', 'gen': '884E'}
3000 125000 1.9116497039794922
{'prompt': 'B5336/1090=', 'chosen': '4E', 'gen': '4E'}
4000 125000 1.9236778020858765
{'prompt': 'B88*6969=', 'chosen': '613272E', 'gen': '727915E'}
5000 125000 1.8582617044448853
{'prompt': 'B6262/5067=', 'chosen': '1E', 'gen': '0E'}
6000 125000 1.8719298839569092
{'prompt': 'B3118-1060=', 'chosen': '2058E', 'gen': '1213E'}
7000 125000 1.7576889991760254
{'prompt': 'B632/1122=', 'chosen': '0E', 'gen': '0E'}
8000 125000 1.796338438987732
{'prompt': 'B4698-8030=', 'chosen': '-3332E', 'gen': '-4982E'}
9000 125000 1.7663074731826782
{'prompt': 'B5663+8327=', 'chosen': '13990E', 'gen': '132899E'}
10000 125000 1.7474076747894287
{'prompt': 'B4101/8913=', 'chosen': '0E', 'gen': '0E'}
11000