Skip to content

Commit

Permalink
round 1 of some changes. we will now always write in fp32, even if dt…
Browse files Browse the repository at this point in the history
…ype is set to float16 or bfloat16. next up, we actually want to write in lower precision, when the dtype is set so
  • Loading branch information
karpathy committed Apr 25, 2024
1 parent 7a52a21 commit 3fb7252
Showing 1 changed file with 44 additions and 29 deletions.
73 changes: 44 additions & 29 deletions train_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):

# a few utilities for saving params/grads/activations to files for loading in C
def write_fp32(tensor, file):
file.write(tensor.detach().cpu().numpy().astype("float32").tobytes())
file.write(tensor.detach().cpu().to(torch.float32).numpy().tobytes())

def write_tensors(model_tensors, L, file):
write_fp32(model_tensors["transformer.wte.weight"], file) # (V, C)
Expand Down Expand Up @@ -258,9 +258,8 @@ def write_model(model, filename):
header[4] = model.config.n_layer
header[5] = model.config.n_head
header[6] = model.config.n_embd
# 2) the parameters on CPU are next
# 2) the parameters on CPU follow
params = {name: param.cpu() for name, param in model.named_parameters()}
# now write
with open(filename, "wb") as file:
# header
file.write(header.numpy().tobytes())
Expand Down Expand Up @@ -346,7 +345,7 @@ def write_tokenizer(enc, filename):
device = "mps"
print(f"using device: {device}")

# create a context manager following the desired dtype and device
# set up a context manager following the desired dtype and device
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[args.dtype]
ctx = torch.amp.autocast(device_type="cuda", dtype=ptdtype) if device == "cuda" else nullcontext()

Expand All @@ -355,16 +354,17 @@ def write_tokenizer(enc, filename):
if torch.cuda.is_available():
torch.cuda.manual_seed(42)

# init the tokenizer
# set the torch precision mode to use TensorFloat32 (TF32) for matmuls
# docs https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
if args.tensorcores:
torch.set_float32_matmul_precision('high')

# init (and write) the tokenizer
enc = tiktoken.get_encoding("gpt2")
encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"})
decode = lambda l: enc.decode(l)

write_tokenizer(enc, "gpt2_tokenizer.bin")

if args.tensorcores:
torch.set_float32_matmul_precision('high')

# load the GPT-2 model weights
model = GPT.from_pretrained("gpt2")
model.train()
Expand All @@ -375,6 +375,9 @@ def write_tokenizer(enc, filename):
print("compiling the model...")
model = torch.compile(model)

# -------------------------------------------------------------------------
# data loading related: long but it's just to get a single batch of data

# load the tokens
# prefer to use tiny_shakespeare if it's available, otherwise use tiny_stories
# we're using val instead of train split just because it is smaller/faster
Expand Down Expand Up @@ -405,47 +408,59 @@ def get_batch():
if i + B*T + 1 >= len(tokens):
i = 0 # in prod we'd want to randomize the start point a bit

# forward backward for a few iterations
# fetch one batch of data, which we will overfit to
data_iter = iter(get_batch())
x, y = next(data_iter) # we'll overfit this batch below

# -------------------------------------------------------------------------
# STAGE 1: weights / state logging for C to load later

# do one forward pass to generate ground truth for our C tests
if not args.inference_only and args.write_tensors:
assert args.dtype == "float32", "right now can only write tensors in float32"
logits, loss = model(x, y)
loss.backward()
write_model(model, "gpt2_124M.bin")
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")

use_fused = device == "cuda" # only works on CUDA (?)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=use_fused)
timings = []
with ctx:
logits, loss = model(x, y)
loss.backward()
write_model(model, "gpt2_124M.bin")
write_state(model, x, y, logits, loss, "gpt2_124M_debug_state.bin")

# -------------------------------------------------------------------------
# STAGE 2: training loop to get timings

# init the optimizer
adam_use_fused = device == "cuda" # only works on CUDA (?)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, fused=adam_use_fused)

if device == "cuda":
torch.cuda.reset_peak_memory_stats()
timings = []
for i in range(args.num_iterations):
t0 = time.time()
with ctx:
logits, loss = model(x, y)
if not args.inference_only:
optimizer.zero_grad()
del logits
loss.backward()
optimizer.step()
if not args.inference_only:
optimizer.zero_grad()
loss.backward()
optimizer.step()
# wait on the CPU for all device work to end so we get accurate per-iteration timings below
if device == "mps":
torch.mps.synchronize()
elif device == "cuda":
torch.cuda.synchronize()
# time and print
t1 = time.time()
if i > args.num_iterations - 20:
# the 0th iteration is often an outlier (much slower) => skip logging it
if i > 0 and i > args.num_iterations - 20:
timings.append(t1-t0)
print(f"iteration {i}, loss: {loss.item()}, time: {(t1-t0)*1000:.3f}ms")

if len(timings) > 20:
print(f"final 20 iters avg: {np.mean(timings[-20:])*1000:.3f}ms")
else:
print(f"final {len(timings)-1} iters avg: {np.mean(timings[1:])*1000:.3f}ms")
# print the average of the last 20 timings, to get something smooth-ish
timings = timings[-20:]
print(f"final {len(timings)} iters avg: {np.mean(timings)*1000:.3f}ms")
print(f"peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")

print(f"Peak memory consumption: {torch.cuda.max_memory_allocated() // 1024 // 1024} MiB")
# -------------------------------------------------------------------------
# STAGE 3: Few steps of inference

# before we end, let's also do one round of inference
# we'll kick off the generation with "<|endoftext|>", which designates the start of a new sequence
Expand Down

0 comments on commit 3fb7252

Please sign in to comment.