# Patch Merging Layer Test Notebook

This notebook tests the PatchMerging layer implementation for Swin Transformer.

In [None]:
# 1. Import Required Libraries and Modules
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
import torch
from src.models.swin.patch_merging import PatchMerging

## 2. Create Dummy Patch Embedding Output
Simulate the output of a patch embedding layer for a batch of images.

In [None]:
# Assume input images are 32x32, patch size is 4, so 8x8 patches
B = 2  # batch size
H = W = 8  # patch grid size
C = 48  # embed dim
x = torch.randn(B, H * W, C)
print('Input shape:', x.shape)  # [B, 64, 48]

## 3. Apply Patch Merging Layer
Test the PatchMerging layer and verify output shape.

In [None]:
patch_merging = PatchMerging(input_resolution=(H, W), dim=C)
out = patch_merging(x)
print('Output shape:', out.shape)  # [B, 32, 96]
expected_shape = (B, (H // 2) * (W // 2), 2 * C)
assert out.shape == expected_shape, f'Expected {expected_shape}, got {out.shape}'
print('PatchMerging output shape verified.')