Skip to content

Commit

Permalink
jitting
Browse files Browse the repository at this point in the history
  • Loading branch information
guanhuaw committed Mar 25, 2024
1 parent ce8adf8 commit 4497f65
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 40 deletions.
8 changes: 4 additions & 4 deletions mirtorch/alg/fbpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ def __init__(
h_prox: Prox,
g_L: float,
G_norm: float,
G: LinearMap = None,
tau: float = None,
G: LinearMap | None = None,
tau: float | None = None,
max_iter: int = 10,
eval_func: Callable = None,
eval_func: Callable | None = None,
p: int = 1,
):
self.max_iter = max_iter
Expand Down Expand Up @@ -87,7 +87,7 @@ def run(self, x0: torch.Tensor):
if self.eval_func is not None:
saved.append(self.eval_func(xold))
logger.info(
"The cost function at %dth iter in FBPD: %10.3e." % (i, saved[-1])
"The cost function at %dth iter in FBPD: %10.3e.", i, saved[-1]
)
if self.eval_func is not None:
return xold, saved
Expand Down
68 changes: 36 additions & 32 deletions mirtorch/linear/linearmaps.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Sequence, TypeVar, Union
from __future__ import annotations

from typing import List, Union

import numpy as np
import torch
Expand All @@ -14,9 +16,6 @@ def check_device(x, y):
assert x.device == y.device, "Tensors should be on the same device"


T = TypeVar("T", bound="LinearMap")


class LinearMap:
r"""
Abstraction of linear operators as matrices :math:`y = A*x`.
Expand Down Expand Up @@ -52,16 +51,16 @@ def backward(ctx, grad_data_in):
size_out: the size of the output of the linear map (a list)
"""

def __init__(self, size_in: Sequence[int], size_out: Sequence[int]):
def __init__(self, size_in: List[int], size_out: List[int]):
r"""
Initiate the linear operator.
"""
self.size_in = list(size_in)
self.size_out = list(size_out)

def __repr__(self):
return "<LinearMap {repr_str} of {oshape}x{ishape}>".format(
repr_str=self.__class__.__name__, oshape=self.size_out, ishape=self.size_in
return (
f"<LinearMap {self.__class__.__name__} of {self.size_out}x{self.size_in}>"
)

def __call__(self, x) -> Tensor:
Expand Down Expand Up @@ -96,69 +95,74 @@ def adjoint(self, x) -> Tensor:
return self._apply_adjoint(x)

@property
def H(self) -> T:
def H(self) -> LinearMap:
r"""
Apply the (Hermitian) transpose
"""
return ConjTranspose(self)

def __add__(self: T, other: T) -> T:
def __add__(self: LinearMap, other: LinearMap) -> LinearMap:
r"""
Reload the + symbol.
"""
return Add(self, other)

def __mul__(self: T, other) -> T:
def __mul__(
self: LinearMap, other: Union[str, int, LinearMap, Tensor]
) -> LinearMap:
r"""
Reload the * symbol.
"""
if np.isscalar(other):
return Multiply(self, other)
elif isinstance(other, LinearMap):
return Matmul(self, other)
elif isinstance(other, torch.Tensor):
elif isinstance(other, Tensor):
if not other.shape:
# raise ValueError(
# "Input tensor has empty shape. If want to scale the linear map, please use the standard scalar")
return Multiply(self, other)
return self.apply(other)
else:
raise NotImplementedError(
f"Only scalers, Linearmaps or Tensors, rather than '{type(other)}' are allowed as arguments for this function."
(
f"Only scalers, Linearmaps or Tensors, rather than '{type(other)}' "
"fare allowed as arguments for this function."
)
)

def __rmul__(self: T, other) -> T:
def __rmul__(
self: LinearMap, other: Union[str, int, LinearMap, Tensor]
) -> LinearMap:
r"""
Reload the * symbol.
"""
if np.isscalar(other):
return Multiply(self, other)
elif isinstance(other, torch.Tensor) and not other.shape:
elif isinstance(other, Tensor) and not other.shape:
return Multiply(self, other)
else:
return NotImplemented

def __sub__(self: T, other: T) -> T:
def __sub__(self: LinearMap, other: LinearMap) -> LinearMap:
r"""
Reload the - symbol.
"""
return self.__add__(-other)

def __neg__(self: T) -> T:
def __neg__(self: LinearMap) -> LinearMap:
r"""
Reload the - symbol.
"""
return -1 * self

def to(self: T, *args, **kwargs):
def to(self: LinearMap, device: Union[torch.device, str]) -> LinearMap:
r"""
Copy to different devices
"""
for prop in self.__dict__.keys():
if isinstance(self.__dict__[prop], torch.Tensor) or isinstance(
if isinstance(self.__dict__[prop], Tensor) or isinstance(
self.__dict__[prop], torch.nn.Module
):
self.__dict__[prop] = self.__dict__[prop].to(*args, **kwargs)
self.__dict__[prop] = self.__dict__[prop].to(device)


class Add(LinearMap):
Expand All @@ -184,10 +188,10 @@ def __init__(self, A: LinearMap, B: LinearMap):
self.B = B
super().__init__(self.A.size_in, self.B.size_out)

def _apply(self: T, x: Tensor) -> Tensor:
def _apply(self: LinearMap, x: Tensor) -> Tensor:
return self.A(x) + self.B(x)

def _apply_adjoint(self: T, x: Tensor) -> Tensor:
def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor:
return self.A.H(x) + self.B.H(x)


Expand All @@ -208,11 +212,11 @@ def __init__(self, A: LinearMap, a: FloatLike):
self.A = A
super().__init__(self.A.size_in, self.A.size_out)

def _apply(self: T, x: Tensor) -> Tensor:
def _apply(self: LinearMap, x: Tensor) -> Tensor:
ax = x * self.a
return self.A(ax)

def _apply_adjoint(self: T, x: Tensor) -> Tensor:
def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor:
ax = x * self.a
return self.A.H(ax)

Expand All @@ -232,11 +236,11 @@ def __init__(self, A: LinearMap, B: LinearMap):
assert list(self.B.size_out) == list(self.A.size_in), "Shapes do not match"
super().__init__(self.B.size_in, self.A.size_out)

def _apply(self: T, x: Tensor) -> Tensor:
def _apply(self: LinearMap, x: Tensor) -> Tensor:
# TODO: add gram operator
return self.A(self.B(x))

def _apply_adjoint(self: T, x: Tensor) -> Tensor:
def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor:
return self.B.H(self.A.H(x))


Expand All @@ -249,10 +253,10 @@ def __init__(self, A: LinearMap):
self.A = A
super().__init__(A.size_out, A.size_in)

def _apply(self: T, x: Tensor) -> Tensor:
def _apply(self: LinearMap, x: Tensor) -> Tensor:
return self.A.adjoint(x)

def _apply_adjoint(self: T, x: Tensor) -> Tensor:
def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor:
return self.A.apply(x)


Expand All @@ -265,7 +269,7 @@ class BlockDiagonal(LinearMap):
A : List of 2D linear maps
"""

def __init__(self, A: Sequence[LinearMap]):
def __init__(self, A: List[LinearMap]):
self.A = A

# dimension checks
Expand All @@ -280,7 +284,7 @@ def __init__(self, A: Sequence[LinearMap]):
size_out = list(A[0].size_out) + [nz]
super().__init__(tuple(size_in), tuple(size_out))

def _apply(self: T, x: Tensor) -> Tensor:
def _apply(self: LinearMap, x: Tensor) -> Tensor:
out = torch.zeros(
self.size_out, dtype=x.dtype, device=x.device, layout=x.layout
)
Expand All @@ -291,7 +295,7 @@ def _apply(self: T, x: Tensor) -> Tensor:
out[..., k] = self.A[k].apply(x[..., k])
return out

def _apply_adjoint(self: T, x: Tensor):
def _apply_adjoint(self: LinearMap, x: Tensor):
out = torch.zeros(self.size_in, dtype=x.dtype, device=x.device, layout=x.layout)
nz = self.size_in[-1]

Expand Down
8 changes: 4 additions & 4 deletions mirtorch/linear/mri.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(
traj: Tensor,
norm="ortho",
batchmode=True,
numpoints: Union[int, Sequence[int]] = 6,
numpoints: Union[int, List[int]] = 6,
grid_size: float = 2,
sequential: bool = False,
):
Expand Down Expand Up @@ -328,7 +328,7 @@ def __init__(
traj: Tensor,
norm="ortho",
batchmode=True,
numpoints: Union[int, Sequence[int]] = 6,
numpoints: Union[int, List[int]] = 6,
grid_size: float = 2,
):
self.smaps = smaps
Expand Down Expand Up @@ -433,7 +433,7 @@ def __init__(
L: int = 6,
nbins: int = 20,
dt: int = 4e-3,
numpoints: Union[int, Sequence[int]] = 6,
numpoints: Union[int, List[int]] = 6,
grid_size: float = 2,
T: Tensor = None,
):
Expand Down Expand Up @@ -560,7 +560,7 @@ def __init__(
L: int = 6,
nbins: int = 20,
dt: int = 4e-3,
numpoints: Union[int, Sequence[int]] = 6,
numpoints: Union[int, List[int]] = 6,
grid_size: float = 2,
T: Tensor = None,
):
Expand Down

0 comments on commit 4497f65

Please sign in to comment.