# PReLU Example (He et al., 2015)

This notebook demonstrates the Parametric ReLU (PReLU) activation function in a simple convolutional network.

Reference: [Delving Deep into Rectifiers: Surpassing Human-Level Performance on ImageNet Classification](https://arxiv.org/pdf/1502.01852).

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [None]:
# CIFAR10 dataset (downloads if not already present)
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = datasets.CIFAR10(root="./data", train=True, download=True, transform=img_transform)
train_loader = DataLoader(train_set, batch_size=64, shuffle=True)

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(16384, 10)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

for epoch in range(1):  # short demo
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')

## Using HuggingFace Trainer

Below is an optional example of training the CNN with [HuggingFace Trainer](https://github.com/huggingface/transformers).
This requires the `datasets` and `transformers` packages and assumes CIFAR-10 is
available locally or pre-downloaded because this notebook may run without internet access.

In [None]:
from datasets import load_dataset
from transformers import Trainer, TrainingArguments

# Load CIFAR-10 from disk or the HuggingFace hub
hf_ds = load_dataset('cifar10')

# Apply the same transformations used earlier
def transform_batch(batch):
    imgs = [img_transform(image) for image in batch['img']]
    batch['pixel_values'] = torch.stack(imgs)
    batch['labels'] = batch['label']
    return batch

hf_ds = hf_ds.with_transform(transform_batch)

def collate_fn(batch):
    return {'pixel_values': torch.stack([x['pixel_values'] for x in batch]),
            'labels': torch.tensor([x['labels'] for x in batch])}

training_args = TrainingArguments(
    output_dir='./results',
    per_device_train_batch_size=64,
    num_train_epochs=1,
    logging_steps=50,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=hf_ds['train'],
    data_collator=collate_fn,
)

trainer.train()