Skip to content

Commit

Permalink
add min_seq feature to batch length shrinking in eval harness
Browse files Browse the repository at this point in the history
  • Loading branch information
kingoflolz committed Jul 26, 2021
1 parent 48f00e6 commit 63ac298
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
10 changes: 7 additions & 3 deletions tasks/eval_harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,13 @@ def greedy_until(self, requests):
def loglikelihood_rolling(self, requests):
raise Exception("unimplemented")

def __init__(self, tpu_cluster, seq, batch, shrink):
def __init__(self, tpu_cluster, seq, batch, shrink, min_seq=None):
super().__init__()
self.tpu = tpu_cluster
self.seq = seq
self.batch = batch
self.shrink = shrink
self.min_seq = min_seq

self.pool = multiprocessing.Pool(initializer=process_init)
process_init()
Expand All @@ -72,9 +73,12 @@ def loglikelihood(self, requests):
r = self.convert_requests(requests)
zero_example = process_request(requests[0], self.seq)

for b in tqdm(sample_batch(r, self.batch, zero_example), desc="LM eval harness", total=len(requests) // self.batch):
for b in tqdm(sample_batch(r, self.batch, zero_example),
desc="LM eval harness",
total=len(requests) // self.batch):

if self.shrink:
b = shrink_seq(b)
b = shrink_seq(b, min_seq=self.min_seq)

out = self.tpu.eval(b)

Expand Down
8 changes: 6 additions & 2 deletions tasks/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,23 @@ def grouper(n, iterable, fillvalue):


# divide the seq length by 2 until it would truncate actual context
def shrink_seq(examples):
def shrink_seq(examples, min_seq=None):
length = examples["obs"].shape[-1]

new_length = length // 2

if min_seq is not None:
if new_length < min_seq:
return examples

max_length = np.max(examples["eval_mask"] * np.arange(0, length)) + 1

if max_length < new_length:
examples["obs"] = examples["obs"][:, :new_length]
examples["target"] = examples["target"][:, :new_length]
examples["eval_mask"] = examples["eval_mask"][:, :new_length]

return shrink_seq(examples)
return shrink_seq(examples, min_seq=min_seq)
else:
return examples

Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ def parse_args():
sample_size=seq)

# use dynamic seq length unless pe is fixed
adaptor = EvalHarnessAdaptor(t, seq, global_val_batch * 4, shrink=pe != "fixed")
adaptor = EvalHarnessAdaptor(t,
seq,
global_val_batch,
shrink=pe != "fixed",
min_seq=1024 if args.version == 2 else None) # work around suboptimal pjit layout

start = time.time()
t.train(train_dataset.get_samples())
Expand Down

0 comments on commit 63ac298

Please sign in to comment.