Skip to content

Commit

Permalink
1) Support transpose convolution
Browse files Browse the repository at this point in the history
2) Allow circular padding with masking
3) Minor doc tweaks
4) deflake tests by raising tolerance
5) Relax some typing requirements from Tuples to Sequences.

Co-authored-by: Kayhan Batmanghelich <kayhan@pitt.edu>
PiperOrigin-RevId: 330662849
  • Loading branch information
romanngg and kayhan-batmanghelich committed Sep 9, 2020
1 parent 8ff7a13 commit 780ad0c
Show file tree
Hide file tree
Showing 5 changed files with 895 additions and 352 deletions.
10 changes: 7 additions & 3 deletions neural_tangents/predict.py
Expand Up @@ -37,7 +37,7 @@
from neural_tangents.utils import utils, dataclasses
import scipy as osp
from neural_tangents.utils.typing import KernelFn, Axes, Get
from typing import Union, Tuple, Callable, Iterable, Optional, Dict, NamedTuple, Generator
from typing import Union, Tuple, Callable, Iterable, Optional, Dict, NamedTuple, Sequence, Generator
from functools import lru_cache


Expand Down Expand Up @@ -741,7 +741,7 @@ def gradient_descent_mse_ensemble(

k_dd_cache = {}

def get_k_train_train(get: Tuple[str, ...]) -> _Kernel:
def get_k_train_train(get: Sequence[str]) -> _Kernel:
if len(get) == 1:
get = get[0]
if get not in k_dd_cache:
Expand Down Expand Up @@ -1000,7 +1000,11 @@ def _get_fns_in_eigenbasis(
Args:
k_train_train:
an n x n matrix
an n x n matrix.
diag_reg:
diagonal regularizer strength.
diag_reg_absolute_scale:
`True` to use absolute (vs relative to mean trace) regulatization.
fns:
a sequence of functions that add on the eigenvalues (evals, dt) ->
modified_evals.
Expand Down

0 comments on commit 780ad0c

Please sign in to comment.