# DQN Rotation Alignment Agent (Kornia + PyTorch)

This notebook implements a DQN agent that learns to align a source image to a target image via rotation, or declare a mismatch when the images are fundamentally different (for example mirrored). The environment uses kornia for differentiable GPU-accelerated rotation and keeps all rotation ops on the active device.


In [5]:
import math
import random
from typing import Callable, Optional, Tuple

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import kornia as K
import gymnasium as gym

try:
    import imageio.v2 as imageio
except Exception:
    import imageio

In [6]:
def get_device() -> torch.device:
    if torch.cuda.is_available():
        return torch.device("cuda")
    if torch.backends.mps.is_available():
        return torch.device("mps")
    return torch.device("cpu")


def set_seed(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def build_state(source: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    # Build 6-channel state by concatenating source and target.
    # Handles batch expansion of target if needed.
    if source.dim() == 3:
        source = source.unsqueeze(0)
    if target.dim() == 3:
        target = target.unsqueeze(0)
    if target.shape[0] == 1 and source.shape[0] > 1:
        target = target.expand(source.shape[0], -1, -1, -1)
    return torch.cat([source, target], dim=1)


def rotate_tensor(img: torch.Tensor, angle_deg) -> torch.Tensor:
    # Rotate image(s) by angle(s) in degrees using kornia.
    if img.dim() == 3:
        img_b = img.unsqueeze(0)
    else:
        img_b = img

    if not torch.is_tensor(angle_deg):
        angle = torch.tensor([angle_deg], device=img_b.device, dtype=img_b.dtype)
    else:
        angle = angle_deg.to(device=img_b.device, dtype=img_b.dtype)

    if angle.dim() == 0:
        angle = angle.unsqueeze(0)
    if angle.numel() == 1 and img_b.shape[0] > 1:
        angle = angle.repeat(img_b.shape[0])

    rotated = K.geometry.transform.rotate(img_b, angle)
    return rotated if img.dim() == 4 else rotated.squeeze(0)


device = get_device()
print("Device:", device)


Device: mps


In [7]:
class RotationEnv(gym.Env):
    # Observation: (6, H, W) = [source(3), target(3)] in [0, 1], torch tensor on device.
    # Actions:
    #   0: rotate CW  (delta = -5 deg)
    #   1: rotate CCW (delta = +5 deg)
    #   2: commit match
    #   3: commit mismatch

    metadata = {"render_modes": []}

    def __init__(
        self,
        image_shape: Tuple[int, int, int],
        max_steps: int = 50,
        epsilon: float = 0.01,
        step_penalty: float = -0.1,
        device: Optional[torch.device] = None,
        pair_sampler: Optional[Callable[[torch.device], Tuple[torch.Tensor, torch.Tensor, bool]]] = None,
    ):
        super().__init__()
        self.image_shape = image_shape  # (3, H, W)
        self.max_steps = max_steps
        self.epsilon = epsilon
        self.step_penalty = step_penalty
        self.device = device or get_device()
        self.pair_sampler = pair_sampler

        _, h, w = image_shape
        self.observation_space = gym.spaces.Box(
            low=0.0,
            high=1.0,
            shape=(6, h, w),
            dtype=np.float32,
        )
        self.action_space = gym.spaces.Discrete(4)

        self.base_source = None
        self.target = None
        self.current_source = None
        self.current_angle = 0.0
        self.is_mirrored = False
        self.step_count = 0

    def _set_pair(self, source: torch.Tensor, target: torch.Tensor, is_mirrored: bool) -> None:
        self.base_source = source.to(self.device).float().clamp(0, 1)
        self.target = target.to(self.device).float().clamp(0, 1)
        self.is_mirrored = bool(is_mirrored)
        self.current_angle = 0.0
        self.current_source = self.base_source
        self.step_count = 0

    def _obs(self) -> torch.Tensor:
        return build_state(self.current_source, self.target)[0]

    def _alignment_error(self) -> torch.Tensor:
        return torch.mean((self.current_source - self.target) ** 2)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        super().reset(seed=seed)
        options = options or {}

        if "pair" in options:
            source, target, is_mirrored = options["pair"]
        elif "source" in options and "target" in options:
            source = options["source"]
            target = options["target"]
            is_mirrored = bool(options.get("is_mirrored", False))
        elif self.pair_sampler is not None:
            source, target, is_mirrored = self.pair_sampler(self.device)
        else:
            raise ValueError("No pair provided. Pass options={pair: (...)} or set pair_sampler.")

        self._set_pair(source, target, is_mirrored)
        info = {"angle": self.current_angle}
        return self._obs(), info

    def step(self, action: int):
        terminated = False
        reward = 0.0

        if action == 0 or action == 1:
            delta = -5.0 if action == 0 else 5.0
            self.current_angle = (self.current_angle + delta) % 360.0
            self.current_source = rotate_tensor(self.base_source, self.current_angle)
            reward = self.step_penalty
        elif action == 2:
            err = self._alignment_error()
            is_match = (not self.is_mirrored) and (err < self.epsilon)
            reward = 100.0 if is_match else -100.0
            terminated = True
        elif action == 3:
            reward = 100.0 if self.is_mirrored else -100.0
            terminated = True
        else:
            raise ValueError(f"Invalid action: {action}")

        self.step_count += 1
        truncated = (self.step_count >= self.max_steps) and (not terminated)

        info = {
            "angle": self.current_angle,
            "error": float(self._alignment_error().detach().cpu()),
            "is_mirrored": self.is_mirrored,
        }
        return self._obs(), float(reward), terminated, truncated, info


In [8]:
class QNetwork(nn.Module):
    def __init__(self, input_shape: Tuple[int, int, int], num_actions: int = 4):
        super().__init__()
        c, h, w = input_shape
        self.features = nn.Sequential(
            nn.Conv2d(c, 32, kernel_size=5, stride=2, padding=2),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
        )
        with torch.no_grad():
            dummy = torch.zeros(1, c, h, w)
            out = self.features(dummy)
            conv_out = out.view(1, -1).shape[1]
        self.head = nn.Sequential(
            nn.Linear(conv_out, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, num_actions),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.head(x)


In [9]:
class ReplayBuffer:
    def __init__(self, capacity: int, state_shape: Tuple[int, int, int], store_uint8: bool = True):
        self.capacity = capacity
        self.state_shape = state_shape
        self.store_uint8 = store_uint8

        state_dtype = torch.uint8 if store_uint8 else torch.float32
        self.states = torch.empty((capacity, *state_shape), dtype=state_dtype)
        self.next_states = torch.empty((capacity, *state_shape), dtype=state_dtype)
        self.actions = torch.empty((capacity,), dtype=torch.int64)
        self.rewards = torch.empty((capacity,), dtype=torch.float32)
        self.dones = torch.empty((capacity,), dtype=torch.uint8)

        self.idx = 0
        self.size = 0

    def push(self, state: torch.Tensor, action: int, reward: float, next_state: torch.Tensor, done: bool) -> None:
        if self.store_uint8:
            s = (state.detach().cpu().clamp(0, 1) * 255).to(torch.uint8)
            ns = (next_state.detach().cpu().clamp(0, 1) * 255).to(torch.uint8)
        else:
            s = state.detach().cpu().float()
            ns = next_state.detach().cpu().float()

        self.states[self.idx] = s
        self.next_states[self.idx] = ns
        self.actions[self.idx] = int(action)
        self.rewards[self.idx] = float(reward)
        self.dones[self.idx] = int(done)

        self.idx = (self.idx + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int, device: torch.device):
        idx = torch.randint(0, self.size, (batch_size,))
        states = self.states[idx].to(device)
        next_states = self.next_states[idx].to(device)
        if self.store_uint8:
            states = states.float() / 255.0
            next_states = next_states.float() / 255.0
        actions = self.actions[idx].to(device)
        rewards = self.rewards[idx].to(device)
        dones = self.dones[idx].to(device).float()
        return states, actions, rewards, next_states, dones


In [10]:
def random_target_image(h: int, w: int, num_shapes: int = 4) -> torch.Tensor:
    img = torch.zeros(3, h, w)
    for _ in range(num_shapes):
        color = torch.rand(3, 1, 1)
        y0 = random.randint(0, h - 8)
        x0 = random.randint(0, w - 8)
        y1 = min(h, y0 + random.randint(6, h // 2))
        x1 = min(w, x0 + random.randint(6, w // 2))
        img[:, y0:y1, x0:x1] = color
    return img


def make_pair_sampler(
    image_shape: Tuple[int, int, int],
    mirror_prob: float = 0.3,
    angle_step: float = 5.0,
) -> Callable[[torch.device], Tuple[torch.Tensor, torch.Tensor, bool]]:
    _, h, w = image_shape

    def sampler(device: torch.device):
        target = random_target_image(h, w)
        is_mirrored = random.random() < mirror_prob
        source = torch.flip(target, dims=[2]) if is_mirrored else target.clone()
        angle = random.choice(np.arange(0.0, 360.0, angle_step))
        source = rotate_tensor(source, angle)
        return source, target, is_mirrored

    return sampler


In [11]:
def select_action(qnet: nn.Module, state: torch.Tensor, epsilon: float, action_space) -> int:
    if random.random() < epsilon:
        return action_space.sample()
    with torch.no_grad():
        if state.dim() == 3:
            state = state.unsqueeze(0)
        q = qnet(state)
        return int(q.argmax(dim=1).item())


def train_dqn(
    env: RotationEnv,
    qnet: QNetwork,
    target_net: QNetwork,
    buffer: ReplayBuffer,
    episodes: int = 2000,
    batch_size: int = 64,
    gamma: float = 0.99,
    eps_start: float = 1.0,
    eps_end: float = 0.05,
    eps_decay_episodes: int = 1000,
    target_update_steps: int = 1000,
    lr: float = 1e-4,
    max_steps: int = 50,
):
    device = next(qnet.parameters()).device
    optimizer = torch.optim.AdamW(qnet.parameters(), lr=lr)
    huber = nn.SmoothL1Loss()

    qnet.train()
    target_net.eval()

    global_step = 0
    rewards_history = []

    for ep in range(1, episodes + 1):
        epsilon = eps_end + (eps_start - eps_end) * max(0.0, (eps_decay_episodes - ep) / eps_decay_episodes)
        state, _ = env.reset()
        ep_reward = 0.0

        for _ in range(max_steps):
            action = select_action(qnet, state, epsilon, env.action_space)
            next_state, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated

            buffer.push(state, action, reward, next_state, done)
            state = next_state
            ep_reward += reward
            global_step += 1

            if buffer.size >= batch_size:
                states, actions, rewards, next_states, dones = buffer.sample(batch_size, device)
                q_values = qnet(states).gather(1, actions.unsqueeze(1)).squeeze(1)
                with torch.no_grad():
                    next_q = target_net(next_states).max(dim=1)[0]
                    target = rewards + gamma * (1.0 - dones) * next_q
                loss = huber(q_values, target)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            if global_step % target_update_steps == 0:
                target_net.load_state_dict(qnet.state_dict())

            if done:
                break

        rewards_history.append(ep_reward)
        if ep % 50 == 0:
            print(f"Episode {ep:4d} | eps {epsilon:.3f} | reward {ep_reward:.1f}")

    return rewards_history


In [17]:
# Example: create env, networks, and buffer
set_seed(0)

H, W = 64, 64
image_shape = (3, H, W)

sampler = make_pair_sampler(image_shape=image_shape, mirror_prob=0.3, angle_step=5.0)

env = RotationEnv(
    image_shape=image_shape,
    max_steps=50,
    epsilon=0.01,
    step_penalty=-0.1,
    device=device,
    pair_sampler=sampler,
)

qnet = QNetwork(input_shape=(6, H, W)).to(device)
target_net = QNetwork(input_shape=(6, H, W)).to(device)
target_net.load_state_dict(qnet.state_dict())

buffer = ReplayBuffer(capacity=10_000, state_shape=(6, H, W), store_uint8=True)

# Uncomment to train
rewards = train_dqn(env, qnet, target_net, buffer, episodes=1000)


Episode   50 | eps 0.953 | reward -100.0
Episode  100 | eps 0.905 | reward -100.1
Episode  150 | eps 0.858 | reward -100.0
Episode  200 | eps 0.810 | reward -100.0
Episode  250 | eps 0.762 | reward -100.1
Episode  300 | eps 0.715 | reward -100.0
Episode  350 | eps 0.667 | reward -100.1
Episode  400 | eps 0.620 | reward -100.4
Episode  450 | eps 0.573 | reward 99.9
Episode  500 | eps 0.525 | reward -100.0
Episode  550 | eps 0.477 | reward 99.0
Episode  600 | eps 0.430 | reward -100.1
Episode  650 | eps 0.382 | reward 99.5
Episode  700 | eps 0.335 | reward -100.0
Episode  750 | eps 0.287 | reward -100.5
Episode  800 | eps 0.240 | reward -100.1
Episode  850 | eps 0.193 | reward -101.0
Episode  900 | eps 0.145 | reward -101.9
Episode  950 | eps 0.098 | reward -104.6
Episode 1000 | eps 0.050 | reward -103.4


In [19]:
def render_pair_frame(source: torch.Tensor, target: torch.Tensor) -> np.ndarray:
    # Return a side-by-side frame (H, 2W, 3) as uint8.
    src = (source.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
    tgt = (target.clamp(0, 1).permute(1, 2, 0).detach().cpu().numpy() * 255).astype(np.uint8)
    return np.concatenate([src, tgt], axis=1)


def rollout_and_save_gif(
    env: RotationEnv,
    qnet: QNetwork,
    out_path: str = "rotation_episode.gif",
    max_steps: int = 50,
    fps: int = 8,
):
    state, _ = env.reset()
    frames = []

    for _ in range(max_steps):
        frames.append(render_pair_frame(env.current_source, env.target))
        action = select_action(qnet, state, epsilon=0.0, action_space=env.action_space)
        state, _, terminated, truncated, _ = env.step(action)
        if terminated or truncated:
            frames.append(render_pair_frame(env.current_source, env.target))
            break

    imageio.mimsave(out_path, frames, fps=fps)
    return out_path

# Uncomment to generate a GIF (after training)
gif_path = rollout_and_save_gif(env, qnet, out_path="rotation_episode.gif")
gif_path


'rotation_episode.gif'