In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import Dataloder
from mingpt.utils import set_seed
set _seed(3407)


import pickle

  class SortDataset(Dataset):

    def __init__(self, split, length=6, num_digits=3):
      assert split in {"train", test}
      self.split = split
      self.length = length
      self.num_digits = num_digits

    def __len__(self):
      return 10000

    def get_vocab_size(self):
      return self.num_digits

    def get_block_size(self):
      return self.length *2 -1

    def __getitem__(self,idx):
      while true:
        inp = torch.randit(self.num_digits, size =(self.length,), dtype=torch.long)
        if torch.rand(1).item() < 0.5:
          if inp.unique().nelement() > self.length // 2:
            continue

        h = hash(pickle.dumps(inp.tolist()))
        inp_split = "test" if h % 4 == 0 else "train"
        if inp_split == self.split:
          break

      sol = torch.sort(inp)[0]

      cat = torch.cat((inp, sol), dim=0)

      x = cat[:-1].clone()
      y= cat[1:].clone()

      y[:self.length-1] = -1
      return x,y

train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
x,y = train_dataset[0]
for a,b in zip(x,y):
  print(int(a), int(b))

from mingpt.model import GPT

model_config = gpt.get_default_config()
model_config.model_type = "gpt-nano"
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = gpt(model_config)

from mingpt.trainer import Trainer

train_config = trainer.get_default_config()
train_config.learning_rate = 5e-4
train_config.max_iters = 2000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

def batch_end_callback(trainer):
  if trainer.iter_num % 100 == 0:
    print(f"iter_dt{trainer.iter_dt * 1000:.2f} ms; iter{trainer.iter_num}: train loss {trainer.loss.item():.5f} ")

trainer.set_callback("on_batch_end", batch_end_callback)

trainer.run()

model.eval();

In [None]:
def eval_split(trainer, split, max_batches):
  dataset = {'train': train_dataset, 'test':test_dataset}[split]
  n = train_dataset.length
  results = []
  mistakes_printed_already = 0
  loader = DataLoader(dataset, batch_size=100,num_workers=0, drop_last = false)
  for b, (x,y) in enumerate(loader):
    x = x.to(trainer.device)
    y = y.to(trainer.device)
    inp = x[:,:n]
    sol = y[:,-n:]
    cat = model.generate(inp, n, do_sample= false)
    sol_candidate = cat[:,n:]
    correct = (sol == sol_candidate).all(1).cpu()
    for i in range(x.size(0)):
      results.append(int(correct[i]))
      if not correct[i] and mistakes_printed_already < 3:
        mistakes_printed_already += 1
        print("gpt ..%s sorted is %s but gt is %s" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))
    if max_batches is not none and b+1 >= max_batches:
      break
  rt = torch.tensor(results, dtype = torch.float)
  print("%s funal score %d/%d = %.2f %% correct" % (split, rt.sum(), len(results), 100 *rt.mean()))
  return rt.sum()

with torch.no_grad():
  train_score = eval_split(trainer, "train", max_batches=50)
  test_score =eval_split(trainer, "test", max_batches=50)

n = train_dataset.length
inp = torch.tensor([[0,0,2,1,0,1]], dtype = torch.long).to(trainer.device)
assert inp[0].nelements() == n
with torch.nograd():
  cat = model.generate(inp, n, do_sample=false)
sol = torch.sort(inp[0])[0]
sol_candidate = cat[:, n:]
print("input sequence :", inp.tolist())
print("predicted sorted:", sol_candidate.tolist())
print("get sort :", sol.tolist())
print("matches :", bool((sol == sol_candidate).all()))