# Convolution Layers in BrainState

This tutorial provides a comprehensive guide to using convolution layers in BrainState, covering both standard convolutions and transposed convolutions for various deep learning applications.

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/chaobrain/brainstate/blob/main/docs/tutorials/convolution_layers-en.ipynb)
[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/chaobrain/brainstate/blob/main/docs/tutorials/convolution_layers-en.ipynb)

## Table of Contents

1. [Introduction](#introduction)
2. [Standard Convolutions](#standard-convolutions)
   - [Conv1d: 1D Convolutions](#conv1d)
   - [Conv2d: 2D Convolutions](#conv2d)
   - [Conv3d: 3D Convolutions](#conv3d)
3. [Weight-Standardized Convolutions](#weight-standardized-convolutions)
   - [ScaledWSConv1d, ScaledWSConv2d, ScaledWSConv3d](#scaled-ws-convolutions)
4. [Transposed Convolutions](#transposed-convolutions)
   - [ConvTranspose1d, ConvTranspose2d, ConvTranspose3d](#conv-transpose)
5. [Channel Formats: channels-last vs channels-first](#channel-formats)
6. [Advanced Features](#advanced-features)
   - [Grouped Convolutions](#grouped-convolutions)
   - [Dilated Convolutions](#dilated-convolutions)
   - [Padding Modes](#padding-modes)
7. [Practical Examples](#practical-examples)
   - [Building a CNN for Image Classification](#cnn-example)
   - [Building an Autoencoder](#autoencoder-example)
   - [U-Net for Segmentation](#unet-example)

In [None]:
# Installation (if needed)
# !pip install brainstate --upgrade

In [None]:
import brainstate
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

# Set random seed for reproducibility
jax.config.update('jax_platform_name', 'cpu')
print(f"BrainState version: { brainstate.__version__}")
print(f"JAX version: {jax.__version__}")

<a id="introduction"></a>
## 1. Introduction

Convolution layers are fundamental building blocks in deep learning, especially for processing grid-like data such as images, audio, and video. BrainState provides a comprehensive set of convolution layers that are:

- **Flexible**: Support both JAX-style (channels-last) and PyTorch-style (channels-first) data formats
- **Efficient**: Built on top of JAX's high-performance `conv_general_dilated` operation
- **Feature-rich**: Include standard convolutions, weight-standardized convolutions, and transposed convolutions
- **Well-documented**: Comprehensive docstrings with examples and comparisons to PyTorch

### Convolution Types in BrainState

| Layer Type | 1D | 2D | 3D | Use Cases |
|-----------|----|----|----|-----------|
| **Standard Conv** | `Conv1d` | `Conv2d` | `Conv3d` | Feature extraction, downsampling |
| **Weight-Standardized** | `ScaledWSConv1d` | `ScaledWSConv2d` | `ScaledWSConv3d` | Improved training stability |
| **Transposed Conv** | `ConvTranspose1d` | `ConvTranspose2d` | `ConvTranspose3d` | Upsampling, generation |

<a id="standard-convolutions"></a>
## 2. Standard Convolutions

Standard convolutions are used for feature extraction and spatial downsampling in neural networks.

<a id="conv1d"></a>
### Conv1d: 1D Convolutions

1D convolutions are commonly used for sequence data such as audio signals, text, and time series.

In [None]:
# Example 1: Basic Conv1d for time series
# Input: (batch_size, time_steps, channels)
batch_size = 4
time_steps = 100
in_channels = 3
out_channels = 16

# Create Conv1d layer
conv1d =  brainstate.nn.Conv1d(
    in_size=(time_steps, in_channels),
    out_channels=out_channels,
    kernel_size=5,
    stride=1,
    padding='SAME'
)

# Create sample input
x = jnp.ones((batch_size, time_steps, in_channels))

# Forward pass
y = conv1d(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Kernel shape: {conv1d.kernel_shape}")
print(f"Number of parameters: {np.prod(conv1d.kernel_shape)}")

In [None]:
# Example 2: Conv1d with different stride and padding
conv1d_downsample =  brainstate.nn.Conv1d(
    in_size=(100, 16),
    out_channels=32,
    kernel_size=3,
    stride=2,  # Downsample by factor of 2
    padding='SAME'
)

x = jnp.ones((4, 100, 16))
y = conv1d_downsample(x)

print(f"Input shape: {x.shape}")
print(f"Output shape (stride=2): {y.shape}")
print(f"Downsampling factor: {x.shape[1] / y.shape[1]:.1f}x")

<a id="conv2d"></a>
### Conv2d: 2D Convolutions

2D convolutions are the workhorses of computer vision, used in virtually all image processing networks.

In [None]:
# Example 1: Basic Conv2d for images
# Input: (batch_size, height, width, channels) - channels-last format
batch_size = 8
height, width = 32, 32
in_channels = 3  # RGB
out_channels = 64

# Create Conv2d layer
conv2d =  brainstate.nn.Conv2d(
    in_size=(height, width, in_channels),
    out_channels=out_channels,
    kernel_size=3,
    stride=1,
    padding='SAME',
    b_init= brainstate.init.Constant(0.0)  # Add bias
)

# Create sample input (e.g., RGB images)
x = jnp.ones((batch_size, height, width, in_channels))

# Forward pass
y = conv2d(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Kernel shape: {conv2d.kernel_shape}")
print(f"Has bias: {'bias' in conv2d.weight.value}")

In [None]:
# Example 2: Conv2d with different kernel sizes
# Visualize how different kernel sizes affect the receptive field

kernel_sizes = [1, 3, 5, 7]
for k in kernel_sizes:
    conv =  brainstate.nn.Conv2d(
        in_size=(32, 32, 3),
        out_channels=16,
        kernel_size=k,
        padding='SAME'
    )
    x = jnp.ones((1, 32, 32, 3))
    y = conv(x)
    params = np.prod(conv.kernel_shape)
    print(f"Kernel size {k}x{k}: Output shape {y.shape}, Parameters: {params:,}")

<a id="conv3d"></a>
### Conv3d: 3D Convolutions

3D convolutions are used for video processing, 3D medical imaging, and volumetric data.

In [None]:
# Example: Conv3d for video data
# Input: (batch_size, frames, height, width, channels)
batch_size = 2
frames = 16
height, width = 64, 64
in_channels = 3
out_channels = 32

# Create Conv3d layer
conv3d =  brainstate.nn.Conv3d(
    in_size=(frames, height, width, in_channels),
    out_channels=out_channels,
    kernel_size=3,  # 3x3x3 kernel
    stride=1,
    padding='SAME'
)

# Create sample input (e.g., video clip)
x = jnp.ones((batch_size, frames, height, width, in_channels))

# Forward pass
y = conv3d(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Kernel shape: {conv3d.kernel_shape}")
print(f"Parameters: {np.prod(conv3d.kernel_shape):,}")

<a id="weight-standardized-convolutions"></a>
## 3. Weight-Standardized Convolutions

Weight standardization is a technique that normalizes convolutional weights to have zero mean and unit variance, which can improve training stability and performance, especially when combined with Group Normalization.

**Reference**: [Weight Standardization (Qiao et al., 2019)](https://arxiv.org/abs/1903.10520)

<a id="scaled-ws-convolutions"></a>
### ScaledWSConv: Weight Standardization with Learnable Gain

In [None]:
# Example: ScaledWSConv2d for improved training
ws_conv =  brainstate.nn.ScaledWSConv2d(
    in_size=(32, 32, 64),
    out_channels=128,
    kernel_size=3,
    ws_gain=True,  # Include learnable per-channel gain
    eps=1e-4  # Small constant for numerical stability
)

x = jnp.ones((4, 32, 32, 64))
y = ws_conv(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Has gain parameter: {'gain' in ws_conv.weight.value}")

# Weight standardization formula:
# W_hat = gain * (W - mean(W)) / (std(W) + eps)
print("\nWeight standardization normalizes weights to:")
print("- Zero mean")
print("- Unit variance")
print("- With optional learnable gain for expressiveness")

In [None]:
# Comparison: Standard Conv vs Weight-Standardized Conv
print("Standard Conv2d:")
standard_conv =  brainstate.nn.Conv2d(
    in_size=(32, 32, 64),
    out_channels=128,
    kernel_size=3
)
print(f"  Parameters: {list(standard_conv.weight.value.keys())}")

print("\nWeight-Standardized Conv2d:")
ws_conv =  brainstate.nn.ScaledWSConv2d(
    in_size=(32, 32, 64),
    out_channels=128,
    kernel_size=3,
    ws_gain=True
)
print(f"  Parameters: {list(ws_conv.weight.value.keys())}")
print(f"  Gain shape: {ws_conv.weight.value['gain'].shape}")

<a id="transposed-convolutions"></a>
## 4. Transposed Convolutions

Transposed convolutions (also called deconvolutions) are used for upsampling feature maps. They're essential in:
- Autoencoders (decoder path)
- Generative models (GANs, VAEs)
- Semantic segmentation (U-Net, FCN)
- Super-resolution networks

<a id="conv-transpose"></a>
### ConvTranspose1d, ConvTranspose2d, ConvTranspose3d

In [None]:
# Example 1: ConvTranspose1d for upsampling sequences
conv_transpose_1d =  brainstate.nn.ConvTranspose1d(
    in_size=(50, 32),  # (time_steps, channels)
    out_channels=16,
    kernel_size=4,
    stride=2,  # Upsample by 2x
    padding='SAME'
)

x = jnp.ones((4, 50, 32))
y = conv_transpose_1d(x)

print("ConvTranspose1d:")
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {y.shape}")
print(f"  Upsampling factor: {y.shape[1] / x.shape[1]:.1f}x")

In [None]:
# Example 2: ConvTranspose2d for image upsampling
conv_transpose_2d =  brainstate.nn.ConvTranspose2d(
    in_size=(16, 16, 128),  # (height, width, channels)
    out_channels=64,
    kernel_size=4,
    stride=2,  # Upsample by 2x in each spatial dimension
    padding='SAME'
)

x = jnp.ones((4, 16, 16, 128))
y = conv_transpose_2d(x)

print("ConvTranspose2d:")
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {y.shape}")
print(f"  Spatial upsampling: {x.shape[1]}x{x.shape[2]} -> {y.shape[1]}x{y.shape[2]}")

In [None]:
# Example 3: Visualizing upsampling with different strides
strides = [1, 2, 3, 4]
in_size = 16

print("Effect of stride on upsampling:")
print("="*50)
for stride in strides:
    conv_t =  brainstate.nn.ConvTranspose2d(
        in_size=(in_size, in_size, 64),
        out_channels=32,
        kernel_size=4,
        stride=stride,
        padding='SAME'
    )
    x = jnp.ones((1, in_size, in_size, 64))
    y = conv_t(x)
    print(f"Stride {stride}: {in_size}x{in_size} -> {y.shape[1]}x{y.shape[2]} ({y.shape[1]/in_size:.1f}x upsampling)")

In [None]:
# Example 4: ConvTranspose3d for video upsampling
conv_transpose_3d =  brainstate.nn.ConvTranspose3d(
    in_size=(8, 16, 16, 64),  # (frames, height, width, channels)
    out_channels=32,
    kernel_size=4,
    stride=2,
    padding='SAME'
)

x = jnp.ones((2, 8, 16, 16, 64))
y = conv_transpose_3d(x)

print("ConvTranspose3d:")
print(f"  Input shape: {x.shape}")
print(f"  Output shape: {y.shape}")
print(f"  3D upsampling: {x.shape[1]}x{x.shape[2]}x{x.shape[3]} -> {y.shape[1]}x{y.shape[2]}x{y.shape[3]}")

<a id="channel-formats"></a>
## 5. Channel Formats: channels-last vs channels-first

BrainState supports both data format conventions:
- **Channels-last (default)**: JAX/TensorFlow style - `[B, H, W, C]`
- **Channels-first**: PyTorch style - `[B, C, H, W]`

Use the `channel_first` parameter to switch between formats.

In [None]:
# Example: Comparing channels-last and channels-first

# Channels-last (JAX/TensorFlow style) - DEFAULT
conv_last =  brainstate.nn.Conv2d(
    in_size=(32, 32, 3),  # (H, W, C)
    out_channels=64,
    kernel_size=3,
    channel_first=False  # Default
)

x_last = jnp.ones((8, 32, 32, 3))  # (B, H, W, C)
y_last = conv_last(x_last)

print("Channels-last format (JAX/TensorFlow):")
print(f"  Input shape: {x_last.shape} (B, H, W, C)")
print(f"  Output shape: {y_last.shape} (B, H, W, C)")
print(f"  Channel axis: -1")

print("\n" + "="*50 + "\n")

# Channels-first (PyTorch style)
conv_first =  brainstate.nn.Conv2d(
    in_size=(3, 32, 32),  # (C, H, W)
    out_channels=64,
    kernel_size=3,
    channel_first=True  # PyTorch compatible
)

x_first = jnp.ones((8, 3, 32, 32))  # (B, C, H, W)
y_first = conv_first(x_first)

print("Channels-first format (PyTorch):")
print(f"  Input shape: {x_first.shape} (B, C, H, W)")
print(f"  Output shape: {y_first.shape} (B, C, H, W)")
print(f"  Channel axis: 1")

In [None]:
# Converting between formats
def channels_last_to_first(x):
    """Convert (B, H, W, C) to (B, C, H, W)"""
    return jnp.transpose(x, (0, 3, 1, 2))

def channels_first_to_last(x):
    """Convert (B, C, H, W) to (B, H, W, C)"""
    return jnp.transpose(x, (0, 2, 3, 1))

# Example
x_last = jnp.ones((4, 32, 32, 3))
x_first = channels_last_to_first(x_last)
x_back = channels_first_to_last(x_first)

print(f"Original (channels-last): {x_last.shape}")
print(f"Converted to channels-first: {x_first.shape}")
print(f"Converted back: {x_back.shape}")
print(f"Arrays equal: {jnp.allclose(x_last, x_back)}")

<a id="advanced-features"></a>
## 6. Advanced Features

<a id="grouped-convolutions"></a>
### Grouped Convolutions

Grouped convolutions divide input and output channels into groups, with each group processed independently. This reduces parameters and computation.

In [None]:
# Example: Comparing standard vs grouped convolutions

in_channels = 64
out_channels = 128
kernel_size = 3

# Standard convolution
conv_standard =  brainstate.nn.Conv2d(
    in_size=(32, 32, in_channels),
    out_channels=out_channels,
    kernel_size=kernel_size,
    groups=1  # No grouping
)

# Grouped convolution (4 groups)
conv_grouped =  brainstate.nn.Conv2d(
    in_size=(32, 32, in_channels),
    out_channels=out_channels,
    kernel_size=kernel_size,
    groups=4  # Divide into 4 groups
)

# Depthwise convolution (groups = in_channels)
conv_depthwise =  brainstate.nn.Conv2d(
    in_size=(32, 32, in_channels),
    out_channels=in_channels,  # Must equal in_channels for depthwise
    kernel_size=kernel_size,
    groups=in_channels  # Each channel processed separately
)

print("Parameter comparison:")
print("="*60)
print(f"Standard convolution:    {np.prod(conv_standard.kernel_shape):,} parameters")
print(f"Grouped convolution (4): {np.prod(conv_grouped.kernel_shape):,} parameters")
print(f"Depthwise convolution:   {np.prod(conv_depthwise.kernel_shape):,} parameters")
print("\nParameter reduction with grouped convolution:")
reduction = (1 - np.prod(conv_grouped.kernel_shape) / np.prod(conv_standard.kernel_shape)) * 100
print(f"  {reduction:.1f}% fewer parameters")

<a id="dilated-convolutions"></a>
### Dilated Convolutions

Dilated (atrous) convolutions expand the receptive field without increasing the number of parameters by inserting zeros between kernel elements.

In [None]:
# Example: Dilated convolutions for larger receptive fields

dilations = [1, 2, 4, 8]
kernel_size = 3

print("Effective receptive field with dilation:")
print("="*60)

for dilation in dilations:
    conv_dilated =  brainstate.nn.Conv2d(
        in_size=(32, 32, 64),
        out_channels=64,
        kernel_size=kernel_size,
        rhs_dilation=dilation  # Kernel dilation
    )
    
    # Effective kernel size = kernel_size + (kernel_size - 1) * (dilation - 1)
    effective_kernel = kernel_size + (kernel_size - 1) * (dilation - 1)
    
    print(f"Dilation {dilation}: {kernel_size}x{kernel_size} kernel -> {effective_kernel}x{effective_kernel} receptive field")
    print(f"  Parameters: {np.prod(conv_dilated.kernel_shape):,} (same for all dilations)")

<a id="padding-modes"></a>
### Padding Modes

BrainState supports multiple padding modes:
- **'SAME'**: Output size equals input size (when stride=1)
- **'VALID'**: No padding, output size is reduced
- **Explicit padding**: Custom padding values

In [None]:
# Example: Different padding modes

in_size = (32, 32, 3)
kernel_size = 5

# SAME padding
conv_same =  brainstate.nn.Conv2d(
    in_size=in_size,
    out_channels=64,
    kernel_size=kernel_size,
    padding='SAME'
)

# VALID padding (no padding)
conv_valid =  brainstate.nn.Conv2d(
    in_size=in_size,
    out_channels=64,
    kernel_size=kernel_size,
    padding='VALID'
)

# Explicit padding
conv_custom =  brainstate.nn.Conv2d(
    in_size=in_size,
    out_channels=64,
    kernel_size=kernel_size,
    padding=[(2, 2), (2, 2)]  # (top/bottom, left/right)
)

x = jnp.ones((1, 32, 32, 3))

print("Effect of padding modes:")
print("="*60)
print(f"Input shape: {x.shape}")
print(f"SAME padding:   {conv_same(x).shape}")
print(f"VALID padding:  {conv_valid(x).shape}")
print(f"Custom padding: {conv_custom(x).shape}")

<a id="practical-examples"></a>
## 7. Practical Examples

<a id="cnn-example"></a>
### Example 1: Building a CNN for Image Classification

In [None]:
class SimpleCNN( brainstate.nn.Module):
    """Simple CNN for image classification."""
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Feature extraction layers
        self.conv1 =  brainstate.nn.Conv2d(
            in_size=(32, 32, 3),
            out_channels=32,
            kernel_size=3,
            padding='SAME'
        )
        
        self.conv2 =  brainstate.nn.Conv2d(
            in_size=(16, 16, 32),  # After pooling
            out_channels=64,
            kernel_size=3,
            padding='SAME'
        )
        
        self.conv3 =  brainstate.nn.Conv2d(
            in_size=(8, 8, 64),  # After pooling
            out_channels=128,
            kernel_size=3,
            padding='SAME'
        )
        
        # Classification head
        self.fc =  brainstate.nn.Linear(4 * 4 * 128, num_classes)
    
    def update(self, x):
        # Block 1
        x = self.conv1(x)
        x = jax.nn.relu(x)
        x = jax.lax.reduce_window(
            x, -jnp.inf, jax.lax.max, 
            window_dimensions=(1, 2, 2, 1),
            window_strides=(1, 2, 2, 1),
            padding='VALID'
        )  # Max pooling 2x2
        
        # Block 2
        x = self.conv2(x)
        x = jax.nn.relu(x)
        x = jax.lax.reduce_window(
            x, -jnp.inf, jax.lax.max,
            window_dimensions=(1, 2, 2, 1),
            window_strides=(1, 2, 2, 1),
            padding='VALID'
        )  # Max pooling 2x2
        
        # Block 3
        x = self.conv3(x)
        x = jax.nn.relu(x)
        x = jax.lax.reduce_window(
            x, -jnp.inf, jax.lax.max,
            window_dimensions=(1, 2, 2, 1),
            window_strides=(1, 2, 2, 1),
            padding='VALID'
        )  # Max pooling 2x2
        
        # Classification
        x = x.reshape(x.shape[0], -1)  # Flatten
        x = self.fc(x)
        return x

# Create and test the model
model = SimpleCNN(num_classes=10)
x = jnp.ones((4, 32, 32, 3))
logits = model(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {logits.shape}")
print(f"Number of classes: {logits.shape[-1]}")

<a id="autoencoder-example"></a>
### Example 2: Building an Autoencoder

In [None]:
class ConvAutoencoder( brainstate.nn.Module):
    """Convolutional Autoencoder for image reconstruction."""
    
    def __init__(self, latent_dim=128):
        super().__init__()
        
        # Encoder
        self.enc_conv1 =  brainstate.nn.Conv2d(
            in_size=(64, 64, 3),
            out_channels=32,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )  # 64x64 -> 32x32
        
        self.enc_conv2 =  brainstate.nn.Conv2d(
            in_size=(32, 32, 32),
            out_channels=64,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )  # 32x32 -> 16x16
        
        self.enc_conv3 =  brainstate.nn.Conv2d(
            in_size=(16, 16, 64),
            out_channels=latent_dim,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )  # 16x16 -> 8x8
        
        # Decoder (using transposed convolutions)
        self.dec_conv1 =  brainstate.nn.ConvTranspose2d(
            in_size=(8, 8, latent_dim),
            out_channels=64,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )  # 8x8 -> 16x16
        
        self.dec_conv2 =  brainstate.nn.ConvTranspose2d(
            in_size=(16, 16, 64),
            out_channels=32,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )  # 16x16 -> 32x32
        
        self.dec_conv3 =  brainstate.nn.ConvTranspose2d(
            in_size=(32, 32, 32),
            out_channels=3,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )  # 32x32 -> 64x64
    
    def encode(self, x):
        """Encode image to latent representation."""
        x = jax.nn.relu(self.enc_conv1(x))
        x = jax.nn.relu(self.enc_conv2(x))
        x = jax.nn.relu(self.enc_conv3(x))
        return x
    
    def decode(self, z):
        """Decode latent representation to image."""
        x = jax.nn.relu(self.dec_conv1(z))
        x = jax.nn.relu(self.dec_conv2(x))
        x = jax.nn.sigmoid(self.dec_conv3(x))  # Output in [0, 1]
        return x
    
    def update(self, x):
        """Full forward pass: encode then decode."""
        z = self.encode(x)
        reconstruction = self.decode(z)
        return reconstruction

# Create and test the autoencoder
autoencoder = ConvAutoencoder(latent_dim=128)
x = jnp.ones((4, 64, 64, 3))

# Encode
z = autoencoder.encode(x)
print(f"Input shape: {x.shape}")
print(f"Latent shape: {z.shape}")

# Decode
reconstruction = autoencoder.decode(z)
print(f"Reconstruction shape: {reconstruction.shape}")

# Full pass
output = autoencoder(x)
print(f"\nFull pass output shape: {output.shape}")
print(f"Reconstruction matches input shape: {output.shape == x.shape}")

<a id="unet-example"></a>
### Example 3: U-Net for Segmentation

In [None]:
class SimpleUNet( brainstate.nn.Module):
    """Simplified U-Net architecture for semantic segmentation."""
    
    def __init__(self, num_classes=2):
        super().__init__()
        
        # Encoder (downsampling path)
        self.enc1 =  brainstate.nn.Conv2d(
            in_size=(128, 128, 3),
            out_channels=64,
            kernel_size=3,
            padding='SAME'
        )
        
        self.enc2 =  brainstate.nn.Conv2d(
            in_size=(64, 64, 64),
            out_channels=128,
            kernel_size=3,
            stride=2,
            padding='SAME'
        )
        
        self.enc3 =  brainstate.nn.Conv2d(
            in_size=(32, 32, 128),
            out_channels=256,
            kernel_size=3,
            stride=2,
            padding='SAME'
        )
        
        # Bottleneck
        self.bottleneck =  brainstate.nn.Conv2d(
            in_size=(16, 16, 256),
            out_channels=512,
            kernel_size=3,
            stride=2,
            padding='SAME'
        )
        
        # Decoder (upsampling path)
        self.dec1 =  brainstate.nn.ConvTranspose2d(
            in_size=(16, 16, 512),
            out_channels=256,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )
        
        self.dec2 =  brainstate.nn.ConvTranspose2d(
            in_size=(32, 32, 256 + 256),  # +256 from skip connection
            out_channels=128,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )
        
        self.dec3 =  brainstate.nn.ConvTranspose2d(
            in_size=(64, 64, 128 + 128),  # +128 from skip connection
            out_channels=64,
            kernel_size=4,
            stride=2,
            padding='SAME'
        )
        
        # Output layer
        self.output =  brainstate.nn.Conv2d(
            in_size=(128, 128, 64 + 64),  # +64 from skip connection
            out_channels=num_classes,
            kernel_size=1,
            padding='SAME'
        )
    
    def update(self, x):
        # Encoder
        e1 = jax.nn.relu(self.enc1(x))  # 128x128x64
        e2 = jax.nn.relu(self.enc2(e1))  # 64x64x128
        e3 = jax.nn.relu(self.enc3(e2))  # 32x32x256
        
        # Bottleneck
        b = jax.nn.relu(self.bottleneck(e3))  # 16x16x512
        
        # Decoder with skip connections
        d1 = jax.nn.relu(self.dec1(b))  # 32x32x256
        d1 = jnp.concatenate([d1, e3], axis=-1)  # Skip connection
        
        d2 = jax.nn.relu(self.dec2(d1))  # 64x64x128
        d2 = jnp.concatenate([d2, e2], axis=-1)  # Skip connection
        
        d3 = jax.nn.relu(self.dec3(d2))  # 128x128x64
        d3 = jnp.concatenate([d3, e1], axis=-1)  # Skip connection
        
        # Output
        out = self.output(d3)  # 128x128x num_classes
        return out

# Create and test the U-Net
unet = SimpleUNet(num_classes=2)
x = jnp.ones((2, 128, 128, 3))
segmentation_map = unet(x)

print(f"Input shape: {x.shape}")
print(f"Segmentation map shape: {segmentation_map.shape}")
print(f"Number of classes: {segmentation_map.shape[-1]}")
print(f"\nU-Net preserves spatial dimensions: {x.shape[1:3] == segmentation_map.shape[1:3]}")

## Summary

This tutorial covered the comprehensive convolution API in BrainState:

### Key Takeaways:

1. **Standard Convolutions** (`Conv1d`, `Conv2d`, `Conv3d`)
   - Feature extraction and downsampling
   - Support for various kernel sizes, strides, and padding modes
   - Grouped and dilated convolutions for efficiency

2. **Weight-Standardized Convolutions** (`ScaledWSConv*`)
   - Improved training stability
   - Works well with Group Normalization
   - Optional learnable gain parameter

3. **Transposed Convolutions** (`ConvTranspose*`)
   - Upsampling for decoders, generators, and segmentation
   - Controllable upsampling factor via stride
   - Essential for encoder-decoder architectures

4. **Flexible Data Formats**
   - Channels-last (JAX/TensorFlow): default
   - Channels-first (PyTorch): via `channel_first=True`
   - Easy migration from PyTorch

5. **Advanced Features**
   - Grouped convolutions for parameter efficiency
   - Dilated convolutions for larger receptive fields
   - Multiple padding modes (SAME, VALID, explicit)

### Best Practices:

- Use **'SAME' padding** to preserve spatial dimensions when stride=1
- Use **grouped convolutions** to reduce parameters in large models
- Use **dilated convolutions** to increase receptive field without pooling
- Use **weight standardization** for training stability, especially with small batches
- Use **transposed convolutions** with kernel_size = stride * 2 for smooth upsampling
- Choose **data format** based on your framework background (JAX: channels-last, PyTorch: channels-first)

### Next Steps:

- Combine convolutions with normalization layers (BatchNorm, GroupNorm)
- Experiment with different activation functions
- Build more complex architectures (ResNet, DenseNet, etc.)
- Train models on real datasets

## References

1. **Weight Standardization**: Qiao, S., Wang, H., Liu, C., Shen, W., & Yuille, A. (2019). Weight Standardization. arXiv:1903.10520
2. **U-Net**: Ronneberger, O., Fischer, P., & Brox, T. (2015). U-Net: Convolutional Networks for Biomedical Image Segmentation. MICCAI 2015
3. **Grouped Convolutions**: Krizhevsky, A., Sutskever, I., & Hinton, G. E. (2012). ImageNet Classification with Deep Convolutional Neural Networks. NIPS 2012
4. **Dilated Convolutions**: Yu, F., & Koltun, V. (2015). Multi-Scale Context Aggregation by Dilated Convolutions. ICLR 2016

## Additional Resources

- [BrainState Documentation](https://brainstate.readthedocs.io/)
- [JAX Documentation](https://jax.readthedocs.io/)
- [Convolution Arithmetic Guide](https://github.com/vdumoulin/conv_arithmetic)