# CityScale Dataset Visualization

This notebook visualizes samples from the CityScale dataset, showing:
- Original satellite image
- Junction map overlayed with ground truth edges
- Offset vectors as a quiver plot

In [None]:
import sys
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.collections import LineCollection
import torch

# Import the dataset
from data.dataset import CityScale

In [None]:
!pwd

In [None]:
# Load the dataset
data_root = '../../Sat2Graph/data/data/'  # Adjust this path as needed
dataset = CityScale(data_root=data_root, split='train', img_size=512, stride=32, aug=False)

print(f"Dataset size: {len(dataset)} samples")

In [None]:
# Get a sample from the dataset
sample_idx = 0  # Change this to visualize different samples
sample = dataset[sample_idx]

# Extract data
image = sample['image'].permute(1, 2, 0).numpy()  # [H, W, 3]
junction_map = sample['junction_map'][0].numpy()  # [h, w]
offset_map = sample['offset_map'].numpy()  # [2, h, w]
offset_mask = sample['offset_mask'][0].numpy()  # [h, w]
edges = sample['edges']
meta = sample['meta']

print(f"Region ID: {meta['region_id']}")
print(f"Image shape: {image.shape}")
print(f"Junction map shape: {junction_map.shape}")
print(f"Number of junctions: {int(junction_map.sum())}")
print(f"Number of edges: {len(edges)}")



In [None]:
# Helper function to convert grid coordinates to pixel coordinates with offsets
def grid_to_pixel_with_offset(i, j, offset_map, stride=32):
    """
    Convert grid cell (i, j) to pixel coordinates using the offset map.
    
    Args:
        i, j: Grid cell indices
        offset_map: [2, h, w] offset map in normalized coordinates [-0.5, 0.5]
        stride: Grid stride (default 32)
    
    Returns:
        (y, x): Pixel coordinates
    """
    # Cell center in pixel coordinates
    center_y = (i + 0.5) * stride
    center_x = (j + 0.5) * stride
    
    # Add offset (convert from normalized [-0.5, 0.5] to pixels)
    offset_y = offset_map[0, i, j] * stride
    offset_x = offset_map[1, i, j] * stride
    
    return center_y + offset_y, center_x + offset_x

In [None]:
# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 16))

# 1. Original Image
ax = axes[0, 0]
ax.imshow(image)
ax.set_title(f'Original Image (Region {meta["region_id"]})', fontsize=14)
ax.axis('off')

# 2. Junction Map Overlay
ax = axes[0, 1]
ax.imshow(image)

# Overlay junction map
h, w = junction_map.shape
stride = image.shape[0] // h

# Draw junctions as red circles
junction_coords = np.argwhere(junction_map > 0.5)
for i, j in junction_coords:
    y, x = grid_to_pixel_with_offset(i, j, offset_map, stride)
    circle = plt.Circle((x, y), radius=3, color='red', fill=True, alpha=0.8)
    ax.add_patch(circle)

ax.set_title(f'Junction Map Overlay ({len(junction_coords)} junctions)', fontsize=14)
ax.axis('off')

# 3. Junction Map + Edges
ax = axes[1, 0]
ax.imshow(image)

# Draw edges as lines
for (i1, j1), (i2, j2) in edges:
    y1, x1 = grid_to_pixel_with_offset(i1, j1, offset_map, stride)
    y2, x2 = grid_to_pixel_with_offset(i2, j2, offset_map, stride)
    ax.plot([x1, x2], [y1, y2], 'g-', linewidth=1.5, alpha=0.6)

# Draw junctions on top
for i, j in junction_coords:
    y, x = grid_to_pixel_with_offset(i, j, offset_map, stride)
    circle = plt.Circle((x, y), radius=3, color='red', fill=True, alpha=0.8)
    ax.add_patch(circle)

ax.set_title(f'Junction Map + Edges ({len(edges)} edges)', fontsize=14)
ax.axis('off')

# 4. Offset Quiver Plot
ax = axes[1, 1]
ax.imshow(image)

# Create quiver plot for offsets
# Only show offsets where there are junctions
Y, X = np.meshgrid(np.arange(h), np.arange(w), indexing='ij')

# Convert grid positions to pixel positions (cell centers)
Y_pixel = (Y + 0.5) * stride
X_pixel = (X + 0.5) * stride

# Get offset vectors in pixel units
U = offset_map[1] * stride  # x-offset
V = offset_map[0] * stride  # y-offset

# Only show offsets where mask is active
mask = offset_mask > 0.5

# Draw quiver plot
if mask.sum() > 0:
    ax.quiver(X_pixel[mask], Y_pixel[mask], U[mask], V[mask],
              color='yellow', angles='xy', scale_units='xy', scale=1,
              width=0.003, headwidth=4, headlength=5, alpha=0.8)

# Draw junctions at their corrected positions
for i, j in junction_coords:
    y, x = grid_to_pixel_with_offset(i, j, offset_map, stride)
    circle = plt.Circle((x, y), radius=3, color='red', fill=True, alpha=0.8)
    ax.add_patch(circle)
    # Draw cell center
    center_y = (i + 0.5) * stride
    center_x = (j + 0.5) * stride
    circle = plt.Circle((center_x, center_y), radius=2, color='cyan', fill=True, alpha=0.6)
    ax.add_patch(circle)

ax.set_title('Offset Quiver Plot\n(Cyan: cell centers, Red: corrected positions, Yellow: offsets)', fontsize=14)
ax.axis('off')

plt.tight_layout()
plt.savefig('dataset_visualization.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nVisualization saved to dataset_visualization.png")

In [None]:
# Print some statistics
print("\n=== Dataset Statistics ===")
print(f"Offset range: [{offset_map.min():.3f}, {offset_map.max():.3f}]")
# print(f"Mean absolute offset: {np.abs(offset_map[offset_mask > 0.5]).mean():.3f}")
print(f"\nJunction density: {junction_map.sum() / (h * w):.4f}")
print(f"Average degree: {len(edges) / max(1, junction_map.sum()):.2f}")

In [None]:
# Visualize multiple samples in a grid
num_samples = min(4, len(dataset))
fig, axes = plt.subplots(2, num_samples, figsize=(5*num_samples, 10))

if num_samples == 1:
    axes = axes.reshape(-1, 1)

for idx in range(num_samples):
    sample = dataset[idx]
    image = sample['image'].permute(1, 2, 0).numpy()
    junction_map = sample['junction_map'][0].numpy()
    offset_map = sample['offset_map'].numpy()
    edges = sample['edges']
    
    h, w = junction_map.shape
    stride = image.shape[0] // h
    
    # Top row: Original images
    axes[0, idx].imshow(image)
    axes[0, idx].set_title(f'Sample {idx}', fontsize=12)
    axes[0, idx].axis('off')
    
    # Bottom row: With graph overlay
    axes[1, idx].imshow(image)
    
    # Draw edges
    for (i1, j1), (i2, j2) in edges:
        y1, x1 = grid_to_pixel_with_offset(i1, j1, offset_map, stride)
        y2, x2 = grid_to_pixel_with_offset(i2, j2, offset_map, stride)
        axes[1, idx].plot([x1, x2], [y1, y2], 'g-', linewidth=1, alpha=0.6)
    
    # Draw junctions
    junction_coords = np.argwhere(junction_map > 0.5)
    for i, j in junction_coords:
        y, x = grid_to_pixel_with_offset(i, j, offset_map, stride)
        circle = plt.Circle((x, y), radius=2, color='red', fill=True, alpha=0.8)
        axes[1, idx].add_patch(circle)
    
    axes[1, idx].set_title(f'{len(edges)} edges, {len(junction_coords)} junctions', fontsize=10)
    axes[1, idx].axis('off')

plt.tight_layout()
plt.savefig('multiple_samples.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nMultiple samples visualization saved to multiple_samples.png")