In [None]:
import torch
import torch.nn as nn
from torch import Tensor
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch
from tqdm import tqdm


In [None]:
class CNN(nn.Module):
    """Modular CNN implementation with configurable layers."""
    
    def __init__(self, in_channels: int, num_classes: int, layer_config: list):
        """
        Args:
            in_channels: Number of input channels (1 for grayscale, 3 for RGB)
            num_classes: Number of output classes
            layer_config: List of dictionaries specifying layer configurations
        """
        super(CNN, self).__init__()
        self.layers = nn.Sequential()
        
        # Build convolutional layers
        for i, config in enumerate(layer_config):
            self.layers.add_module(
                f"conv_{i}",
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=config['out_channels'],
                    kernel_size=config['kernel_size'],
                    stride=config['stride'],
                    padding=config['padding'],
                )
            )
            in_channels = config['out_channels']  # Update in_channels for next layer
            
            # Add activation
            if config['activation'] == 'relu':
                self.layers.add_module(f"relu_{i}", nn.ReLU())
            elif config['activation'] == 'leaky_relu':
                self.layers.add_module(f"leaky_relu_{i}", nn.LeakyReLU(0.1))
                
            # Add pooling if specified
            if 'pool' in config:
                self.layers.add_module(
                    f"pool_{i}",
                    nn.MaxPool2d(
                        kernel_size=config['pool']['kernel_size'],
                        stride=config['pool']['stride']
                    )
                )
        
        # Adaptive pooling to handle different input sizes
        self.adaptive_pool = nn.AdaptiveAvgPool2d((4, 4))
        
        # Fully connected layer
        self.fc = nn.Linear(in_channels * 4 * 4, num_classes)
        
    def forward(self, x: Tensor) -> Tensor:
        x = self.layers(x)
        x = self.adaptive_pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        return x

In [None]:
# CNN Architecture Configuration
CNN_CONFIG = [
    {
        'out_channels': 32,
        'kernel_size': 3,
        'stride': 1,
        'padding': 1,
        'activation': 'relu',
        'pool': {'kernel_size': 2, 'stride': 2}
    },
    {
        'out_channels': 64,
        'kernel_size': 3,
        'stride': 1,
        'padding': 1,
        'activation': 'relu',
        'pool': {'kernel_size': 2, 'stride': 2}
    }
]

# Training Configuration
TRAIN_CONFIG = {
    'batch_size': 64,
    'epochs': 10,
    'learning_rate': 0.001,
    'optimizer': optim.Adam,
    'loss_fn': nn.CrossEntropyLoss(),
    'device': 'cuda' if torch.cuda.is_available() else 'cpu'
}

In [None]:
def get_dataloaders(dataset_name: str = 'MNIST', batch_size: int = 64):
    """Create train and test dataloaders for specified dataset."""
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))  # Adjust for different datasets
    ])
    # ADD Datasets as needed
    if dataset_name == 'MNIST':
        train_data = datasets.MNIST(
            root='data', train=True, download=True, transform=transform)
        test_data = datasets.MNIST(
            root='data', train=False, download=True, transform=transform)
    elif dataset_name == 'CIFAR10':
        train_data = datasets.CIFAR10(
            root='data', train=True, download=True, transform=transform)
        test_data = datasets.CIFAR10(
            root='data', train=False, download=True, transform=transform)
    else:
        raise ValueError(f"Unsupported dataset: {dataset_name}")
    
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
    
    return train_loader, test_loader

In [None]:
class Trainer:
    """Modular training class for CNN models."""
    
    def __init__(self, model, train_loader, test_loader, config):
        self.model = model.to(config['device'])
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.config = config
        self.optimizer = config['optimizer'](model.parameters(), lr=config['learning_rate'])
        self.loss_fn = config['loss_fn']
        
    def train_epoch(self):
        self.model.train()
        total_loss = 0
        
        for data, targets in tqdm(self.train_loader, desc="Training"):
            data = data.to(self.config['device'])
            targets = targets.to(self.config['device'])
            
            self.optimizer.zero_grad()
            outputs = self.model(data)
            loss = self.loss_fn(outputs, targets)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            
        return total_loss / len(self.train_loader)
    
    def evaluate(self):
        self.model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, targets in self.test_loader:
                data = data.to(self.config['device'])
                targets = targets.to(self.config['device'])
                
                outputs = self.model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()
                
        return correct / total
    
    def save_model(self, path='model.pth'):
        torch.save(self.model.state_dict(), path)

In [None]:
def main():
    # Initialize model
    model = CNN(
        in_channels=1,  # MNIST has 1 channel
        num_classes=10,
        layer_config=CNN_CONFIG
    )
    
    # Get data loaders
    train_loader, test_loader = get_dataloaders(
        dataset_name='MNIST',
        batch_size=TRAIN_CONFIG['batch_size']
    )
    
    # Initialize trainer
    trainer = Trainer(model, train_loader, test_loader, TRAIN_CONFIG)
    
    # Training loop
    for epoch in range(TRAIN_CONFIG['epochs']):
        train_loss = trainer.train_epoch()
        test_acc = trainer.evaluate()
        print(f"Epoch {epoch+1}/{TRAIN_CONFIG['epochs']}")
        print(f"Train Loss: {train_loss:.4f} | Test Accuracy: {test_acc:.4f}")
    
    trainer.save_model()

main()