# Basic Neural Network Layers

BrainState provides a comprehensive set of pre-built layers for building neural networks. This tutorial covers the essential building blocks.

You will learn about:

- 📏 **Linear layers** - Fully connected transformations
- 🔲 **Convolutional layers** - Spatial feature extraction (1D, 2D, 3D)
- 🏊 **Pooling layers** - Downsampling operations
- 💧 **Dropout layers** - Regularization techniques
- 🔧 **Utility layers** - Flatten, reshape, and more

These layers are optimized, well-tested, and ready to use in your models!

In [1]:
import brainstate
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

## 1. Linear (Fully Connected) Layers

Linear layers perform the transformation: **y = Wx + b**

### Basic Usage

In [5]:
# Create a linear layer
brainstate.random.seed(42)
linear = brainstate.nn.Linear(10, 5)

print("Linear Layer:")
print(linear)
print(f"\nWeight shape: {linear.weight.value['weight'].shape}")
print(f"Bias shape: {linear.weight.value['bias'].shape}")

# Forward pass
x = brainstate.random.randn(10)
y = linear(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Output: {y}")

Linear Layer:
Linear(
  in_size=(10,),
  out_size=(5,),
  w_mask=None,
  weight=ParamState(
    value={
      'bias': ShapedArray(float32[5]),
      'weight': ShapedArray(float32[10,5])
    }
  )
)

Weight shape: (10, 5)
Bias shape: (5,)

Input shape: (10,)
Output shape: (5,)
Output: [ 0.24681929  1.2860886  -1.6367221   0.29457197 -0.9486235 ]


### Batch Processing

Linear layers automatically handle batched inputs:

In [6]:
# Batched input: (batch_size, features)
x_batch = brainstate.random.randn(32, 10)  # 32 samples, 10 features each
y_batch = linear(x_batch)

print(f"Batched input shape: {x_batch.shape}")
print(f"Batched output shape: {y_batch.shape}")

# Works with arbitrary batch dimensions
x_multi = brainstate.random.randn(8, 4, 10)  # (batch1, batch2, features)
y_multi = linear(x_multi)

print(f"\nMulti-batch input: {x_multi.shape}")
print(f"Multi-batch output: {y_multi.shape}")

Batched input shape: (32, 10)
Batched output shape: (32, 5)

Multi-batch input: (8, 4, 10)
Multi-batch output: (8, 4, 5)


### Linear Layer Variants

BrainState provides specialized linear layers:

In [None]:
# SparseLinear: For sparse connectivity
sparse_linear = brainstate.nn.SparseLinear(
    num_in=100,
    num_out=50,
    prob=0.1,  # 10% connectivity
    weight=0.05
)

print("Sparse Linear Layer:")
print(sparse_linear)

x = brainstate.random.randn(100)
y = sparse_linear(x)
print(f"\nInput: {x.shape} → Output: {y.shape}")

## 2. Convolutional Layers

Convolutional layers extract spatial features using learnable filters.

### Conv1d - 1D Convolution

Used for sequential data (time series, audio, text):

In [None]:
# Conv1d: (batch, length, in_channels) → (batch, length, out_channels)
brainstate.random.seed(0)
conv1d = brainstate.nn.Conv1d(
    in_channels=3,
    out_channels=16,
    kernel_size=3,
    padding='SAME'  # Keep spatial dimensions
)

print("Conv1d Layer:")
print(conv1d)

# Input: (batch=4, length=100, channels=3)
x = brainstate.random.randn(4, 100, 3)
y = conv1d(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Kernel shape: {conv1d.w.value.shape}")

### Conv2d - 2D Convolution

The workhorse for image processing:

In [None]:
# Conv2d: (batch, height, width, in_channels) → (batch, height, width, out_channels)
conv2d = brainstate.nn.Conv2d(
    in_channels=3,      # RGB input
    out_channels=32,    # 32 feature maps
    kernel_size=(3, 3), # 3x3 kernel
    strides=(1, 1),     # Stride of 1
    padding='SAME'
)

print("Conv2d Layer:")
print(conv2d)

# Input: (batch=8, height=28, width=28, channels=3)
x_image = brainstate.random.randn(8, 28, 28, 3)
y_image = conv2d(x_image)

print(f"\nInput shape: {x_image.shape}")
print(f"Output shape: {y_image.shape}")
print(f"Kernel shape: {conv2d.w.value.shape}")

### Visualizing Convolutional Features

In [None]:
# Create a simple image with a pattern
def create_test_image():
    img = jnp.zeros((28, 28, 1))
    # Add vertical edge
    img = img.at[5:23, 10:13, 0].set(1.0)
    # Add horizontal edge
    img = img.at[10:13, 5:23, 0].set(1.0)
    return img

# Create and apply conv
brainstate.random.seed(42)
edge_conv = brainstate.nn.Conv2d(
    in_channels=1,
    out_channels=4,
    kernel_size=(3, 3),
    padding='SAME'
)

test_img = create_test_image()
features = edge_conv(test_img[None, ...])[0]  # Add/remove batch dim

# Visualize
fig, axes = plt.subplots(1, 5, figsize=(15, 3))

# Original image
axes[0].imshow(test_img[:, :, 0], cmap='gray')
axes[0].set_title('Input Image')
axes[0].axis('off')

# Feature maps
for i in range(4):
    axes[i+1].imshow(np.array(features[:, :, i]), cmap='viridis')
    axes[i+1].set_title(f'Feature Map {i}')
    axes[i+1].axis('off')

plt.tight_layout()
plt.show()

### Convolution Parameters

Understanding stride and padding:

In [None]:
# Different stride values
x = brainstate.random.randn(1, 28, 28, 3)

configs = [
    {'strides': (1, 1), 'padding': 'SAME', 'name': 'Stride 1, SAME'},
    {'strides': (2, 2), 'padding': 'SAME', 'name': 'Stride 2, SAME'},
    {'strides': (1, 1), 'padding': 'VALID', 'name': 'Stride 1, VALID'},
]

print(f"Input shape: {x.shape}\n")

for config in configs:
    brainstate.random.seed(0)
    conv = brainstate.nn.Conv2d(
        in_channels=3,
        out_channels=16,
        kernel_size=(3, 3),
        strides=config['strides'],
        padding=config['padding']
    )
    y = conv(x)
    print(f"{config['name']:20s}: {y.shape}")

### Conv3d - 3D Convolution

For video or volumetric data:

In [None]:
# Conv3d: (batch, depth, height, width, channels)
conv3d = brainstate.nn.Conv3d(
    in_channels=3,
    out_channels=16,
    kernel_size=(3, 3, 3),
    padding='SAME'
)

# Video input: (batch=2, frames=16, height=64, width=64, channels=3)
x_video = brainstate.random.randn(2, 16, 64, 64, 3)
y_video = conv3d(x_video)

print(f"Video input shape: {x_video.shape}")
print(f"Video output shape: {y_video.shape}")

## 3. Pooling Layers

Pooling layers downsample feature maps, reducing spatial dimensions.

### Max Pooling

In [None]:
# MaxPool2d: Takes maximum value in each window
maxpool = brainstate.nn.MaxPool2d(
    kernel_size=(2, 2),
    strides=(2, 2)
)

print("MaxPool2d Layer:")
print(maxpool)

# Input: (batch=4, height=28, width=28, channels=16)
x = brainstate.random.randn(4, 28, 28, 16)
y = maxpool(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {y.shape}")
print(f"Spatial reduction: {x.shape[1]}x{x.shape[2]} → {y.shape[1]}x{y.shape[2]}")

### Average Pooling

In [None]:
# AvgPool2d: Takes average value in each window
avgpool = brainstate.nn.AvgPool2d(
    kernel_size=(2, 2),
    strides=(2, 2)
)

y_avg = avgpool(x)

print("AvgPool2d Layer:")
print(avgpool)
print(f"\nOutput shape: {y_avg.shape}")

### Adaptive Pooling

Pools to a fixed output size regardless of input size:

In [None]:
# AdaptiveAvgPool2d: Always outputs specified size
adaptive_pool = brainstate.nn.AdaptiveAvgPool2d(output_size=(7, 7))

# Works with any input size
inputs = [
    brainstate.random.randn(1, 28, 28, 16),
    brainstate.random.randn(1, 56, 56, 16),
    brainstate.random.randn(1, 224, 224, 16),
]

print("AdaptiveAvgPool2d (output: 7x7)\n")
for i, x in enumerate(inputs):
    y = adaptive_pool(x)
    print(f"Input {x.shape[1:3]:>8} → Output {y.shape[1:3]}")

### Comparing Pooling Operations

In [None]:
# Create a simple test pattern
x_test = jnp.array([
    [1, 2, 3, 4],
    [5, 6, 7, 8],
    [9, 10, 11, 12],
    [13, 14, 15, 16]
], dtype=jnp.float32)

x_test = x_test[None, :, :, None]  # Add batch and channel dims

# Apply different pooling
maxpool_2x2 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), strides=(2, 2))
avgpool_2x2 = brainstate.nn.AvgPool2d(kernel_size=(2, 2), strides=(2, 2))

y_max = maxpool_2x2(x_test)
y_avg = avgpool_2x2(x_test)

fig, axes = plt.subplots(1, 3, figsize=(12, 4))

# Original
im1 = axes[0].imshow(x_test[0, :, :, 0], cmap='viridis', interpolation='nearest')
axes[0].set_title('Original (4x4)', fontsize=12, fontweight='bold')
for i in range(4):
    for j in range(4):
        axes[0].text(j, i, f'{x_test[0,i,j,0]:.0f}', 
                    ha='center', va='center', color='white', fontsize=10)
plt.colorbar(im1, ax=axes[0])

# Max pooled
im2 = axes[1].imshow(y_max[0, :, :, 0], cmap='viridis', interpolation='nearest')
axes[1].set_title('Max Pooled (2x2)', fontsize=12, fontweight='bold')
for i in range(2):
    for j in range(2):
        axes[1].text(j, i, f'{y_max[0,i,j,0]:.0f}', 
                    ha='center', va='center', color='white', fontsize=10)
plt.colorbar(im2, ax=axes[1])

# Avg pooled
im3 = axes[2].imshow(y_avg[0, :, :, 0], cmap='viridis', interpolation='nearest')
axes[2].set_title('Avg Pooled (2x2)', fontsize=12, fontweight='bold')
for i in range(2):
    for j in range(2):
        axes[2].text(j, i, f'{y_avg[0,i,j,0]:.1f}', 
                    ha='center', va='center', color='white', fontsize=10)
plt.colorbar(im3, ax=axes[2])

plt.tight_layout()
plt.show()

print("MaxPool takes the maximum value from each 2x2 window")
print("AvgPool takes the average value from each 2x2 window")

## 4. Dropout and Regularization

Dropout randomly sets activations to zero during training for regularization.

### Standard Dropout

In [None]:
# Dropout: Randomly zero out elements
dropout = brainstate.nn.Dropout(p=0.5)  # Drop 50% of activations

print("Dropout Layer:")
print(dropout)

# Create test input
brainstate.random.seed(42)
x = jnp.ones(10)

# Apply dropout multiple times (different masks)
print("\nOriginal:", x)
print("\nDropout outputs (training mode):")
for i in range(3):
    y = dropout(x)
    print(f"  {i+1}: {y}")

### Training vs Evaluation Mode

Dropout behaves differently during training and evaluation:

In [None]:
class NetworkWithDropout(brainstate.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.linear1 = brainstate.nn.Linear(input_dim, hidden_dim)
        self.dropout = brainstate.nn.Dropout(p=0.5)
        self.linear2 = brainstate.nn.Linear(hidden_dim, output_dim)
    
    def update(self, x):
        x = self.linear1(x)
        x = jnp.maximum(0, x)  # ReLU
        x = self.dropout(x)
        x = self.linear2(x)
        return x

# Create network
brainstate.random.seed(0)
net = NetworkWithDropout(10, 20, 5)

# Test input
x = brainstate.random.randn(10)

# Compare outputs
y1 = net(x)
y2 = net(x)

print("With dropout (training):")
print(f"  Output 1: {y1}")
print(f"  Output 2: {y2}")
print(f"  Outputs differ: {not jnp.allclose(y1, y2)}")

print("\n💡 Note: In training mode, outputs vary due to random dropout masks")
print("   In eval mode, dropout is disabled for consistent predictions")

## 5. Utility Layers

### Flatten Layer

Flattens multi-dimensional inputs:

In [None]:
# Flatten: Reshape to 1D
flatten = brainstate.nn.Flatten()

# Example: After convolution, flatten before fully connected
x_conv = brainstate.random.randn(4, 7, 7, 64)  # (batch, H, W, C)
x_flat = flatten(x_conv)

print(f"Before flatten: {x_conv.shape}")
print(f"After flatten: {x_flat.shape}")
print(f"\nFlattened: {7 * 7 * 64} = {x_flat.shape[1]} features per sample")

## 6. Building a Complete CNN

Let's combine everything into a complete convolutional neural network:

In [None]:
class SimpleCNN(brainstate.nn.Module):
    """Simple CNN for image classification."""
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # Conv block 1
        self.conv1 = brainstate.nn.Conv2d(3, 32, kernel_size=(3, 3), padding='SAME')
        self.pool1 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), strides=(2, 2))
        
        # Conv block 2
        self.conv2 = brainstate.nn.Conv2d(32, 64, kernel_size=(3, 3), padding='SAME')
        self.pool2 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), strides=(2, 2))
        
        # Conv block 3
        self.conv3 = brainstate.nn.Conv2d(64, 128, kernel_size=(3, 3), padding='SAME')
        self.pool3 = brainstate.nn.MaxPool2d(kernel_size=(2, 2), strides=(2, 2))
        
        # Flatten and classify
        self.flatten = brainstate.nn.Flatten()
        self.fc1 = brainstate.nn.Linear(128 * 4 * 4, 256)  # Assuming 32x32 input
        self.dropout = brainstate.nn.Dropout(p=0.5)
        self.fc2 = brainstate.nn.Linear(256, num_classes)
    
    def update(self, x):
        # Conv block 1
        x = self.conv1(x)
        x = jnp.maximum(0, x)  # ReLU
        x = self.pool1(x)
        
        # Conv block 2
        x = self.conv2(x)
        x = jnp.maximum(0, x)
        x = self.pool2(x)
        
        # Conv block 3
        x = self.conv3(x)
        x = jnp.maximum(0, x)
        x = self.pool3(x)
        
        # Classifier
        x = self.flatten(x)
        x = self.fc1(x)
        x = jnp.maximum(0, x)
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x

# Create CNN
brainstate.random.seed(42)
cnn = SimpleCNN(num_classes=10)

print("Simple CNN Architecture:")
print(cnn)

# Test with batch of images
batch_size = 8
images = brainstate.random.randn(batch_size, 32, 32, 3)  # CIFAR-10 size
logits = cnn(images)

print(f"\nInput shape: {images.shape}")
print(f"Output shape: {logits.shape}")
print(f"\nLogits for first image: {logits[0]}")

### Visualizing CNN Feature Maps

In [None]:
# Get intermediate features
def get_conv_features(model, x):
    """Extract features from each conv layer."""
    features = []
    
    # Conv 1
    x = model.conv1(x)
    x = jnp.maximum(0, x)
    features.append(x)
    x = model.pool1(x)
    
    # Conv 2
    x = model.conv2(x)
    x = jnp.maximum(0, x)
    features.append(x)
    x = model.pool2(x)
    
    # Conv 3
    x = model.conv3(x)
    x = jnp.maximum(0, x)
    features.append(x)
    
    return features

# Extract features
single_image = images[0:1]  # Take first image
features = get_conv_features(cnn, single_image)

# Visualize
fig, axes = plt.subplots(1, 4, figsize=(16, 4))

# Original image
axes[0].imshow((np.array(single_image[0]) + 1) / 2)  # Normalize to [0,1]
axes[0].set_title('Input Image\n(32×32×3)', fontsize=10, fontweight='bold')
axes[0].axis('off')

# Feature maps
layer_names = ['Conv1\n(32×32×32)', 'Conv2\n(16×16×64)', 'Conv3\n(8×8×128)']
for i, (feat, name) in enumerate(zip(features, layer_names)):
    # Show first feature map from each layer
    feat_map = np.array(feat[0, :, :, 0])
    axes[i+1].imshow(feat_map, cmap='viridis')
    axes[i+1].set_title(name, fontsize=10, fontweight='bold')
    axes[i+1].axis('off')

plt.tight_layout()
plt.show()

print("Feature map shapes:")
for i, feat in enumerate(features):
    print(f"  Layer {i+1}: {feat.shape}")

## Summary

In this tutorial, you learned about:

✅ **Linear layers** - Fully connected transformations  
✅ **Convolutional layers** - 1D, 2D, 3D spatial feature extraction  
✅ **Pooling layers** - Max, average, and adaptive pooling  
✅ **Dropout** - Regularization through random masking  
✅ **Utility layers** - Flatten and reshape operations  
✅ **Complete CNN** - Building end-to-end architectures  

### Key Takeaways

| Layer Type | Use Case | Key Parameters |
|------------|----------|----------------|
| **Linear** | Fully connected | `in_features`, `out_features` |
| **Conv2d** | Image features | `in_channels`, `out_channels`, `kernel_size`, `stride`, `padding` |
| **MaxPool2d** | Downsampling | `kernel_size`, `strides` |
| **Dropout** | Regularization | `p` (drop probability) |
| **Flatten** | Shape transformation | None |

### Best Practices

1. 🎯 **Use SAME padding** to preserve spatial dimensions
2. 📐 **Double check shapes** - especially after conv/pool operations
3. 💧 **Add dropout** after dense layers for regularization
4. 🏊 **Pool after activation** - standard practice in CNNs
5. 🔍 **Visualize features** - helps debug and understand the network

### Next Steps

Continue with:
- **Activations & Normalization** - Improve training stability
- **Recurrent Networks** - Handle sequential data
- **Training** - Put it all together with optimization