In [12]:
from transformers import GPTNeoXForCausalLM
from datasets import load_dataset
import torch
import gc

In [2]:
checkpoint_dataset = load_dataset(
    "rdiehlmartinez/pythia-pile-presampled",
    "checkpoints",
    split='train',
)

model_sizes = ["70m", "160m", "410m", "1b", "1.4b", "2.8b"]

# checkpoint step stored by pythia 

# checkpointing steps used in evaluation by pythia 
checkpoint_steps = [0, 1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1000, ]
checkpoint_steps.extend([3000 + (i * 10000) for i in range(0, 15)])

ORIGINAL_BATCH_SIZE = 1024 
REDUCED_BATCH_SIZE = 128 

MAX_STEP = 142_999 # Last step in training (used to index final batc)

# NOTE: setting up the data batch sizes 

ordered_steps = list(set(checkpoint_dataset['step']))
ordered_steps.sort()
step_to_start_index = {step: i*1024 for i, step in enumerate(ordered_steps)}

Resolving data files: 100%|██████████| 77/77 [00:01<00:00, 55.60it/s]


In [3]:
def get_data_batch(step, include_labels=True):
    """
    Get a data batch for a given step in the training process.
    """

    assert(step in step_to_start_index), f"Step {step} not valid checkpoint step."
    start_idx = step_to_start_index[step]
    end_idx = start_idx + 1024

    return {
        "input_ids": torch.tensor(checkpoint_dataset[start_idx:end_idx]['ids'], device='cuda'),
        "labels": torch.tensor(checkpoint_dataset[start_idx:end_idx]['ids'], device='cuda') if include_labels else None
    }


In [4]:
model_size = "70m"
checkpoint_step = 1000

model = GPTNeoXForCausalLM.from_pretrained(
    f"EleutherAI/pythia-{model_size}-deduped",
    revision=f"step{checkpoint_step}",
    cache_dir=f"./pythia-{model_size}-deduped/step{checkpoint_step}",
).to('cuda')


  return self.fget.__get__(instance, owner)()


In [5]:
data_batch = get_data_batch(checkpoint_step)

In [6]:
data_batch.keys()

dict_keys(['input_ids', 'labels'])

In [10]:
def forward_pass(model, batch, debug=False, verbose=False):
    """
    Perform a forward pass of the model on a given batch of data; assumes that the model 
    has hooks setup to save the hidden states at each layer.
    """
    if debug:
        torch.cuda.memory._record_memory_history(max_entries=100000)
        # split up the last batch into smaller batches that can fit on the GPU 
        # automatically find the largest batch size that can fit on the GPU
        # and then use that to split up the last batch

    batch_index = 0

    total_loss = 0.0

    batch_size = 1
    static_batch_size = None # NOTE: static_batch_size is only set when batch size is reduced

    while batch_index < REDUCED_BATCH_SIZE:

        if verbose:
            print("START OF LOOP")
            print("memory: ", torch.cuda.memory_allocated()/1e9, "GB")

        try:
            if static_batch_size is None:
                _batch_size = batch_size
            else: 
                # NOTE: reached when we've run out of memory and have reduced the batch size
                _batch_size = static_batch_size

            batch_end_index = min(batch_index + _batch_size, REDUCED_BATCH_SIZE)

            if verbose:
                print(f"Batch index: {batch_index}, Batch end index: {batch_end_index}")
                print(f"Batch size: {_batch_size}")

            _inputs = batch['input_ids'][batch_index : batch_end_index]

            if verbose:
                print(f"Shape of current sub-batch inputs: {_inputs.shape}")

            _labels = batch['labels'][batch_index : batch_end_index]

            if verbose:
                print("AFTER INPUTS")
                print("memory: ", torch.cuda.memory_allocated()/1e9, "GB")

            _loss = model(_inputs, labels=_labels).loss.item()

            if verbose:
                print("AFTER MODEL")
                print("memory: ", torch.cuda.memory_allocated()/1e9, "GB")

        except RuntimeError:
            # NOTE: Exception is thrown when the batch size is too large for the GPU

            if batch_size == 1:
                raise Exception("Batch size of 1 is too large for the GPU")

            _batch_size //= 2
            static_batch_size = _batch_size
            if verbose:
                print(f"Reducing batch size to: {_batch_size}")

            gc.collect()
            torch.cuda.empty_cache()

            continue

        total_loss += _loss * _batch_size

        batch_index = batch_end_index

        if static_batch_size is None:
            batch_size *= 2

    if debug:
        torch.cuda.memory._dump_snapshot("memory_snapshot_NA.pickle")
        torch.cuda.memory._record_memory_history(enabled=None)

    return total_loss/REDUCED_BATCH_SIZE

In [13]:
total_loss = forward_pass(model, data_batch, verbose=True)

START OF LOOP
memory:  71.414570496 GB
Batch index: 0, Batch end index: 1
Batch size: 1
Shape of current sub-batch inputs: torch.Size([1, 2049])
AFTER INPUTS
memory:  71.414570496 GB
AFTER MODEL
memory:  71.414570496 GB
START OF LOOP
memory:  71.414570496 GB
Batch index: 1, Batch end index: 3
Batch size: 2
Shape of current sub-batch inputs: torch.Size([2, 2049])
AFTER INPUTS
memory:  71.414570496 GB
AFTER MODEL
memory:  71.414570496 GB
START OF LOOP
memory:  71.414570496 GB
Batch index: 3, Batch end index: 7
Batch size: 4
Shape of current sub-batch inputs: torch.Size([4, 2049])
AFTER INPUTS
memory:  71.414570496 GB
AFTER MODEL
memory:  71.414570496 GB
START OF LOOP
memory:  71.414570496 GB
Batch index: 7, Batch end index: 15
Batch size: 8
Shape of current sub-batch inputs: torch.Size([8, 2049])
AFTER INPUTS
memory:  71.414570496 GB
Reducing batch size to: 4
START OF LOOP
memory:  71.414570496 GB
Batch index: 7, Batch end index: 11
Batch size: 4
Shape of current sub-batch inputs: torch.

In [14]:
total_loss

3.699067138135433

In [7]:
sub_batch = {
    "input_ids": data_batch['input_ids'][:8],
    "labels": data_batch['labels'][:8]
}

with torch.no_grad():
    output = model(**sub_batch) 

In [9]:
output.loss.item()

3.837646007537842

In [16]:
# get max context length of model 
model.config 

GPTNeoXConfig {
  "_name_or_path": "EleutherAI/pythia-70m-deduped",
  "architectures": [
    "GPTNeoXForCausalLM"
  ],
  "attention_bias": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.1,
  "eos_token_id": 0,
  "hidden_act": "gelu",
  "hidden_dropout": 0.0,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 2048,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 2048,
  "model_type": "gpt_neox",
  "num_attention_heads": 8,
  "num_hidden_layers": 6,
  "rope_scaling": null,
  "rotary_emb_base": 10000,
  "rotary_pct": 0.25,
  "tie_word_embeddings": false,
  "torch_dtype": "float16",
  "transformers_version": "4.37.1",
  "use_cache": true,
  "use_parallel_residual": true,
  "vocab_size": 50304
}