Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added tridiagonal_solve in tensorflow frontend #23279

Merged
merged 49 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
fce115a
Added foldr for tensorflow frontend
AbdullahSabry Aug 15, 2023
492f45d
fixed formatting
AbdullahSabry Aug 15, 2023
03b9a65
fixed formatting
AbdullahSabry Aug 15, 2023
86ae03e
Fixed formatting
AbdullahSabry Aug 18, 2023
03a5751
Fixed formatting
AbdullahSabry Aug 18, 2023
2d1fd84
Merge branch 'unifyai:main' into main
AbdullahSabry Aug 25, 2023
249407f
Dealt with conflicts
AbdullahSabry Aug 26, 2023
bd14e9b
Dealt with conflicts
AbdullahSabry Aug 26, 2023
f097188
Dealt with conflicts
AbdullahSabry Aug 26, 2023
fc83ed1
Merge branch 'unifyai:main' into main
AbdullahSabry Aug 26, 2023
bad94ee
Dealt with conflicts
AbdullahSabry Aug 26, 2023
0656e23
Merge branch 'main' of github.com:AbdullahSabry/ivy into main
AbdullahSabry Aug 26, 2023
eef77f8
Merge branch 'unifyai:main' into main
AbdullahSabry Aug 27, 2023
1fd46a6
Merge branch 'unifyai:main' into main
AbdullahSabry Aug 28, 2023
ff7eb13
Cleaned branch
AbdullahSabry Aug 28, 2023
11307bb
Cleaned branch
AbdullahSabry Aug 28, 2023
8fe6050
resolving conflicts
AbdullahSabry Aug 29, 2023
eb6905c
resolving conflicts
AbdullahSabry Aug 29, 2023
fb454e8
Merge branch 'unifyai:main' into main
AbdullahSabry Aug 29, 2023
d9ae52b
resolving conflicts
AbdullahSabry Aug 29, 2023
a5026c1
Merge branch 'main' of github.com:AbdullahSabry/ivy into main
AbdullahSabry Aug 29, 2023
68f2511
Update general_functions.py
AbdullahSabry Sep 1, 2023
aa8736f
Fixed paddle backend failing test
AbdullahSabry Sep 1, 2023
c345d7e
Merge branch 'unifyai:main' into main
AbdullahSabry Sep 2, 2023
b615abf
Initial commit
AbdullahSabry Sep 5, 2023
9008111
Added Initial testing for tridiagonal solve
AbdullahSabry Sep 8, 2023
062c8e3
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 8, 2023
7ebbbef
Added Initial testing for tridiagonal solve
AbdullahSabry Sep 8, 2023
143b092
Fixed Formatting
AbdullahSabry Sep 8, 2023
529eff2
Fixed Formatting
AbdullahSabry Sep 8, 2023
293c1f8
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 8, 2023
1150916
Added support for more formats
AbdullahSabry Sep 8, 2023
64a4700
Merge branch 'tridiagonal_solve' of github.com:AbdullahSabry/ivy into…
AbdullahSabry Sep 8, 2023
9f68178
Added support for all formats in testing
AbdullahSabry Sep 8, 2023
0e41c1c
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 8, 2023
f37bc2c
Added support for all formats in testing
AbdullahSabry Sep 8, 2023
cc9bc27
Fixed testing for tridiagonal_solve
AbdullahSabry Sep 19, 2023
f114995
Fixed formatting
AbdullahSabry Sep 19, 2023
1faf9ed
Fixed formatting
AbdullahSabry Sep 19, 2023
f99b05d
Fixed formatting
AbdullahSabry Sep 19, 2023
cc6abb8
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 19, 2023
202d1d9
Fixed tests
AbdullahSabry Sep 19, 2023
3e3c1b4
Fixed formatting
AbdullahSabry Sep 19, 2023
b9e4297
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 20, 2023
b67d948
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 24, 2023
6d298fd
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 24, 2023
5a29cfa
Merge branch 'unifyai:main' into tridiagonal_solve
AbdullahSabry Sep 25, 2023
91d74f4
Merge remote-tracking branch 'upstream/main' into pr/23279
NripeshN Dec 1, 2023
509d542
🤖 Lint code
ivy-branch Dec 1, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions ivy/functional/frontends/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,52 @@ def tensorsolve(a, b, axes):
@to_ivy_arrays_and_back
def trace(x, name=None):
return ivy.trace(x, axis1=-2, axis2=-1)


@to_ivy_arrays_and_back
@with_supported_dtypes(
{
"2.13.0 and below": (
"float32",
"float64",
"complex64",
"complex128",
)
},
"tensorflow",
)
def tridiagonal_solve(
diagonals,
rhs,
diagonals_format="compact",
transpose_rhs=False,
conjugate_rhs=False,
name=None,
partial_pivoting=True,
perturb_singular=False,
):
if transpose_rhs is True:
rhs_copy = ivy.matrix_transpose(rhs)
if conjugate_rhs is True:
rhs_copy = ivy.conj(rhs)
if not transpose_rhs and not conjugate_rhs:
rhs_copy = ivy.array(rhs)

if diagonals_format == "matrix":
return ivy.solve(diagonals, rhs_copy)
elif diagonals_format in ["sequence", "compact"]:
diagonals = ivy.array(diagonals)
dim = diagonals[0].shape[0]
diagonals[[0, -1], [-1, 0]] = 0
dummy_idx = [0, 0]
indices = ivy.array([
[(i, i + 1) for i in range(dim - 1)] + [dummy_idx],
[(i, i) for i in range(dim)],
[dummy_idx] + [(i + 1, i) for i in range(dim - 1)],
])
constructed_matrix = ivy.scatter_nd(
indices, diagonals, shape=ivy.array([dim, dim])
)
return ivy.solve(constructed_matrix, rhs_copy)
else:
raise "Unexpected diagonals_format"
106 changes: 106 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,41 @@ def _get_second_matrix(draw):
)


@st.composite
def _get_tridiagonal_dtype_matrix_format(draw):
input_dtype_strategy = st.shared(
st.sampled_from(draw(helpers.get_dtypes("float_and_complex"))),
key="shared_dtype",
)
input_dtype = draw(input_dtype_strategy)
shared_size = draw(
st.shared(helpers.ints(min_value=2, max_value=4), key="shared_size")
)
diagonals_format = draw(st.sampled_from(["compact", "sequence", "matrix"]))
if diagonals_format == "matrix":
matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=tuple([shared_size, shared_size]),
min_value=2,
max_value=5,
).filter(tridiagonal_matrix_filter)
)
elif diagonals_format in ["compact", "sequence"]:
matrix = draw(
helpers.array_values(
dtype=input_dtype,
shape=tuple([3, shared_size]),
min_value=2,
max_value=5,
).filter(tridiagonal_compact_filter)
)
if diagonals_format == "sequence":
matrix = list(matrix)

return input_dtype, matrix, diagonals_format


# --- Main --- #
# ------------ #

Expand Down Expand Up @@ -1207,3 +1242,74 @@ def test_tensorflow_trace(
fn_tree=fn_tree,
x=x[0],
)


# tridiagonal_solve
@handle_frontend_test(
fn_tree="tensorflow.linalg.tridiagonal_solve",
x=_get_tridiagonal_dtype_matrix_format(),
y=_get_second_matrix(),
transpose_rhs=st.just(False),
conjugate_rhs=st.booleans(),
)
def test_tensorflow_tridiagonal_solve(
*,
x,
y,
transpose_rhs,
conjugate_rhs,
frontend,
backend_fw,
test_flags,
fn_tree,
on_device,
):
input_dtype1, x1, diagonals_format = x
input_dtype2, x2 = y
helpers.test_frontend_function(
input_dtypes=[input_dtype1, input_dtype2],
backend_to_test=backend_fw,
frontend=frontend,
test_flags=test_flags,
fn_tree=fn_tree,
on_device=on_device,
rtol=1e-3,
atol=1e-3,
diagonals=x1,
rhs=x2,
diagonals_format=diagonals_format,
transpose_rhs=transpose_rhs,
conjugate_rhs=conjugate_rhs,
)


def tridiagonal_compact_filter(x):
diagonals = ivy.array(x)
dim = diagonals[0].shape[0]
diagonals[[0, -1], [-1, 0]] = 0
dummy_idx = [0, 0]
indices = ivy.array([
[(i, i + 1) for i in range(dim - 1)] + [dummy_idx],
[(i, i) for i in range(dim)],
[dummy_idx] + [(i + 1, i) for i in range(dim - 1)],
])
matrix = ivy.scatter_nd(
indices, diagonals, ivy.array([dim, dim]), reduction="replace"
)
return tridiagonal_matrix_filter(matrix)


def tridiagonal_matrix_filter(x):
dim = x.shape[0]
if ivy.abs(ivy.det(x)) < 1e-3:
return False
for i in range(dim):
for j in range(dim):
cell = x[i][j]
if i == j or i == j - 1 or i == j + 1:
if cell == 0:
return False
else:
if cell != 0:
return False
return True
Loading