Skip to content

Commit

Permalink
Add a few more docstrings to predict
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 322823171
  • Loading branch information
romanngg committed Jul 23, 2020
1 parent 03be257 commit e30492d
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions neural_tangents/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@


import collections
from jax import lax
from jax.api import grad
from jax.experimental import ode
import jax.numpy as np
Expand All @@ -37,7 +36,7 @@
from neural_tangents.utils import utils, dataclasses
import scipy as osp
from neural_tangents.utils.typing import KernelFn, Axes, Get
from typing import Union, Tuple, Callable, Iterable, Optional, Dict
from typing import Union, Tuple, Callable, Iterable, Optional, Dict, NamedTuple
from functools import lru_cache


Expand Down Expand Up @@ -197,7 +196,6 @@ def predict_fn_finite(t, fx_train_0, fx_test_0, ntk_test_train):
return fx_train_t
return fx_test_t


return predict_fn_finite

def predict_fn(
Expand Down Expand Up @@ -249,6 +247,7 @@ def predict_fn(

@dataclasses.dataclass
class ODEState:
"""ODE state dataclass holding outputs and auxiliary variables."""
fx_train: np.ndarray = None
fx_test: np.ndarray = None
qx_train: np.ndarray = None
Expand Down Expand Up @@ -456,7 +455,6 @@ def predict_fn(
# ODE solver requires `t[0]` to be the time where `fx_train_0` [and
# `fx_test_0`] are evaluated, but also a strictly increasing sequence of
# timesteps, so we always temporarily append an [almost] `0` at the start.
identity = lambda x: x
t0 = np.where(t[0] == 0,
np.full((1,), -1e-24, t.dtype),
np.zeros((1,), t.dtype))
Expand Down Expand Up @@ -488,7 +486,10 @@ def predict_fn(
return predict_fn


Gaussian = collections.namedtuple('Gaussian', 'mean covariance')
class Gaussian(NamedTuple):
"""A `(mean, covariance)` convenience namedtuple."""
mean: np.ndarray
covariance: np.ndarray


def gp_inference(
Expand Down

0 comments on commit e30492d

Please sign in to comment.