In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils

#自作ライブラリ
import tokenizer as tk
import mini_transformer as mini

In [2]:
batch_size = 64 # how many independent sequences will we process in parallel?
block_size = 256 # what is the maximum context length for predictions?
max_iters = 5000
eval_interval = 500
learning_rate = 3e-4
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
n_layer = 6
dropout = 0.2

In [3]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [4]:
tokenizer = tk.Tokenizer()
tokenizer.set_list(text)

In [5]:
train_data, test_data = tokenizer.train_test_split(text)

In [6]:
model = mini.GPTLanguageModel(tokenizer.chars).to(device)

In [7]:
@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = tokenizer.get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

In [13]:
# print the number of parameters in the model
print(sum(p.numel() for p in model.parameters())/1e6, 'M parameters')

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = tokenizer.get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

10.788929 M parameters
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64,)
1003597 (64

KeyboardInterrupt: 

In [12]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(tokenizer.decoder(model.generate(context, max_new_token=500)[0].tolist()))
#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))


ySoh,' ce ttye oeatRe thraD Iflg:a,,sgnptOeamhoahereo  nene s, i
ameN i oudhdr  es rU frrNuorH,sr,yrHonisfsdh,m eni  w Lg
qthidtthdeh andha'h as th  .Nt m
He. anpg
Iedoauim u's sp,  ss
lndtnwhatnpthinimwo adrfutthe  t
 Cl heyl KhcS scon
Ktye a lhtsohrieith r,cnGn aeuer sD.ihein haveI a
Iatyo w
:enh  numgt iOithaDnthdrN ev?N
 iyTe  sIf nl
he f eaIcRoe head:
hhwEr.c
AiAn:vsd na
d haemetc  w
D,se dhEone,
  ay:
ena  oAtFRuora
oe e :  a ;saIPte CInerlnthhauoisre,inde nRo a: n oo all,wovinc ddlGinoter
