In [None]:
# Install required packages

!pip install efficientnet-pytorch
!pip install pytorch-quantization
!pip install tensorrt
!pip install onnx
!pip install timm

# Import necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.quantization import QConfig, default_qconfig
from torch.quantization.quantize_fx import prepare_qat_fx, convert_fx
import torchvision.transforms as transforms
from efficientnet_pytorch import EfficientNet
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


In [None]:
import math

class FP4Quantizer:
    """
    FP4 E2M1 format quantization implementation
    Based on the format described in the research papers
    """
    def __init__(self, format_type='E2M1'):
        self.format_type = format_type
        # FP4 E2M1 quantization levels (16 values)
        self.fp4_values = torch.tensor([
            -6.0, -4.0, -3.0, -2.0, -1.5, -1.0, -0.5, 0.0,
            0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0  # 16 values
        ], dtype=torch.float32)
    
    def quantize(self, x, scale_factor=None):
        """
        Quantize tensor to FP4 format using absmax scaling
        """
        if scale_factor is None:
            # Calculate absmax scale factor
            max_val = torch.max(torch.abs(x))
            scale_factor = 6.0 / max_val  # 6.0 is MAX_FP4 for E2M1
        
        # Scale input
        x_scaled = x * scale_factor
        
        # Clamp to FP4 range
        x_clamped = torch.clamp(x_scaled, -6.0, 6.0)
        
        # Quantize using lookup table
        quantized = self._quantize_lookup(x_clamped)
        
        return quantized / scale_factor, scale_factor
    
    def _quantize_lookup(self, x):
        """
        Quantize using lookup table for FP4 values
        """
        # Find closest FP4 value for each element
        x_flat = x.flatten()
        quantized_flat = torch.zeros_like(x_flat)
        
        for i, val in enumerate(x_flat):
            diff = torch.abs(self.fp4_values - val)
            closest_idx = torch.argmin(diff)
            quantized_flat[i] = self.fp4_values[closest_idx]
        
        return quantized_flat.reshape(x.shape)

class OutlierClampingCompensation:
    """
    Outlier Clamping and Compensation for activations
    Based on recent FP4 training research
    """
    def __init__(self, quantile=0.99):
        self.quantile = quantile
    
    def apply(self, x):
        """
        Apply outlier clamping and return clamped tensor and compensation matrix
        """
        # Calculate quantile thresholds
        threshold = torch.quantile(torch.abs(x), self.quantile)
        
        # Clamp outliers
        x_clamped = torch.clamp(x, -threshold, threshold)
        
        # Calculate compensation matrix (sparse)
        compensation = x - x_clamped
        
        return x_clamped, compensation


In [None]:
class QuantizedConv2d(nn.Module):
    """
    Custom quantized convolution with FP4 support
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.weight_quantizer = FP4Quantizer()
        self.activation_quantizer = FP4Quantizer()
        self.outlier_handler = OutlierClampingCompensation()
        
    def forward(self, x):
        # Quantize weights
        weight_q, weight_scale = self.weight_quantizer.quantize(self.conv.weight)
        
        # Handle activation outliers
        x_clamped, compensation = self.outlier_handler.apply(x)
        
        # Quantize activations
        x_q, act_scale = self.activation_quantizer.quantize(x_clamped)
        
        # Perform convolution with quantized weights and activations
        # Note: This is a simulation - actual hardware would handle this differently
        output = F.conv2d(x_q, weight_q, self.conv.bias, 
                         self.conv.stride, self.conv.padding)
        
        # Add compensation for outliers (sparse matrix multiplication)
        if torch.sum(torch.abs(compensation)) > 0:
            comp_output = F.conv2d(compensation, self.conv.weight, None,
                                 self.conv.stride, self.conv.padding)
            output += comp_output
        
        return output

class QuantizedEfficientNet(nn.Module):
    """
    EfficientNet with FP4 quantization
    """
    def __init__(self, model_name='efficientnet-b0', num_classes=1000):
        super().__init__()
        # Load pre-trained EfficientNet
        self.backbone = EfficientNet.from_pretrained(model_name, num_classes=num_classes)
        
        # Replace key layers with quantized versions
        self._replace_layers()
        
    def _replace_layers(self):
        """
        Replace standard layers with quantized versions
        """
        # This is a simplified replacement - in practice, you'd need to handle
        # all conv layers in the MBConv blocks
        for name, module in self.backbone.named_modules():
            if isinstance(module, nn.Conv2d) and 'features' in name:
                # Replace with quantized conv
                new_conv = QuantizedConv2d(
                    module.in_channels, 
                    module.out_channels, 
                    module.kernel_size, 
                    module.stride, 
                    module.padding
                )
                # Copy weights
                new_conv.conv.weight.data = module.weight.data.clone()
                if module.bias is not None:
                    new_conv.conv.bias.data = module.bias.data.clone()
                
                # Replace in model
                parent_name = '.'.join(name.split('.')[:-1])
                child_name = name.split('.')[-1]
                parent = self.backbone
                for part in parent_name.split('.'):
                    if part:
                        parent = getattr(parent, part)
                setattr(parent, child_name, new_conv)
    
    def forward(self, x):
        return self.backbone(x)


In [None]:
class FP4QATTrainer:
    """
    Quantization-Aware Training trainer for FP4 EfficientNet
    """
    def __init__(self, model, train_loader, val_loader, device):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        
        # Setup optimizer with lower learning rate for QAT
        self.optimizer = optim.AdamW(self.model.parameters(), lr=1e-5, weight_decay=1e-4)
        self.scheduler = optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=100)
        self.criterion = nn.CrossEntropyLoss()
        
    def train_epoch(self):
        """
        Train for one epoch
        """
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        for batch_idx, (data, target) in enumerate(tqdm(self.train_loader, desc="Training")):
            data, target = data.to(self.device), target.to(self.device)
            
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.optimizer.step()
            
            total_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += target.size(0)
            
            if batch_idx % 100 == 0:
                print(f'Batch {batch_idx}, Loss: {loss.item():.6f}')
        
        return total_loss / len(self.train_loader), 100. * correct / total
    
    def validate(self):
        """
        Validate the model
        """
        self.model.eval()
        val_loss = 0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for data, target in tqdm(self.val_loader, desc="Validation"):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                val_loss += self.criterion(output, target).item()
                pred = output.argmax(dim=1, keepdim=True)
                correct += pred.eq(target.view_as(pred)).sum().item()
                total += target.size(0)
        
        return val_loss / len(self.val_loader), 100. * correct / total
    
    def train(self, epochs=50):
        """
        Full training loop
        """
        train_losses, train_accs = [], []
        val_losses, val_accs = [], []
        
        for epoch in range(epochs):
            print(f'\nEpoch {epoch+1}/{epochs}')
            
            # Training
            train_loss, train_acc = self.train_epoch()
            
            # Validation
            val_loss, val_acc = self.validate()
            
            # Update scheduler
            self.scheduler.step()
            
            # Store metrics
            train_losses.append(train_loss)
            train_accs.append(train_acc)
            val_losses.append(val_loss)
            val_accs.append(val_acc)
            
            print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
            print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
            
            # Save best model
            if epoch == 0 or val_acc > max(val_accs[:-1]):
                torch.save(self.model.state_dict(), 'best_fp4_efficientnet.pth')
                print('Best model saved!')
        
        return {
            'train_losses': train_losses,
            'train_accs': train_accs,
            'val_losses': val_losses,
            'val_accs': val_accs
        }


In [None]:
def prepare_imagenet_data(data_path='/EfficientNet/FP4/Imagenet', batch_size=32):
    """
    Prepare ImageNet dataset for training
    """
    # ImageNet normalization values
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                                   std=[0.229, 0.224, 0.225])
    
    # Training transforms with progressive resizing support
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
        transforms.ToTensor(),
        normalize
    ])
    
    # Validation transforms
    val_transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        normalize
    ])
    
    # Load ImageNet dataset
    train_dataset = torchvision.datasets.ImageFolder(
        os.path.join(data_path, 'train'), 
        transform=train_transform
    )
    
    val_dataset = torchvision.datasets.ImageFolder(
        os.path.join(data_path, 'val'), 
        transform=val_transform
    )
    
    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True, 
        num_workers=8,
        pin_memory=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size, 
        shuffle=False, 
        num_workers=8,
        pin_memory=True
    )
    
    return train_loader, val_loader

# Update model creation for ImageNet (1000 classes)
model = QuantizedEfficientNet(model_name='efficientnet-b0', num_classes=1000)


In [None]:
def main():
    """
    Main training pipeline for FP4 EfficientNet
    """
    print("Starting FP4 EfficientNet Quantization-Aware Training")
    
    # Prepare data
    print("Preparing data...")
    train_loader, val_loader = prepare_data(batch_size=32)
    
    # Create quantized model
    print("Creating quantized EfficientNet model...")
    model = QuantizedEfficientNet(model_name='efficientnet-b0', num_classes=10)
    
    # Setup trainer
    trainer = FP4QATTrainer(model, train_loader, val_loader, device)
    
    # Train the model
    print("Starting training...")
    history = trainer.train(epochs=50)
    
    # Plot results
    plot_training_history(history)
    
    print("Training completed!")
    return model, history

def plot_training_history(history):
    """
    Plot training history
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot loss
    ax1.plot(history['train_losses'], label='Train Loss')
    ax1.plot(history['val_losses'], label='Val Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    # Plot accuracy
    ax2.plot(history['train_accs'], label='Train Accuracy')
    ax2.plot(history['val_accs'], label='Val Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()

# Run the complete pipeline
if __name__ == "__main__":
    model, history = main()
