You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
def get_collate_function(max_seq_length):
cnt = 0
def collate_function(batch):
nonlocal cnt
length = None
if cnt < 10:
length = max_seq_length
cnt += 1
input_ids = [x["input_ids"] for x in batch]
attention_mask = [x["attention_mask"] for x in batch]
data = {
"input_ids": pack_tensor_2D(input_ids, default=1,
dtype=torch.int64, length=length),
"attention_mask": pack_tensor_2D(attention_mask, default=0,
dtype=torch.int64, length=length),
}
ids = [x['id'] for x in batch]
return data, ids
return collate_function
we see that there is a cnt variable which is deciding if the collate_function should pad or not. I couldn't get why it is needed. Could you please explain the significance of cnt ?
Thank you
AM
The text was updated successfully, but these errors were encountered:
It is a simple trick I used. Some inappropriate hyperparameters may trigger `outofmemory' error during training. Therefore, this code requires the input to have the max sequence length at the beginning of training. Therefore, if the batch size is too big or max seq length is too big, the error will be triggered from the beginning and I can easily know.
You can also delete this code.
In https://github.com/jingtaozhan/DRhard/blob/dc17f3d1f7f59d13d15daa1a728dc8d6efc48b92/dataset.py, if we take a look at the data collator,
we see that there is a
cnt
variable which is deciding if the collate_function should pad or not. I couldn't get why it is needed. Could you please explain the significance ofcnt
?Thank you
AM
The text was updated successfully, but these errors were encountered: