In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'

%run 1.tokenizer.ipynb

tokenizer

<__main__.Tokenizer at 0x7efcfde1c550>

In [2]:
%run 2.dataset.ipynb


def f(data):
    data = [i['text'] for i in data]
    data = tokenizer(data, device=device)

    data['labels'] = data['input_ids'].clone()
    select = data['labels'] == tokenizer.pad_token_id
    data['labels'][select] = -100

    return data


loader = get_loader(f, negative_label=False, with_answer=True)

len(loader), next(iter(loader))

(62500,
 {'input_ids': tensor([[ 0,  6,  6,  5,  7, 14,  7,  5, 12, 10, 18,  9,  7, 13, 13, 14, 11,  6,
            4,  4, 18,  5,  6,  9, 13, 13,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,
            2,  2,  2],
          [ 0, 10,  8,  8,  7, 14, 12,  4,  4,  4, 18,  5,  8,  8,  8,  7, 14, 10,
            9,  9,  4, 18,  6,  4, 13, 13,  7, 14,  8,  8, 11,  4, 18,  6,  9,  8,
           10,  7,  1],
          [ 0, 13, 12,  7,  8, 14, 13, 12, 11,  4, 18,  5, 13, 11,  4,  8, 14,  9,
           11, 13, 13, 18,  6,  9,  9,  4,  7, 14, 12,  9,  9,  4, 18,  7,  8,  4,
            9,  7,  1],
          [ 0,  9,  7, 11, 14, 11,  8,  6,  7, 18, 11, 13, 10,  4, 14, 11,  6, 12,
            7, 18,  5,  9,  6,  8,  7, 14, 11,  5, 12,  8, 18,  6,  6,  8,  6, 11,
            1,  2,  2],
          [ 0,  6, 10, 13,  6, 14,  8,  6, 11,  6, 18, 10, 13, 10,  8, 14,  7, 10,
            9,  9, 18,  5,  4, 10,  5, 13, 14,  6, 10,  9, 18,  5,  4, 12, 12,  8,
            1,  2,  2],
          [ 0, 11,  8,  4,  6

In [3]:
%run 3.model.ipynb

model_actor = Gemma3Actor(len(tokenizer), tokenizer.pad_token_id).to(device)
#model_actor = torch.load('model/actor', weights_only=False).to(device)

sum(i.numel() for i in model_actor.parameters()) / 1_0000_0000

2.01442304

In [4]:
optimizer = torch.optim.Adam(model_actor.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,
                                              base_lr=1e-5,
                                              max_lr=1e-4,
                                              step_size_up=1000)

for epoch in range(40_0000):
    data = next(iter(loader))

    loss, logits = model_actor(**data)
    loss.backward()
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

    if epoch % 1000 == 0:
        logits = logits.argmax(2)

        input_ids = data['input_ids'][:, 1:]
        logits = logits[:, :-1]
        attention_mask = data['attention_mask'][:, 1:] == 1

        input_ids = input_ids[attention_mask]
        logits = logits[attention_mask]

        correct = (input_ids == logits).sum().item()
        total = attention_mask.sum().item()

        print(epoch, loss.item(), correct / total)

        input_ids = data['input_ids'][0]
        last_index = lambda lst, ele: len(lst) - lst[::-1].index(ele)
        idx = last_index(input_ids.tolist(), tokenizer.eq_token_id)
        input_ids = input_ids[:idx]

        gen = generate(model_actor, input_ids.unsqueeze(0),
                       tokenizer.pad_token_id, tokenizer.eos_token_id)

        print({
            'input_ids': tokenizer.decode(input_ids),
            'gen': tokenizer.decode(gen[0])
        })

torch.save(model_actor.to('cpu'), 'model/actor')

0 3.1623668670654297 0.02850877192982456
{'input_ids': 'B2429+8531=10960+7950=', 'gen': 'B2429+8531=10960+7950=E'}
1000 1.7674411535263062 0.311804008908686
{'input_ids': 'B4050+6869=10919+152=', 'gen': 'B4050+6869=10919+152=11520E'}
2000 1.5808123350143433 0.41133004926108374
{'input_ids': 'B1530+7587=9117+4412=13529+1727=', 'gen': 'B1530+7587=9117+4412=13529+1727=15246E'}
3000 1.2982882261276245 0.5050709939148073
{'input_ids': 'B6216+7686=13902+2023=15925+1849=', 'gen': 'B6216+7686=13902+2023=15925+1849=18784E'}
4000 1.2204344272613525 0.5117493472584856
{'input_ids': 'B2999+3549=6548+2826=9374+2962=', 'gen': 'B2999+3549=6548+2826=9374+2962=12336E'}
5000 1.1998157501220703 0.5454545454545454
{'input_ids': 'B6570+4321=10891+774=11665+2902=', 'gen': 'B6570+4321=10891+774=11665+2902=14567E'}
6000 1.1052799224853516 0.5511627906976744
{'input_ids': 'B6956+3305=10261+7753=18014+4699=', 'gen': 'B6956+3305=10261+7753=18014+4699=22713E'}
7000 1.2600265741348267 0.5027472527472527
{'input_id

66000 1.0875331163406372 0.5740740740740741
{'input_ids': 'B7440+6184=', 'gen': 'B7440+6184=13624+1000=14624+1000=15624E'}
67000 1.083970546722412 0.583710407239819
{'input_ids': 'B1429+2456=3885+6846=', 'gen': 'B1429+2456=3885+6846=10731+191=10922E'}
68000 1.133447289466858 0.5418848167539267
{'input_ids': 'B5017+5483=10500+6184=16684+1244=', 'gen': 'B5017+5483=10500+6184=16684+1244=17928E'}
69000 1.1150791645050049 0.5563380281690141
{'input_ids': 'B1797+9340=11137+660=11797+9845=', 'gen': 'B1797+9340=11137+660=11797+9845=21642E'}
70000 1.124192237854004 0.5748898678414097
{'input_ids': 'B5267+291=5558+6128=11686+5159=', 'gen': 'B5267+291=5558+6128=11686+5159=16845E'}
71000 1.0974736213684082 0.5684931506849316
{'input_ids': 'B5266+7827=13093+2893=', 'gen': 'B5266+7827=13093+2893=15986+6069=22055E'}
72000 1.105100154876709 0.555045871559633
{'input_ids': 'B9676+3249=12925+537=', 'gen': 'B9676+3249=12925+537=13462E'}
73000 1.115188717842102 0.5653104925053534
{'input_ids': 'B1910+8931

130000 1.1356825828552246 0.5463659147869674
{'input_ids': 'B7199+9517=', 'gen': 'B7199+9517=16716+1718=18434E'}
131000 1.1292095184326172 0.5630026809651475
{'input_ids': 'B2014+9437=', 'gen': 'B2014+9437=11451+6662=18113E'}
132000 1.1572033166885376 0.5353260869565217
{'input_ids': 'B1052+1357=2409+8460=', 'gen': 'B1052+1357=2409+8460=10869E'}
133000 1.1186416149139404 0.5746835443037974
{'input_ids': 'B4532+7559=', 'gen': 'B4532+7559=12091+1000=13091E'}
134000 1.0913532972335815 0.5532407407407407
{'input_ids': 'B1987+8273=', 'gen': 'B1987+8273=10260+120=10380E'}
135000 1.1195849180221558 0.5531914893617021
{'input_ids': 'B6567+319=6886+172=7058+4422=', 'gen': 'B6567+319=6886+172=7058+4422=11480E'}
136000 1.0659250020980835 0.5949367088607594
{'input_ids': 'B4257+8334=12591+8273=', 'gen': 'B4257+8334=12591+8273=20864+4639=25503E'}
137000 1.1224018335342407 0.5467289719626168
{'input_ids': 'B3928+6497=10425+6959=17384+1646=', 'gen': 'B3928+6497=10425+6959=17384+1646=19030E'}
138000 1

{'input_ids': 'B3794+8955=12749+146=', 'gen': 'B3794+8955=12749+146=12895+3587=16482E'}
196000 1.0691380500793457 0.5926680244399185
{'input_ids': 'B9696+7353=17049+6860=23909+6400=', 'gen': 'B9696+7353=17049+6860=23909+6400=30309E'}
197000 1.1093517541885376 0.5772946859903382
{'input_ids': 'B1997+6881=8878+1896=', 'gen': 'B1997+6881=8878+1896=10774E'}
198000 1.098071575164795 0.5704989154013015
{'input_ids': 'B1871+5797=7668+9976=17644+5370=', 'gen': 'B1871+5797=7668+9976=17644+5370=23014E'}
199000 1.1555887460708618 0.521865889212828
{'input_ids': 'B2984+9413=', 'gen': 'B2984+9413=12397+3222=15619+3222=18841E'}
200000 1.1102302074432373 0.5527638190954773
{'input_ids': 'B9807+8303=', 'gen': 'B9807+8303=18110+8406=26516E'}
201000 1.082602858543396 0.5668202764976958
{'input_ids': 'B8801+7143=', 'gen': 'B8801+7143=15944+3333=19277E'}
202000 1.1152782440185547 0.5591647331786543
{'input_ids': 'B7892+4177=', 'gen': 'B7892+4177=12069+171=12240+1711=13951E'}
203000 1.127286434173584 0.563

259000 1.0679231882095337 0.5848214285714286
{'input_ids': 'B5221+3977=9198+6073=', 'gen': 'B5221+3977=9198+6073=15271+9220=24491E'}
260000 1.106353759765625 0.5675675675675675
{'input_ids': 'B7377+9003=16380+3651=20031+5929=', 'gen': 'B7377+9003=16380+3651=20031+5929=25960E'}
261000 1.1114585399627686 0.55
{'input_ids': 'B9734+3217=12951+3744=16695+7645=', 'gen': 'B9734+3217=12951+3744=16695+7645=24340E'}
262000 1.116966724395752 0.5555555555555556
{'input_ids': 'B1732+1299=', 'gen': 'B1732+1299=3031+488=3519E'}
263000 1.0746439695358276 0.5763747454175153
{'input_ids': 'B5366+4978=10344+5449=15793+1084=', 'gen': 'B5366+4978=10344+5449=15793+1084=16877E'}
264000 1.131766676902771 0.5477386934673367
{'input_ids': 'B8367+428=8795+6202=14997+4523=', 'gen': 'B8367+428=8795+6202=14997+4523=19520E'}
265000 1.1356009244918823 0.5390428211586902
{'input_ids': 'B6556+4074=', 'gen': 'B6556+4074=10630+9148=19778E'}
266000 1.0780097246170044 0.5704989154013015
{'input_ids': 'B1841+4855=6696+2701=

323000 1.1274486780166626 0.5570291777188329
{'input_ids': 'B9160+5297=14457+5666=', 'gen': 'B9160+5297=14457+5666=20123E'}
324000 1.1025716066360474 0.5454545454545454
{'input_ids': 'B3161+438=3599+793=', 'gen': 'B3161+438=3599+793=4392E'}
325000 1.1625964641571045 0.5294117647058824
{'input_ids': 'B4278+3888=8166+2256=10422+3351=', 'gen': 'B4278+3888=8166+2256=10422+3351=13773E'}
326000 1.0847289562225342 0.5505882352941176
{'input_ids': 'B1059+3696=', 'gen': 'B1059+3696=4755+7239=11994+7239=19233E'}
327000 1.0992380380630493 0.5619266055045872
{'input_ids': 'B3772+4682=8454+2785=', 'gen': 'B3772+4682=8454+2785=11239E'}
328000 1.1313751935958862 0.5390625
{'input_ids': 'B7910+7700=15610+2703=', 'gen': 'B7910+7700=15610+2703=18313E'}
329000 1.111573338508606 0.5518072289156627
{'input_ids': 'B8418+1594=', 'gen': 'B8418+1594=10012+274=10286E'}
330000 1.1050945520401 0.5490196078431373
{'input_ids': 'B8624+6966=15590+3957=', 'gen': 'B8624+6966=15590+3957=19547+5237=24784E'}
331000 1.091

{'input_ids': 'B9057+4968=14025+5379=', 'gen': 'B9057+4968=14025+5379=19404+721=20125E'}
387000 1.1004130840301514 0.5688073394495413
{'input_ids': 'B1572+3707=', 'gen': 'B1572+3707=5279+9978=15257E'}
388000 1.12783944606781 0.5413333333333333
{'input_ids': 'B8940+1347=10287+3107=', 'gen': 'B8940+1347=10287+3107=13394E'}
389000 1.089949607849121 0.5814977973568282
{'input_ids': 'B1813+8038=', 'gen': 'B1813+8038=9851+8139=17990E'}
390000 1.1120356321334839 0.5440806045340051
{'input_ids': 'B117+9431=', 'gen': 'B117+9431=9548+4805=14353E'}
391000 1.0887153148651123 0.579064587973274
{'input_ids': 'B5658+7003=12661+8934=', 'gen': 'B5658+7003=12661+8934=21595+4150=25745E'}
392000 1.1123192310333252 0.5580246913580247
{'input_ids': 'B6936+772=7708+3518=', 'gen': 'B6936+772=7708+3518=11226E'}
393000 1.128913402557373 0.5549872122762148
{'input_ids': 'B3907+2256=', 'gen': 'B3907+2256=6163+1802=7965E'}
394000 1.1159926652908325 0.5456790123456791
{'input_ids': 'B5966+4534=', 'gen': 'B5966+4534