Skip to content

Commit

Permalink
Add train script that runs on the TPU VM (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
nostalgebraist authored Jun 14, 2021
1 parent b608bea commit 4ea1a1a
Show file tree
Hide file tree
Showing 2 changed files with 298 additions and 1 deletion.
289 changes: 289 additions & 0 deletions device_train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
import argparse
import json
import time

import jax
import numpy as np
import optax

import wandb
from tqdm import tqdm


from mesh_transformer import util
from mesh_transformer.checkpoint import read_ckpt, write_ckpt
from mesh_transformer.transformer_shard import CausalTransformer
from tfrecord_loader import TFRecordNewInputs
from smart_open import open
from google.cloud import storage
from google.cloud.exceptions import NotFound

from mesh_transformer.util import clip_by_global_norm, additive_weight_decay


def parse_args():
# Parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=None, help="Config file location")
parser.add_argument("--tune-model-path", type=str, default=None, help="Base model to finetune")

args = parser.parse_args()
return args


def save(network, step, bucket, path, mp, aux=None, keep_n=3, delete_old=True):
assert path
client = storage.Client()

if aux is None:
aux = {}

try:
with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
meta = json.load(f)
except:
# create metadata file
with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
json.dump({
"step": 0,
"checkpoints": [],
"aux": {}
}, f)

# do sharded checkpoint writing
start = time.time()
res = []
for shard_id in range(mp):
write_ckpt(network.state, f"gs://{bucket}/{path}/step_{step}/", shard_id)

print(f"Wrote checkpoint in {time.time() - start:.06}s")

with open(f"gs://{bucket}/{path}/meta.json", "r") as f:
meta = json.load(f)

meta["step"] = step
meta["checkpoints"].append(step)
all_aux = meta.get("aux", {})

while len(meta["checkpoints"]) > keep_n:
ckpt_to_delete = meta["checkpoints"].pop(0)

try:
del all_aux[str(ckpt_to_delete)]
except:
print(f"failed to delete the aux state for {step}")

if delete_old:
print(f"deleting checkpoint {ckpt_to_delete}")
for blob in client.list_blobs(bucket, prefix=f"{path}/step_{ckpt_to_delete}/"):
# print(f"deleting {blob.name}")
assert path in blob.name
blob.delete()
else:
print(f"keeping checkpoint {ckpt_to_delete}")

all_aux[step] = aux
meta["aux"] = all_aux

with open(f"gs://{bucket}/{path}/meta.json", "w") as f:
json.dump(meta, f)


def train_step(network, data):
inputs = {
"obs": data[:, :, :-1],
"target": data[:, :, 1:],
}

loss, last_loss = network.train(inputs)

return np.array(loss).mean(), np.array(last_loss).mean()


def eval_step(network, data):
inputs = {
"obs": data[:, :-1],
"target": data[:, 1:],
}

out = network.eval(inputs)
loss = out["loss"]

return np.array(loss).mean()


if __name__ == "__main__":
args = parse_args()
params = json.load(open(args.config))

gradient_accumulation_steps = params.get("gradient_accumulation_steps", 1)
per_replica_batch = params["per_replica_batch"]
cores_per_replica = params["cores_per_replica"]

assert cores_per_replica <= 8

bucket = params["bucket"]
model_dir = params["model_dir"]
layers = params["layers"]
d_model = params["d_model"]
n_heads = params["n_heads"]
n_vocab = params["n_vocab"]
seq = params["seq"]
norm = params["norm"]

val_batches = params["val_batches"]
val_every = params["val_every"]
ckpt_every = params["ckpt_every"]
keep_every = params["keep_every"]
eval_tasks = params["eval_harness_tasks"]
total_steps = params["total_steps"]

pe = params["pe"]
assert pe in ["fixed", "rotary", "t5"]

warmup_steps = params["warmup_steps"]
anneal_steps = params["anneal_steps"]
lr = params["lr"]
end_lr = params["end_lr"]
weight_decay = params["weight_decay"]
step_shift = params.get("step_shift", 0)

opt = optax.chain(
optax.scale(1 / gradient_accumulation_steps),
clip_by_global_norm(1),
optax.scale_by_adam(),
additive_weight_decay(weight_decay),
optax.scale(-1),
optax.scale_by_schedule(util.gpt3_schedule(warmup_steps, anneal_steps, lr, end_lr, step_shift))
)

params["optimizer"] = opt

start = time.time()
tpu_size = jax.device_count()
if tpu_size < cores_per_replica:
msg = f"each shard needs a separate device, but device count ({tpu_size}) < shard count ({cores_per_replica})"
raise ValueError(msg)
print(f"jax devices: {tpu_size}")
print(f"jax runtime initialized in {time.time() - start:.06}s")

mesh_shape = (tpu_size // cores_per_replica, cores_per_replica)
devices = np.array(jax.devices()).reshape(mesh_shape)

# pick initial ckpt - based on tuning vs train from scratch

step = 0
initial_ckpt_state_path = None
train_loader = None

if args.tune_model_path:
initial_ckpt_state_path = args.tune_model_path
print('we are fine-tuning')
else:
print('we are not fine-tuning')
initial_ckpt_model_dir = model_dir
initial_ckpt_path = f"gs://{bucket}/{initial_ckpt_model_dir}"
meta_path = f"{initial_ckpt_path}/meta.json"

try:
with open(meta_path, "r") as f:
meta = json.load(f)
ckpt_step = meta["checkpoints"][-1]
initial_ckpt_state_path = f"{initial_ckpt_path}/step_{ckpt_step}/"
print(f"state will be restored from checkpoint {ckpt_step}")

step = ckpt_step
train_loader = meta['aux'][str(ckpt_step)].get("train_loader", None)
except NotFound:
# no checkpoint, start at zero
print(f"No checkpoint to load at {initial_ckpt_path}. Training from scratch.")

if initial_ckpt_state_path:
print(f"path to load checkpoint from: {initial_ckpt_state_path}")
else:
print("not loading from a checkpoint")

# set up datasets
print("setting up datasets")

train_dataset = TFRecordNewInputs(f"data/{params['train_set']}",
batch_size=(
gradient_accumulation_steps,
per_replica_batch * tpu_size // cores_per_replica),
sample_size=params['seq'],
restore_state=train_loader)

global_val_batch = per_replica_batch * tpu_size // cores_per_replica

val_sets = {}

for k, v in params['val_set'].items():
val_sets[k] = TFRecordNewInputs(f"data/{v}",
batch_size=(global_val_batch,),
sample_size=seq)

# tok/sec metrics
windows_per_step = gradient_accumulation_steps * (per_replica_batch * tpu_size // cores_per_replica)
tokens_per_step = params['seq'] * windows_per_step

# load + run
with jax.experimental.maps.mesh(devices, ('dp', 'mp')):
print("initializing network")
network = CausalTransformer(params)

if initial_ckpt_state_path:
print("loading network")
start = time.time()
network.state = read_ckpt(network.state, initial_ckpt_state_path, devices.shape[1])
print(f"network loaded in {time.time() - start:.06}s")

print('compiling train fn')
start = time.time()
train_step(network, train_dataset.get_samples())
step += 1
print(f"Train fn compiled in {time.time() - start:.06}s")

print('compiling eval fn')
start = time.time()
for val_set in val_sets.values():
eval_step(network, val_set.get_samples())
val_set.reset()
print(f"Eval fn compiled in {time.time() - start:.06}s")

wandb.init(project='mesh-transformer-jax', name=params["name"], config=params)

while True:
if (step % ckpt_every == 1) or step == total_steps:
print(f"saving a checkpoint for step {step}")
save(network, step, bucket, model_dir,
mp=cores_per_replica,
aux={"train_loader": train_dataset.get_state()},
delete_old=True,
)

if step == total_steps:
print("training completed!")
exit()

if step % val_every == 1: # 1 because we've already taken a step to compile train fn
for name, val_set in val_sets.items():
val_loss = []
for i, _ in tqdm(zip(val_set.sample_once(), range(val_batches)),
desc=f"validation for step {step}, set {name}",
total=val_batches):
val_loss.append(eval_step(network, i))
val_set.reset()

val_loss = np.array(val_loss).mean()
print(f"validation loss for step {step}, set {name}: {val_loss}")

wandb.log({f'val/loss_{name}': float(val_loss)}, step)

start = time.time()
loss, last_loss = train_step(network, train_dataset.get_samples())
step += 1

steps_per_sec = 1 / (time.time() - start)
tokens_per_sec = tokens_per_step * steps_per_sec

wandb.log({'train/loss': loss, 'train/last_loss': last_loss, 'train/steps_per_sec': steps_per_sec, 'train/tokens_per_sec': tokens_per_sec}, step)
10 changes: 9 additions & 1 deletion tfrecord_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,14 @@ def __init__(self, index_fname, batch_size, parse_fn, map_fn=None, restore_state

self.sample_fn = self.sample_once()

def reset(self):
self.file_idx = 0
self.file_idx_init = True
self.used = []

self.clean_index = list(filter(lambda x: x not in self.used, self.index))
self.sample_fn = self.sample_once()

def sample_once(self):
for i in self.clean_index:
compression = "ZLIB" if "zstd" in i else ""
Expand Down Expand Up @@ -56,7 +64,7 @@ def get_samples(self):
try:
return next(self.sample_fn)
except StopIteration:
self.sample_fn = self.sample_once()
self.reset()
return self.get_samples()

def get_state(self):
Expand Down

0 comments on commit 4ea1a1a

Please sign in to comment.