In [1]:
import torch
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from mingpt.utils import set_seed


In [2]:

set_seed(1234)
import pickle


In [3]:
class SortDataset(Dataset):
    """ 
    Dataset for the Sort problem. E.g. for problem length 6:
    Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2
    Which will feed into the transformer concatenated as:
    input:  0 0 2 1 0 1 0 0 0 1 1
    output: I I I I I 0 0 0 1 1 2
    where I is "ignore", as the transformer is reading the input sequence
    """

    def __init__(self, split, length=6, num_digits=3):
        assert split in {'train', 'test'}
        self.split = split
        self.length = length
        self.num_digits = num_digits
    
    def __len__(self):
        return 10000 # ...
    
    def get_vocab_size(self):
        return self.num_digits
    
    def get_block_size(self):
        # the length of the sequence that will feed into transformer, 
        # containing concatenated input and the output, but -1 because
        # the transformer starts making predictions at the last input element
        return self.length * 2 - 1

    def __getitem__(self, idx):
        
        # use rejection sampling to generate an input example from the desired split
        while True:
            # generate some random integers
            inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)
            # half of the time let's try to boost the number of examples that 
            # have a large number of repeats, as this is what the model seems to struggle
            # with later in training, and they are kind of rate
            if torch.rand(1).item() < 0.5:
                if inp.unique().nelement() > self.length // 2:
                    # too many unqiue digits, re-sample
                    continue
            # figure out if this generated example is train or test based on its hash
            h = hash(pickle.dumps(inp.tolist()))
            inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test
            if inp_split == self.split:
                break # ok
        
        # solve the task: i.e. sort
        sol = torch.sort(inp)[0]

        # concatenate the problem specification and the solution
        cat = torch.cat((inp, sol), dim=0)

        # the inputs to the transformer will be the offset sequence
        x = cat[:-1].clone()
        y = cat[1:].clone()
        # we only want to predict at output locations, mask out the loss at the input locations
        y[:self.length-1] = -1
        return x, y

In [4]:
# print an example instance of the dataset
train_dataset = SortDataset('train')
test_dataset = SortDataset('test')
x, y = train_dataset[0]
for a, b in zip(x,y):
    print(int(a),int(b))

0 -1
1 -1
0 -1
0 -1
1 -1
1 0
0 0
0 0
0 1
1 1
1 1


In [5]:
# create a GPT instance
from mingpt.model import GPT

model_config = GPT.get_default_config()
model_config.model_type = 'gpt-nano'
model_config.vocab_size = train_dataset.get_vocab_size()
model_config.block_size = train_dataset.get_block_size()
model = GPT(model_config)

number of parameters: 0.09M


In [6]:
# create a Trainer object
from mingpt.trainer import Trainer

train_config = Trainer.get_default_config()
train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster
train_config.max_iters = 1000
train_config.num_workers = 0
trainer = Trainer(train_config, model, train_dataset)

running on device cuda


In [7]:
# https://github.com/pytorch/pytorch/blob/main/torch/profiler/profiler.py

from torch.autograd import kineto_available, ProfilerActivity
from torch.profiler import profile, schedule, tensorboard_trace_handler

tracing_schedule = schedule(skip_first=5, wait=5, warmup=2, active=100, repeat=1)
trace_handler = tensorboard_trace_handler(dir_name="/scratch/user/siweicui/xllm/kineto/tracing/trace_data", use_gzip=False)


with profile(
  activities = [ProfilerActivity.CPU, ProfilerActivity.CUDA],
  schedule = tracing_schedule,
  on_trace_ready = trace_handler,
  profile_memory = True,
  record_shapes = True,
  with_stack = True
) as prof:
    def batch_end_callback(trainer):
      if trainer.iter_num % 100 == 0:
          print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}")
      prof.step()
    trainer.set_callback('on_batch_end', batch_end_callback)

    trainer.run()
    


iter_dt 0.00ms; iter 0: train loss 1.08678


[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
[W kineto_shim.cpp:343] Profiler is not initialized: skipping step() invocation
STAGE:2024-01-21 15:00:27 105174:105174 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


iter_dt 40.80ms; iter 100: train loss 0.21678


STAGE:2024-01-21 15:00:33 105174:105174 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-01-21 15:00:34 105174:105174 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


iter_dt 23.33ms; iter 200: train loss 0.13171
iter_dt 24.50ms; iter 300: train loss 0.05237
iter_dt 24.00ms; iter 400: train loss 0.04231
iter_dt 24.79ms; iter 500: train loss 0.01241
iter_dt 25.42ms; iter 600: train loss 0.02408
iter_dt 24.86ms; iter 700: train loss 0.01282
iter_dt 25.84ms; iter 800: train loss 0.00942
iter_dt 22.99ms; iter 900: train loss 0.00308
