Skip to content

Commit

Permalink
Use convert backend for pose formation
Browse files Browse the repository at this point in the history
  • Loading branch information
eigenvivek committed Dec 6, 2023
1 parent 4dda5a2 commit eff3126
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 51 deletions.
23 changes: 12 additions & 11 deletions diffpose/deepfluoro.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,26 +307,27 @@ def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)):

# %% ../notebooks/api/00_deepfluoro.ipynb 26
from beartype import beartype
from pytorchse3.se3 import se3_exp_map

from .calibration import RigidTransform
from .calibration import RigidTransform, convert


@beartype
def get_random_offset(batch_size: int, device) -> RigidTransform:
t1 = torch.distributions.Normal(10, 70).sample((batch_size,))
t2 = torch.distributions.Normal(250, 90).sample((batch_size,))
t3 = torch.distributions.Normal(5, 50).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))
r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))
logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)
T = se3_exp_map(logmap)
R = T[..., :3, :3].transpose(-1, -2)
t = T[..., 3, :3]
return RigidTransform(R, t)
t1 = torch.distributions.Normal(10, 70).sample((batch_size,))
t2 = torch.distributions.Normal(250, 90).sample((batch_size,))
t3 = torch.distributions.Normal(5, 50).sample((batch_size,))
log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)
log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)
return convert(
[log_R_vee, log_t_vee],
"se3_log_map",
"se3_exp_map",
)

# %% ../notebooks/api/00_deepfluoro.ipynb 32
# %% ../notebooks/api/00_deepfluoro.ipynb 33
from torchvision.transforms import Compose, Lambda, Normalize, Resize


Expand Down
31 changes: 16 additions & 15 deletions diffpose/ljubljana.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,35 +116,36 @@ def __getitem__(self, idx):

# %% ../notebooks/api/01_ljubljana.ipynb 7
from beartype import beartype
from pytorchse3.se3 import se3_exp_map

from .calibration import RigidTransform
from .calibration import RigidTransform, convert


@beartype
def get_random_offset(view, batch_size: int, device) -> RigidTransform:
if view == "ap":
t1 = torch.distributions.Normal(-10, 20).sample((batch_size,))
t1 = torch.distributions.Normal(-6, 30).sample((batch_size,))
t2 = torch.distributions.Normal(175, 30).sample((batch_size,))
t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))
r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))
elif view == "lat":
t1 = torch.distributions.Normal(75, 15).sample((batch_size,))
t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))
t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))
r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
t1 = torch.distributions.Normal(75, 30).sample((batch_size,))
t2 = torch.distributions.Normal(-80, 30).sample((batch_size,))
t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))
r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))
r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))
r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))
else:
raise ValueError(f"view must be 'ap' or 'lat', not '{view}'")

logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)
T = se3_exp_map(logmap)
R = T[..., :3, :3].transpose(-1, -2)
t = T[..., 3, :3]
return RigidTransform(R, t)
log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)
log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)
return convert(
[log_R_vee, log_t_vee],
"se3_log_map",
"se3_exp_map",
)

# %% ../notebooks/api/01_ljubljana.ipynb 9
from torch.nn.functional import pad
Expand Down
54 changes: 45 additions & 9 deletions notebooks/api/00_deepfluoro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -713,24 +713,60 @@
"source": [
"#| export\n",
"from beartype import beartype\n",
"from pytorchse3.se3 import se3_exp_map\n",
"\n",
"from diffpose.calibration import RigidTransform\n",
"from diffpose.calibration import RigidTransform, convert\n",
"\n",
"\n",
"@beartype\n",
"def get_random_offset(batch_size: int, device) -> RigidTransform:\n",
" t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n",
" r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,))\n",
" t1 = torch.distributions.Normal(10, 70).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(250, 90).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(5, 50).sample((batch_size,))\n",
" log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)\n",
" log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)\n",
" return convert(\n",
" [log_R_vee, log_t_vee],\n",
" \"se3_log_map\",\n",
" \"se3_exp_map\",\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e14ac65e-3c57-4c67-be3a-1d7184f669b3",
"metadata": {},
"outputs": [],
"source": [
"@beartype\n",
"def get_random_offset(view, batch_size: int, device) -> RigidTransform:\n",
" if view == \"ap\":\n",
" t1 = torch.distributions.Normal(-6, 20).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(175, 30).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n",
" r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))\n",
" elif view == \"lat\":\n",
" t1 = torch.distributions.Normal(75, 15).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n",
" r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n",
" else:\n",
" raise ValueError(f\"view must be 'ap' or 'lat', not '{view}'\")\n",
"\n",
" logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)\n",
" T = se3_exp_map(logmap)\n",
" R = T[..., :3, :3].transpose(-1, -2)\n",
" t = T[..., 3, :3]\n",
" return RigidTransform(R, t)"
" T = convert(\n",
" [logmap[..., :3], logmap[..., 3:]],\n",
" \"se3_log_map\",\n",
" \"se3_exp_map\",\n",
" )\n",
" return T"
]
},
{
Expand Down
35 changes: 19 additions & 16 deletions notebooks/api/01_ljubljana.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,35 +187,36 @@
"source": [
"#| export\n",
"from beartype import beartype\n",
"from pytorchse3.se3 import se3_exp_map\n",
"\n",
"from diffpose.calibration import RigidTransform\n",
"from diffpose.calibration import RigidTransform, convert\n",
"\n",
"\n",
"@beartype\n",
"def get_random_offset(view, batch_size: int, device) -> RigidTransform:\n",
" if view == \"ap\":\n",
" t1 = torch.distributions.Normal(-10, 20).sample((batch_size,))\n",
" t1 = torch.distributions.Normal(-6, 30).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(175, 30).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(-5, 15).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,))\n",
" r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,))\n",
" elif view == \"lat\":\n",
" t1 = torch.distributions.Normal(75, 15).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(-80, 20).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(-5, 10).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n",
" t1 = torch.distributions.Normal(75, 30).sample((batch_size,))\n",
" t2 = torch.distributions.Normal(-80, 30).sample((batch_size,))\n",
" t3 = torch.distributions.Normal(-5, 30).sample((batch_size,))\n",
" r1 = torch.distributions.Normal(0.0, 0.1).sample((batch_size,))\n",
" r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,))\n",
" r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,))\n",
" else:\n",
" raise ValueError(f\"view must be 'ap' or 'lat', not '{view}'\")\n",
"\n",
" logmap = torch.stack([r1, r2, r3, t1, t2, t3], dim=1).to(device)\n",
" T = se3_exp_map(logmap)\n",
" R = T[..., :3, :3].transpose(-1, -2)\n",
" t = T[..., 3, :3]\n",
" return RigidTransform(R, t)"
" log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device)\n",
" log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device)\n",
" return convert(\n",
" [log_R_vee, log_t_vee],\n",
" \"se3_log_map\",\n",
" \"se3_exp_map\",\n",
" )"
]
},
{
Expand Down Expand Up @@ -268,7 +269,9 @@
" self.intrinsic.inverse(),\n",
" pad(x, (0, 1), value=1), # Convert to homogenous coordinates\n",
" )\n",
" extrinsic = self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)\n",
" extrinsic = (\n",
" self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose)\n",
" )\n",
" return extrinsic.transform_points(x)\n",
"\n",
" def __call__(self, pose):\n",
Expand Down

0 comments on commit eff3126

Please sign in to comment.