From 4497f65d1f9876d613eed690f8d94b4ef5f8a2fd Mon Sep 17 00:00:00 2001 From: guanhuaw Date: Sun, 24 Mar 2024 19:51:22 -0700 Subject: [PATCH] jitting --- mirtorch/alg/fbpd.py | 8 ++--- mirtorch/linear/linearmaps.py | 68 ++++++++++++++++++----------------- mirtorch/linear/mri.py | 8 ++--- 3 files changed, 44 insertions(+), 40 deletions(-) diff --git a/mirtorch/alg/fbpd.py b/mirtorch/alg/fbpd.py index e3affa0..342d065 100644 --- a/mirtorch/alg/fbpd.py +++ b/mirtorch/alg/fbpd.py @@ -42,10 +42,10 @@ def __init__( h_prox: Prox, g_L: float, G_norm: float, - G: LinearMap = None, - tau: float = None, + G: LinearMap | None = None, + tau: float | None = None, max_iter: int = 10, - eval_func: Callable = None, + eval_func: Callable | None = None, p: int = 1, ): self.max_iter = max_iter @@ -87,7 +87,7 @@ def run(self, x0: torch.Tensor): if self.eval_func is not None: saved.append(self.eval_func(xold)) logger.info( - "The cost function at %dth iter in FBPD: %10.3e." % (i, saved[-1]) + "The cost function at %dth iter in FBPD: %10.3e.", i, saved[-1] ) if self.eval_func is not None: return xold, saved diff --git a/mirtorch/linear/linearmaps.py b/mirtorch/linear/linearmaps.py index 289bf3a..c042e0e 100644 --- a/mirtorch/linear/linearmaps.py +++ b/mirtorch/linear/linearmaps.py @@ -1,4 +1,6 @@ -from typing import Sequence, TypeVar, Union +from __future__ import annotations + +from typing import List, Union import numpy as np import torch @@ -14,9 +16,6 @@ def check_device(x, y): assert x.device == y.device, "Tensors should be on the same device" -T = TypeVar("T", bound="LinearMap") - - class LinearMap: r""" Abstraction of linear operators as matrices :math:`y = A*x`. @@ -52,7 +51,7 @@ def backward(ctx, grad_data_in): size_out: the size of the output of the linear map (a list) """ - def __init__(self, size_in: Sequence[int], size_out: Sequence[int]): + def __init__(self, size_in: List[int], size_out: List[int]): r""" Initiate the linear operator. """ @@ -60,8 +59,8 @@ def __init__(self, size_in: Sequence[int], size_out: Sequence[int]): self.size_out = list(size_out) def __repr__(self): - return "".format( - repr_str=self.__class__.__name__, oshape=self.size_out, ishape=self.size_in + return ( + f"" ) def __call__(self, x) -> Tensor: @@ -96,19 +95,21 @@ def adjoint(self, x) -> Tensor: return self._apply_adjoint(x) @property - def H(self) -> T: + def H(self) -> LinearMap: r""" Apply the (Hermitian) transpose """ return ConjTranspose(self) - def __add__(self: T, other: T) -> T: + def __add__(self: LinearMap, other: LinearMap) -> LinearMap: r""" Reload the + symbol. """ return Add(self, other) - def __mul__(self: T, other) -> T: + def __mul__( + self: LinearMap, other: Union[str, int, LinearMap, Tensor] + ) -> LinearMap: r""" Reload the * symbol. """ @@ -116,49 +117,52 @@ def __mul__(self: T, other) -> T: return Multiply(self, other) elif isinstance(other, LinearMap): return Matmul(self, other) - elif isinstance(other, torch.Tensor): + elif isinstance(other, Tensor): if not other.shape: - # raise ValueError( - # "Input tensor has empty shape. If want to scale the linear map, please use the standard scalar") return Multiply(self, other) return self.apply(other) else: raise NotImplementedError( - f"Only scalers, Linearmaps or Tensors, rather than '{type(other)}' are allowed as arguments for this function." + ( + f"Only scalers, Linearmaps or Tensors, rather than '{type(other)}' " + "fare allowed as arguments for this function." + ) ) - def __rmul__(self: T, other) -> T: + def __rmul__( + self: LinearMap, other: Union[str, int, LinearMap, Tensor] + ) -> LinearMap: r""" Reload the * symbol. """ if np.isscalar(other): return Multiply(self, other) - elif isinstance(other, torch.Tensor) and not other.shape: + elif isinstance(other, Tensor) and not other.shape: return Multiply(self, other) else: return NotImplemented - def __sub__(self: T, other: T) -> T: + def __sub__(self: LinearMap, other: LinearMap) -> LinearMap: r""" Reload the - symbol. """ return self.__add__(-other) - def __neg__(self: T) -> T: + def __neg__(self: LinearMap) -> LinearMap: r""" Reload the - symbol. """ return -1 * self - def to(self: T, *args, **kwargs): + def to(self: LinearMap, device: Union[torch.device, str]) -> LinearMap: r""" Copy to different devices """ for prop in self.__dict__.keys(): - if isinstance(self.__dict__[prop], torch.Tensor) or isinstance( + if isinstance(self.__dict__[prop], Tensor) or isinstance( self.__dict__[prop], torch.nn.Module ): - self.__dict__[prop] = self.__dict__[prop].to(*args, **kwargs) + self.__dict__[prop] = self.__dict__[prop].to(device) class Add(LinearMap): @@ -184,10 +188,10 @@ def __init__(self, A: LinearMap, B: LinearMap): self.B = B super().__init__(self.A.size_in, self.B.size_out) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: return self.A(x) + self.B(x) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: return self.A.H(x) + self.B.H(x) @@ -208,11 +212,11 @@ def __init__(self, A: LinearMap, a: FloatLike): self.A = A super().__init__(self.A.size_in, self.A.size_out) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: ax = x * self.a return self.A(ax) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: ax = x * self.a return self.A.H(ax) @@ -232,11 +236,11 @@ def __init__(self, A: LinearMap, B: LinearMap): assert list(self.B.size_out) == list(self.A.size_in), "Shapes do not match" super().__init__(self.B.size_in, self.A.size_out) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: # TODO: add gram operator return self.A(self.B(x)) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: return self.B.H(self.A.H(x)) @@ -249,10 +253,10 @@ def __init__(self, A: LinearMap): self.A = A super().__init__(A.size_out, A.size_in) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: return self.A.adjoint(x) - def _apply_adjoint(self: T, x: Tensor) -> Tensor: + def _apply_adjoint(self: LinearMap, x: Tensor) -> Tensor: return self.A.apply(x) @@ -265,7 +269,7 @@ class BlockDiagonal(LinearMap): A : List of 2D linear maps """ - def __init__(self, A: Sequence[LinearMap]): + def __init__(self, A: List[LinearMap]): self.A = A # dimension checks @@ -280,7 +284,7 @@ def __init__(self, A: Sequence[LinearMap]): size_out = list(A[0].size_out) + [nz] super().__init__(tuple(size_in), tuple(size_out)) - def _apply(self: T, x: Tensor) -> Tensor: + def _apply(self: LinearMap, x: Tensor) -> Tensor: out = torch.zeros( self.size_out, dtype=x.dtype, device=x.device, layout=x.layout ) @@ -291,7 +295,7 @@ def _apply(self: T, x: Tensor) -> Tensor: out[..., k] = self.A[k].apply(x[..., k]) return out - def _apply_adjoint(self: T, x: Tensor): + def _apply_adjoint(self: LinearMap, x: Tensor): out = torch.zeros(self.size_in, dtype=x.dtype, device=x.device, layout=x.layout) nz = self.size_in[-1] diff --git a/mirtorch/linear/mri.py b/mirtorch/linear/mri.py index 2793e4c..53ecfdb 100644 --- a/mirtorch/linear/mri.py +++ b/mirtorch/linear/mri.py @@ -165,7 +165,7 @@ def __init__( traj: Tensor, norm="ortho", batchmode=True, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, sequential: bool = False, ): @@ -328,7 +328,7 @@ def __init__( traj: Tensor, norm="ortho", batchmode=True, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, ): self.smaps = smaps @@ -433,7 +433,7 @@ def __init__( L: int = 6, nbins: int = 20, dt: int = 4e-3, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, T: Tensor = None, ): @@ -560,7 +560,7 @@ def __init__( L: int = 6, nbins: int = 20, dt: int = 4e-3, - numpoints: Union[int, Sequence[int]] = 6, + numpoints: Union[int, List[int]] = 6, grid_size: float = 2, T: Tensor = None, ):