-
Notifications
You must be signed in to change notification settings - Fork 0
/
einops.py
79 lines (56 loc) · 2.16 KB
/
einops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
r"""Einstein-like tensor operations
References:
| Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation (Rogozhnikov et al., 2022)
| https://openreview.net/forum?id=oapKSVM2bcj
"""
__all__ = [
'Rearrange',
'Reduce',
'Repeat',
]
import einops
import jax
from jax import Array
from typing import *
from .module import Module
class Rearrange(Module):
r"""Creates an axis rearrangement layer.
This module is a thin wrapper around :func:`einops.rearrange`.
Arguments:
pattern: The axis rearrangement pattern. For example, the pattern
:py:`'A B C -> C (A B)'` moves and flattens the two first axes.
lengths: The lengths of the axes.
"""
def __init__(self, pattern: str, **lengths: int):
self.pattern = pattern
self.lengths = lengths
def __call__(self, x: Array) -> Array:
return einops.rearrange(x, self.pattern, **self.lengths)
class Reduce(Module):
r"""Creates an axis reduction layer.
This module is a thin wrapper around :func:`einops.reduce`.
Arguments:
pattern: The axis rearrangement pattern. For example, the pattern
:py:`'A B C -> A C'` reduces the second axis.
reduction: The type of reduction (:py:`'sum'`, :py:`'mean'`, :py:`'max'`, ...).
lengths: The lengths of the axes.
"""
def __init__(self, pattern: str, reduction: str = 'sum', **lengths: int):
self.pattern = pattern
self.reduction = reduction
self.lengths = lengths
def __call__(self, x: Array) -> Array:
return einops.reduce(x, self.pattern, self.reduction, **self.lengths)
class Repeat(Module):
r"""Creates an axis repetition layer.
This module is a thin wrapper around :func:`einops.repeat`.
Arguments:
pattern: The axis rearrangement pattern. For example, the pattern
:py:`'A B -> A C B'` inserts a new axis.
lengths: The lengths of the axes.
"""
def __init__(self, pattern: str, **lengths: int):
self.pattern = pattern
self.lengths = lengths
def __call__(self, x: Array) -> Array:
return einops.repeat(x, self.pattern, **self.lengths)