From eff3126e1ccd2297d50ae0db025e35d88117ef4d Mon Sep 17 00:00:00 2001 From: eigenvivek Date: Tue, 5 Dec 2023 20:35:50 -0500 Subject: [PATCH] Use convert backend for pose formation --- diffpose/deepfluoro.py | 23 ++++++------- diffpose/ljubljana.py | 31 +++++++++--------- notebooks/api/00_deepfluoro.ipynb | 54 +++++++++++++++++++++++++------ notebooks/api/01_ljubljana.ipynb | 35 +++++++++++--------- 4 files changed, 92 insertions(+), 51 deletions(-) diff --git a/diffpose/deepfluoro.py b/diffpose/deepfluoro.py index 32c091f..9e3c6fa 100644 --- a/diffpose/deepfluoro.py +++ b/diffpose/deepfluoro.py @@ -307,26 +307,27 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)): # %% ../notebooks/api/00_deepfluoro.ipynb 26 from beartype import beartype -from pytorchse3.se3 import se3_exp_map -from .calibration import RigidTransform +from .calibration import RigidTransform, convert @beartype def get_random_offset(batch_size: int, device) -> RigidTransform: - t1 = torch.distributions.Normal(10, 70).sample((batch_size,)) - t2 = torch.distributions.Normal(250, 90).sample((batch_size,)) - t3 = torch.distributions.Normal(5, 50).sample((batch_size,)) r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,)) r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,)) - logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device) - T = se3_exp_map(logmap) - R = T[..., :3, :3].transpose(-1, -2) - t = T[..., 3, :3] - return RigidTransform(R, t) + t1 = torch.distributions.Normal(10, 70).sample((batch_size,)) + t2 = torch.distributions.Normal(250, 90).sample((batch_size,)) + t3 = torch.distributions.Normal(5, 50).sample((batch_size,)) + log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device) + log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device) + return convert( + [log_R_vee, log_t_vee], + "se3_log_map", + "se3_exp_map", + ) -# %% ../notebooks/api/00_deepfluoro.ipynb 32 +# %% ../notebooks/api/00_deepfluoro.ipynb 33 from torchvision.transforms import Compose, Lambda, Normalize, Resize diff --git a/diffpose/ljubljana.py b/diffpose/ljubljana.py index d10084c..d973136 100644 --- a/diffpose/ljubljana.py +++ b/diffpose/ljubljana.py @@ -116,35 +116,36 @@ def __getitem__(self, idx): # %% ../notebooks/api/01_ljubljana.ipynb 7 from beartype import beartype -from pytorchse3.se3 import se3_exp_map -from .calibration import RigidTransform +from .calibration import RigidTransform, convert @beartype def get_random_offset(view, batch_size: int, device) -> RigidTransform: if view == "ap": - t1 = torch.distributions.Normal(-10, 20).sample((batch_size,)) + t1 = torch.distributions.Normal(-6, 30).sample((batch_size,)) t2 = torch.distributions.Normal(175, 30).sample((batch_size,)) - t3 = torch.distributions.Normal(-5, 15).sample((batch_size,)) - r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) - r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) + t3 = torch.distributions.Normal(-5, 30).sample((batch_size,)) + r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) + r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,)) elif view == "lat": - t1 = torch.distributions.Normal(75, 15).sample((batch_size,)) - t2 = torch.distributions.Normal(-80, 20).sample((batch_size,)) - t3 = torch.distributions.Normal(-5, 10).sample((batch_size,)) - r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) + t1 = torch.distributions.Normal(75, 30).sample((batch_size,)) + t2 = torch.distributions.Normal(-80, 30).sample((batch_size,)) + t3 = torch.distributions.Normal(-5, 30).sample((batch_size,)) + r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,)) r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,)) else: raise ValueError(f"view must be 'ap' or 'lat', not '{view}'") - logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device) - T = se3_exp_map(logmap) - R = T[..., :3, :3].transpose(-1, -2) - t = T[..., 3, :3] - return RigidTransform(R, t) + log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device) + log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device) + return convert( + [log_R_vee, log_t_vee], + "se3_log_map", + "se3_exp_map", + ) # %% ../notebooks/api/01_ljubljana.ipynb 9 from torch.nn.functional import pad diff --git a/notebooks/api/00_deepfluoro.ipynb b/notebooks/api/00_deepfluoro.ipynb index f1f5e87..d577c76 100644 --- a/notebooks/api/00_deepfluoro.ipynb +++ b/notebooks/api/00_deepfluoro.ipynb @@ -713,24 +713,60 @@ "source": [ "#| export\n", "from beartype import beartype\n", - "from pytorchse3.se3 import se3_exp_map\n", "\n", - "from diffpose.calibration import RigidTransform\n", + "from diffpose.calibration import RigidTransform, convert\n", "\n", "\n", "@beartype\n", "def get_random_offset(batch_size: int, device) -> RigidTransform:\n", - " t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n", - " t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n", - " t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n", " r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))\n", " r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", " r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))\n", + " t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n", + " log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)\n", + " log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)\n", + " return convert(\n", + " [log_R_vee, log_t_vee],\n", + " \"se3_log_map\",\n", + " \"se3_exp_map\",\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e14ac65e-3c57-4c67-be3a-1d7184f669b3", + "metadata": {}, + "outputs": [], + "source": [ + "@beartype\n", + "def get_random_offset(view, batch_size: int, device) -> RigidTransform:\n", + " if view == \"ap\":\n", + " t1 = torch.distributions.Normal(-6, 20).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(175, 30).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", + " r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", + " r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))\n", + " elif view == \"lat\":\n", + " t1 = torch.distributions.Normal(75, 15).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))\n", + " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n", + " else:\n", + " raise ValueError(f\"view must be 'ap' or 'lat', not '{view}'\")\n", + "\n", " logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)\n", - " T = se3_exp_map(logmap)\n", - " R = T[..., :3, :3].transpose(-1, -2)\n", - " t = T[..., 3, :3]\n", - " return RigidTransform(R, t)" + " T = convert(\n", + " [logmap[..., :3], logmap[..., 3:]],\n", + " \"se3_log_map\",\n", + " \"se3_exp_map\",\n", + " )\n", + " return T" ] }, { diff --git a/notebooks/api/01_ljubljana.ipynb b/notebooks/api/01_ljubljana.ipynb index 7bd9130..48eee1a 100644 --- a/notebooks/api/01_ljubljana.ipynb +++ b/notebooks/api/01_ljubljana.ipynb @@ -187,35 +187,36 @@ "source": [ "#| export\n", "from beartype import beartype\n", - "from pytorchse3.se3 import se3_exp_map\n", "\n", - "from diffpose.calibration import RigidTransform\n", + "from diffpose.calibration import RigidTransform, convert\n", "\n", "\n", "@beartype\n", "def get_random_offset(view, batch_size: int, device) -> RigidTransform:\n", " if view == \"ap\":\n", - " t1 = torch.distributions.Normal(-10, 20).sample((batch_size,))\n", + " t1 = torch.distributions.Normal(-6, 30).sample((batch_size,))\n", " t2 = torch.distributions.Normal(175, 30).sample((batch_size,))\n", - " t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))\n", - " r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", - " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", + " r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n", " r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))\n", " elif view == \"lat\":\n", - " t1 = torch.distributions.Normal(75, 15).sample((batch_size,))\n", - " t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))\n", - " t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))\n", - " r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", + " t1 = torch.distributions.Normal(75, 30).sample((batch_size,))\n", + " t2 = torch.distributions.Normal(-80, 30).sample((batch_size,))\n", + " t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))\n", + " r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))\n", " r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n", " r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n", " else:\n", " raise ValueError(f\"view must be 'ap' or 'lat', not '{view}'\")\n", "\n", - " logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)\n", - " T = se3_exp_map(logmap)\n", - " R = T[..., :3, :3].transpose(-1, -2)\n", - " t = T[..., 3, :3]\n", - " return RigidTransform(R, t)" + " log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)\n", + " log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)\n", + " return convert(\n", + " [log_R_vee, log_t_vee],\n", + " \"se3_log_map\",\n", + " \"se3_exp_map\",\n", + " )" ] }, { @@ -268,7 +269,9 @@ " self.intrinsic.inverse(),\n", " pad(x, (0, 1), value=1), # Convert to homogenous coordinates\n", " )\n", - " extrinsic = self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)\n", + " extrinsic = (\n", + " self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)\n", + " )\n", " return extrinsic.transform_points(x)\n", "\n", " def __call__(self, pose):\n",