Skip to content

Commit

Permalink
Removing useless args and kwargs arguments in tasks.
Browse files Browse the repository at this point in the history
  • Loading branch information
fazega committed Apr 12, 2024
1 parent dad46bc commit 2b8eb4b
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 31 deletions.
5 changes: 1 addition & 4 deletions tasks/cs/bucket_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,12 @@ class BucketSort(task.GeneralizationTask):
1110001 -> 0001111 (with `vocab_size = 2`)
"""

def __init__(self, *args, vocab_size: int = 5, **kwargs) -> None:
def __init__(self, vocab_size: int = 5) -> None:
"""Initializes the task.
Args:
*args: The args for the base task class.
vocab_size: The size of the alphabet. We use 5 in the paper.
**kwargs: The kwargs for the base task class.
"""
super().__init__(*args, **kwargs)
self._vocab_size = vocab_size

@functools.partial(jax.jit, static_argnums=(0, 2, 3))
Expand Down
6 changes: 1 addition & 5 deletions tasks/cs/duplicate_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,12 @@ class DuplicateString(task.GeneralizationTask):
Note that the sampling is jittable so this task is fast.
"""

def __init__(self, *args, vocab_size: int = 2, **kwargs):
def __init__(self, vocab_size: int = 2) -> None:
"""Initializes the remember_string task.
Args:
vocab_size: The size of the alphabet. We use 2 in the paper.
*args: Args for the base task class.
**kwargs: Kwargs for the base task class.
"""
super().__init__(*args, **kwargs)

self._vocab_size = vocab_size

@functools.partial(jax.jit, static_argnums=(0, 2, 3))
Expand Down
6 changes: 1 addition & 5 deletions tasks/cs/odds_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,12 @@ class OddsFirst(task.GeneralizationTask):
Note that the sampling is jittable so this task is fast.
"""

def __init__(self, *args, vocab_size: int = 2, **kwargs):
def __init__(self, vocab_size: int = 2) -> None:
"""Initializes the odds_first task.
Args:
vocab_size: The size of the alphabet. We use 2 in the paper.
*args: Args for the base task class.
**kwargs: Kwargs for the base task class.
"""
super().__init__(*args, **kwargs)

self._vocab_size = vocab_size

@functools.partial(jax.jit, static_argnums=(0, 2, 3))
Expand Down
5 changes: 1 addition & 4 deletions tasks/dcf/modular_arithmetic_brackets.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,13 @@ def generate_raw_dataset(
class ModularArithmeticBrackets(task.GeneralizationTask):
"""A task with the goal of reducing an arithmetic expression with brackets."""

def __init__(self, *args, modulus: int = 5, mult: bool = False, **kwargs):
def __init__(self, modulus: int = 5, mult: bool = False) -> None:
"""Initializes the modular arithmetic task.
Args:
*args: Args for the base task class.
modulus: The modulus used for the computation. We use 5 in the paper.
mult: Whether to add multiplication or use only '+' and '-'.
**kwargs: Kwargs for the base task class.
"""
super().__init__(*args, **kwargs)
self._modulus = modulus
self._mult = mult

Expand Down
5 changes: 1 addition & 4 deletions tasks/dcf/solve_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,12 @@ class SolveEquation(task.GeneralizationTask):
multiple solutions (multiplication by zero).
"""

def __init__(self, *args, modulus: int = 5, **kwargs):
def __init__(self, modulus: int = 5) -> None:
"""Initializes the modular arithmetic task.
Args:
*args: Args for the base task class.
modulus: The modulus used for the computation. We use 5 in the paper.
**kwargs: Kwargs for the base task class.
"""
super().__init__(*args, **kwargs)
self._modulus = modulus

def sample_batch(
Expand Down
14 changes: 5 additions & 9 deletions tasks/regular/modular_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,22 +132,18 @@ class ModularArithmetic(task.GeneralizationTask):
Note that the input strings are always of odd length.
"""

def __init__(self,
*args,
modulus: int = 5,
operators: Optional[Sequence[str]] = None,
**kwargs):
def __init__(
self,
modulus: int = 5,
operators: Optional[Sequence[str]] = None,
) -> None:
"""Initializes the modular arithmetic task.
Args:
*args: Args for the base task class.
modulus: The modulus used for the computation. We use 5 in the paper.
operators: Operators to be used in the sequences. By default it's None,
meaning all operators available are used.
**kwargs: Kwargs for the base task class.
"""
super().__init__(*args, **kwargs)

self._modulus = modulus
if operators is None:
operators = ('+', '*', '-')
Expand Down

0 comments on commit 2b8eb4b

Please sign in to comment.