/
linalg.py
139 lines (108 loc) · 4.24 KB
/
linalg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import click
from devito import Inc, Operator, Function, dimensions, info
from devito.tools import as_tuple
__all__ = ['mat_vec', 'transpose_mat_vec', 'mat_mat', 'mat_mat_sum',
'chain_contractions']
@click.group(chain=True)
def linalg():
"""
A set of kernels performing basic (BLAS-like) linear algebra operations.
Upper-case letters ``A, B, C, ...`` are for matrices; lower-case letters
``x, y, ...`` are for vectors.
"""
pass
def option_basic(f):
def callback_shape(ctx, param, value):
return as_tuple(value)
def callback_opts(ctx, param, value):
if value is True:
return ('blocking,simd,openmp', {'blockinner': True})
else:
return 'noop'
options = [
click.option('-ms', '--mat-shape', default=(4, 4), help='Matrix shape'),
click.option('-vs', '--vec-shape', default=4, help='Vector shape',
callback=callback_shape),
click.option('-o', '--optimize', default=False, is_flag=True,
help='Generate optimized code', callback=callback_opts)
]
for option in reversed(options):
f = option(f)
return f
@linalg.command(name='mat-vec')
@option_basic
def cli_mat_vec(mat_shape, vec_shape, optimize, **kwargs):
"""``Ax = b``."""
i, j = dimensions('i j')
A = Function(name='A', shape=mat_shape, dimensions=(i, j))
x = Function(name='x', shape=vec_shape, dimensions=(j,))
b = Function(name='b', shape=vec_shape, dimensions=(i,))
mat_vec(A, x, b, optimize)
@linalg.command(name='transpose-mat-vec')
@option_basic
def cli_transpose_mat_vec(mat_shape, vec_shape, optimize, **kwargs):
"""``A -> A^T, A^Tx = b``."""
i, j = dimensions('i j')
A = Function(name='A', shape=mat_shape, dimensions=(i, j))
x = Function(name='x', shape=vec_shape, dimensions=(j,))
b = Function(name='b', shape=vec_shape, dimensions=(i,))
transpose_mat_vec(A, x, b, optimize)
@linalg.command(name='mat-mat')
@option_basic
def cli_mat_mat(mat_shape, optimize, **kwargs):
"""``AB = C``."""
i, j, k = dimensions('i j k')
A = Function(name='A', shape=mat_shape, dimensions=(i, j))
B = Function(name='B', shape=mat_shape, dimensions=(j, k))
C = Function(name='C', shape=mat_shape, dimensions=(i, k))
mat_mat(A, B, C, optimize)
@linalg.command(name='mat-mat-sum')
@option_basic
def cli_mat_mat_sum(mat_shape, optimize, **kwargs):
"""``AB + AC = D``."""
i, j, k = dimensions('i j k')
A = Function(name='A', shape=mat_shape, dimensions=(i, j))
B = Function(name='B', shape=mat_shape, dimensions=(j, k))
C = Function(name='C', shape=mat_shape, dimensions=(j, k))
D = Function(name='D', shape=mat_shape, dimensions=(i, k))
mat_mat_sum(A, B, C, D, optimize)
@linalg.command(name='chain-contractions')
@option_basic
def cli_chain_contractions(mat_shape, optimize, **kwargs):
"""``AB + AC = D, DE = F``."""
i, j, k, l = dimensions('i j k l')
A = Function(name='A', shape=mat_shape, dimensions=(i, j))
B = Function(name='B', shape=mat_shape, dimensions=(j, k))
C = Function(name='C', shape=mat_shape, dimensions=(j, k))
D = Function(name='D', shape=mat_shape, dimensions=(i, k))
E = Function(name='E', shape=mat_shape, dimensions=(k, l))
F = Function(name='F', shape=mat_shape, dimensions=(i, l))
chain_contractions(A, B, C, D, E, F, optimize)
def mat_vec(A, x, b, optimize):
"""``Ax = b``."""
op = Operator(Inc(b, A*x), opt=optimize)
op.apply()
info('Executed `Ax = b`')
def transpose_mat_vec(A, x, b, optimize):
"""``A -> A^T, A^Tx = b``."""
i, j = A.indices
op = Operator([Inc(b, A[j, i]*x)], opt=optimize)
op.apply()
info('Executed `A^Tx = b`')
def mat_mat(A, B, C, optimize):
"""``AB = C``."""
op = Operator(Inc(C, A*B), opt=optimize)
op.apply()
info('Executed `AB = C`')
def mat_mat_sum(A, B, C, D, optimize):
"""``AB + AC = D``."""
op = Operator(Inc(D, A*B + A*C), opt=optimize)
op.apply()
info('Executed `AB + AC = D`')
def chain_contractions(A, B, C, D, E, F, optimize):
"""``AB + AC = D, DE = F``."""
op = Operator([Inc(D, A*B + A*C), Inc(F, D*E)], opt=optimize)
op.apply()
info('Executed `AB + AC = D, DE = F`')
if __name__ == "__main__":
linalg()