Skip to content

Commit

Permalink
remove unused
Browse files Browse the repository at this point in the history
  • Loading branch information
fattorib committed Jun 5, 2023
1 parent ee998d4 commit 453ef9a
Showing 1 changed file with 0 additions and 22 deletions.
22 changes: 0 additions & 22 deletions src/training/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,5 @@
import jax.numpy as jnp
from jax import random


def initialized(key: random.PRNGKey, model: nn.Module, input_shape: Tuple[int, int]):
"""Initializes param dict for a model
Args:
key (_type_): _description_
image_size (_type_): _description_
model (_type_): _description_
Returns:
_type_: _description_
"""

init_batch = jnp.ones((input_shape), dtype=jnp.int32)

def init(rng, init_batch):
return model.init(rng, init_batch, None, False)

jit_apply = jax.jit(init, backend="cpu")
variables = jit_apply(rng=key, init_batch=init_batch)
return variables


def compute_tokens_seen(absolute_step, max_context):

return absolute_step * max_context

0 comments on commit 453ef9a

Please sign in to comment.