In [1]:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1

env: ANYWIDGET_HMR=1


[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bobleesj/quantem.widget/blob/main/notebooks/mark2d/mark2d_all_features.ipynb)

# Mark2D — All Features

Comprehensive demo of every Mark2D capability:
basic atom picking, custom scale/dot size, image replacement, coordinate retrieval,
lattice basis definition, PyTorch tensor input, gallery mode, snap-to-peak,
and state save/load for reproducible analysis.

## 1. Basic HAADF-STEM atom picking

Hexagonal lattice simulating a [110] zone axis. Click on bright atom columns to select positions.

In [2]:
import numpy as np
from quantem.widget import Mark2D


def make_haadf_stem(size=256, spacing=18, sigma=2.8):
    """HAADF-STEM image with atomic columns on a hexagonal lattice."""
    y, x = np.mgrid[:size, :size]
    img = np.random.normal(0.08, 0.015, (size, size))
    a1 = np.array([spacing, 0.0])
    a2 = np.array([spacing * 0.5, spacing * np.sqrt(3) / 2])
    for i in range(-1, size // spacing + 2):
        for j in range(-1, size // spacing + 2):
            cx = i * a1[0] + j * a2[0]
            cy = i * a1[1] + j * a2[1]
            if -spacing < cx < size + spacing and -spacing < cy < size + spacing:
                intensity = 0.7 + 0.3 * ((i + j) % 3 == 0)
                img += intensity * np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * sigma**2))
    scan_noise = np.random.normal(0, 0.01, (size, 1)) * np.ones((1, size))
    img += scan_noise
    return np.clip(img, 0, None).astype(np.float32)


haadf = make_haadf_stem()
w1 = Mark2D(haadf, max_points=3)
w1

Mark2D(256×256, pts=0)

## 2. Custom scale, dot size, max points

Zoomed-in view with larger markers and more allowed selections.

In [3]:
w2 = Mark2D(haadf, scale=2.0, dot_size=18, max_points=10)
w2

Mark2D(256×256, pts=0)

## 3. Replace image with `set_image()`

Switch between two different zone axes without creating a new widget.
The cubic [001] pattern has a simple square lattice, while the hexagonal
pattern above has alternating column intensities.

In [4]:
def make_cubic_stem(size=256, spacing=20, sigma=2.5):
    """HAADF-STEM of cubic [001] zone axis."""
    y, x = np.mgrid[:size, :size]
    img = np.random.normal(0.08, 0.015, (size, size))
    for i in range(-1, size // spacing + 2):
        for j in range(-1, size // spacing + 2):
            cx = i * spacing
            cy = j * spacing
            if -spacing < cx < size + spacing and -spacing < cy < size + spacing:
                img += 0.8 * np.exp(-((x - cx)**2 + (y - cy)**2) / (2 * sigma**2))
    scan_noise = np.random.normal(0, 0.01, (size, 1)) * np.ones((1, size))
    img += scan_noise
    return np.clip(img, 0, None).astype(np.float32)


cubic = make_cubic_stem()
w3 = Mark2D(haadf, scale=1.0, max_points=5)
w3

Mark2D(256×256, pts=0)

In [5]:
# Replace the hexagonal image with the cubic [001] zone axis
w3.set_image(cubic)
print("Image replaced: now showing cubic [001] zone axis")

Image replaced: now showing cubic [001] zone axis


## 4. Inspect widget state

Use `summary()` to see a detailed breakdown of all widgets — image info, placed points, ROIs, display settings.

In [6]:
for name, widget in [("Hexagonal", w1), ("Zoomed", w2), ("Cubic", w3)]:
    print(f"--- {name} ---")
    widget.summary()
    print()

--- Hexagonal ---
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.008141  max=1.124  mean=0.1899  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/3
Marker:   circle red  size=12px

--- Zoomed ---
Mark2D
════════════════════════════════
Image:    256×256  scale=2.0x
Data:     min=0.008141  max=1.124  mean=0.1899  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/10
Marker:   circle red  size=18px

--- Cubic ---
Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0.005092  max=0.934  mean=0.1553  dtype=float32
Display:  gray | auto contrast | linear
Points:   0/5
Marker:   circle red  size=12px



## 5. Define lattice basis from 3 points

Pick 3 atom columns on `w1` above: an origin and two nearest neighbors.
Then run this cell to compute lattice vectors **u** and **v**, plus the
angle between them.

In [7]:
points = w1.selected_points
if len(points) < 3:
    print("Click 3 atom columns on w1 above, then re-run this cell.")
else:
    origin = np.array([points[0]["row"], points[0]["col"]])
    p1 = np.array([points[1]["row"], points[1]["col"]])
    p2 = np.array([points[2]["row"], points[2]["col"]])
    u = p1 - origin
    v = p2 - origin
    angle = np.degrees(np.arccos(
        np.dot(u, v) / (np.linalg.norm(u) * np.linalg.norm(v))
    ))
    print(f"Origin: (row={origin[0]:.1f}, col={origin[1]:.1f})")
    print(f"u = ({u[0]:.1f}, {u[1]:.1f}), |u| = {np.linalg.norm(u):.1f} px")
    print(f"v = ({v[0]:.1f}, {v[1]:.1f}), |v| = {np.linalg.norm(v):.1f} px")
    print(f"Angle(u, v) = {angle:.1f} degrees")
    print(f"\nExpected for hexagonal: |u| ~ |v| ~ 18 px, angle ~ 60 degrees")

Click 3 atom columns on w1 above, then re-run this cell.


## 6. PyTorch tensor input

Mark2D accepts both NumPy arrays and PyTorch tensors.

In [8]:
import torch

haadf_tensor = torch.from_numpy(haadf)
print(f"Tensor shape: {haadf_tensor.shape}, dtype: {haadf_tensor.dtype}")

w4 = Mark2D(haadf_tensor, scale=1.5, dot_size=14, max_points=5)
w4

Tensor shape: torch.Size([256, 256]), dtype: torch.float32


Mark2D(256×256, pts=0)

## 7. Gallery mode — pick points across multiple images

Pass a list of images to pick points on each independently.
Click an unselected image to select it. Only the selected image allows point placement.

In [9]:
# Gallery with 3 different crystal structures
hexagonal = make_haadf_stem(size=128, spacing=18)
cubic = make_cubic_stem(size=128, spacing=20)

# Ring pattern (simulated amorphous diffraction)
yy, xx = np.mgrid[:128, :128]
r = np.sqrt((xx - 64)**2 + (yy - 64)**2)
ring = (np.exp(-(r - 40)**2 / 20) + 0.5 * np.exp(-(r - 20)**2 / 10)).astype(np.float32)

w5 = Mark2D(
    [hexagonal, cubic, ring],
    ncols=3,
    max_points=5,
    labels=["Hex [110]", "Cubic [001]", "Ring"],
)
w5

Mark2D(3×128×128, idx=0, pts=0)

In [10]:
w5.summary()

Mark2D
════════════════════════════════
Image:    3×128×128 (3 cols)
Data:     min=0  max=1.11  mean=0.1713  dtype=float32
Display:  gray | auto contrast | linear
Points [Hex [110]]: 0/5
Points [Cubic [001]]: 0/5
Points [Ring]: 0/5
Marker:   circle red  size=12px


## 8. Gallery with torch tensors

In [11]:
# Gallery with torch tensors
t1 = torch.from_numpy(hexagonal)
t2 = torch.from_numpy(cubic)
w6 = Mark2D([t1, t2], ncols=2, max_points=4, labels=["Hex (torch)", "Cubic (torch)"])
w6

Mark2D(2×128×128, idx=0, pts=0)

In [12]:
w6.summary()

Mark2D
════════════════════════════════
Image:    2×128×128 (2 cols)
Data:     min=0.004098  max=1.11  mean=0.1854  dtype=float32
Display:  gray | auto contrast | linear
Points [Hex (torch)]: 0/4
Points [Cubic (torch)]: 0/4
Marker:   circle red  size=12px


## 9. Snap-to-peak on a sharp diffraction pattern

Snap-to-peak finds the nearest local intensity maximum within a search radius.
This is most useful on images with sharp, well-separated peaks — like electron
diffraction patterns with Bragg spots.

**Try it:** Click anywhere *near* a Bragg spot. With snap enabled (green),
your point jumps to the exact peak center. Toggle snap off to see the
difference — points land exactly where you click instead.

In [13]:
def make_diffraction_pattern(size=256, spot_sigma=0.8):
    """Electron diffraction pattern with sharp Bragg spots on a hexagonal reciprocal lattice."""
    img = np.random.normal(0.02, 0.005, (size, size))
    cx, cy = size // 2, size // 2
    y, x = np.mgrid[:size, :size]

    # Hexagonal reciprocal lattice
    a = 28  # spot spacing (px)
    g1 = np.array([a, 0.0])
    g2 = np.array([a * 0.5, a * np.sqrt(3) / 2])

    for i in range(-6, 7):
        for j in range(-6, 7):
            sx = cx + i * g1[0] + j * g2[0]
            sy = cy + i * g1[1] + j * g2[1]
            if 0 <= sx < size and 0 <= sy < size:
                dist = np.sqrt((sx - cx) ** 2 + (sy - cy) ** 2)
                # Intensity envelope: central beam bright, outer spots dimmer
                intensity = np.exp(-dist**2 / (2 * (3 * a) ** 2))
                if i == 0 and j == 0:
                    intensity = 1.0
                img += intensity * np.exp(
                    -((x - sx) ** 2 + (y - sy) ** 2) / (2 * spot_sigma**2)
                )
    return np.clip(img, 0, None).astype(np.float32)


diffraction = make_diffraction_pattern()

# Snap enabled with 8px search radius — clicks jump to the nearest Bragg spot
w7 = Mark2D(
    diffraction,
    snap_enabled=True,
    snap_radius=8,
    max_points=10,
    dot_size=8,
    colormap="viridis",
    log_scale=True,
)
w7

Mark2D(256×256, pts=0, cmap=viridis, log, snap r=8)

In [14]:
w7.summary()

Mark2D
════════════════════════════════
Image:    256×256
Data:     min=0  max=1.014  mean=0.02314  dtype=float32
Display:  viridis | auto contrast | log
Points:   0/10
Marker:   circle red  size=8px
Snap:     ON (radius=8 px)


### Side-by-side: snap on vs snap off

Same diffraction pattern, same settings — only snap differs.
Click near the same Bragg spot in both images to compare precision.

In [15]:
# Gallery: snap OFF (left) vs snap ON (right)
w8 = Mark2D(
    [diffraction, diffraction],
    ncols=2,
    max_points=5,
    dot_size=8,
    colormap="viridis",
    log_scale=True,
    snap_enabled=True,
    snap_radius=8,
    labels=["Snap OFF (toggle it off)", "Snap ON (default)"],
)
w8

Mark2D(2×256×256, idx=0, pts=0, cmap=viridis, log, snap r=8)

## 10. Save and load state

All widget state — points, ROIs, profile lines, display settings — can be saved
to a JSON file with `save()` and restored with the `state` parameter. This lets
you resume analysis after a kernel restart or share exact results with a colleague.

In [16]:
# Create a widget with pre-placed points and custom settings
w9 = Mark2D(
    haadf,
    points=[(36, 36), (54, 36), (45, 52)],
    snap_enabled=True,
    snap_radius=8,
    colormap="viridis",
    marker_shape="diamond",
    marker_color="#00bcd4",
    title="HAADF analysis",
    pixel_size_angstrom=1.5,
)
w9.add_roi(128, 128, mode="circle", radius=30)
w9

HAADF analysis(256×256, px=1.50 Å, pts=3, rois=1, cmap=viridis, snap r=8)

In [17]:
# Save all state to a JSON file
w9.save("haadf_analysis.json")
print("Saved to haadf_analysis.json")

Saved to haadf_analysis.json


In [18]:
# Restore from file — same image, all state comes back
w10 = Mark2D(haadf, state="haadf_analysis.json")
print(f"Restored: {len(w10.selected_points)} points, {len(w10.roi_list)} ROIs")
print(f"Colormap: {w10.colormap}, title: {w10.title}")
w10

Restored: 3 points, 1 ROIs
Colormap: viridis, title: HAADF analysis


HAADF analysis(256×256, px=1.50 Å, pts=3, rois=1, cmap=viridis, snap r=8)

In [19]:
# The JSON file is small and human-readable — great for version control
import json
from pathlib import Path

state = json.loads(Path("haadf_analysis.json").read_text())
print(json.dumps(state, indent=2))

{
  "selected_points": [
    {
      "row": 36,
      "col": 36,
      "shape": "circle",
      "color": "#f44336"
    },
    {
      "row": 54,
      "col": 36,
      "shape": "triangle",
      "color": "#4caf50"
    },
    {
      "row": 45,
      "col": 52,
      "shape": "square",
      "color": "#2196f3"
    }
  ],
  "roi_list": [
    {
      "id": 0,
      "mode": "circle",
      "row": 128,
      "col": 128,
      "radius": 30,
      "rectW": 60,
      "rectH": 40,
      "color": "#0f0",
      "opacity": 0.8
    }
  ],
  "profile_line": [],
  "selected_idx": 0,
  "marker_shape": "diamond",
  "marker_color": "#00bcd4",
  "dot_size": 12,
  "max_points": 10,
  "marker_border": 2,
  "marker_opacity": 1.0,
  "label_size": 0,
  "label_color": "",
  "snap_enabled": true,
  "snap_radius": 8,
  "colormap": "viridis",
  "auto_contrast": true,
  "log_scale": false,
  "show_fft": false,
  "show_stats": true,
  "show_controls": true,
  "percentile_low": 2.0,
  "percentile_high": 98.0,
  "tit

In [20]:
# Clean up the saved file
p = Path("haadf_analysis.json")
if p.exists():
    p.unlink()