diff --git a/tests/core/test_batch.py b/tests/core/test_batch.py index 7bdcbda7..1e218800 100644 --- a/tests/core/test_batch.py +++ b/tests/core/test_batch.py @@ -445,7 +445,7 @@ def test_ConstantTokenNumSampler(self): sample_count = 0 for batch_x, batch_y in data_iter: sample_count += len(batch_x['seq_len']) - self.assertTrue(sum(batch_x['seq_len'])<120) + self.assertTrue(sum(batch_x['seq_len'])<=120) self.assertEqual(sample_count, num_samples) """