In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import os

# Check if GPU is available and set the device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
BATCH_SIZE = 128
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-3  # Stronger weight decay to avoid overfitting
EPOCHS = 150
CONFIDENCE_THRESHOLD = 0.99  # Confidence threshold for high-confidence predictions
PATIENCE = 10  # Early stopping patience
SAVE_PATH = './'

class EarlyStopping:
    def __init__(self, patience=10, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, accuracy, model):
        if self.best_score is None:
            self.best_score = accuracy
        elif accuracy < self.best_score:
            self.counter += 1
            if self.verbose:
                print(f"Early stopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = accuracy
            self.counter = 0

# Create directory for saving models
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

# Data augmentation and normalization for training and testing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),  # Added Random Rotation for augmentation
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Load CIFAR-100 dataset
train_dataset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Define ResNet block
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Perform downsampling if necessary
        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


# Define ResNet architecture
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=100):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
        self.dropout = nn.Dropout(0.5)  # Added Dropout to reduce overfitting

    def _make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if stride != 1 or self.in_channels != out_channels:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(out_channels, out_channels))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)  # Apply dropout before fully connected layer
        x = self.fc(x)

        return x

# Smooth Score function
def smooth_score(logits, true_labels, accuracy_threshold=0.55, gamma=5.0):
    """
    Differentiable score function for neural networks.

    Parameters:
        - logits: Tensor of model outputs (raw logits before softmax).
        - true_labels: Tensor of true class labels.
        - accuracy_threshold: Threshold for class accuracy.
        - gamma: Weighting factor for low accuracy classifications.

    Returns:
        - A single differentiable score value.
    """
    # Apply softmax to logits to get class probabilities
    probs = F.softmax(logits, dim=1)

    # Create one-hot encoding for true labels
    true_one_hot = F.one_hot(true_labels, num_classes=logits.size(1)).float()

    # Calculate probabilities for the correct class
    correct_class_probs = torch.sum(probs * true_one_hot, dim=1)

    # Calculate a smooth accuracy approximation using sigmoid
    # This smooths the sharp transition between high and low accuracy classes
    smooth_accuracy = torch.sigmoid((correct_class_probs - accuracy_threshold) * 10)

    # High accuracy: values close to 1; Low accuracy: values close to 0
    high_accuracy_contrib = torch.sum(smooth_accuracy)

    # Low accuracy penalty: (1 - smooth_accuracy) approximates the low accuracy classes
    low_accuracy_contrib = torch.sum((1 - smooth_accuracy) * gamma)

    # Final score combines high accuracy and penalizes low accuracy
    final_score = high_accuracy_contrib - low_accuracy_contrib

    return final_score
    
def test(model, test_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    test_accuracy = 100.0 * correct / total
    print(f"Test Accuracy: {test_accuracy:.2f}%")
    return test_accuracy
    
# Training function with early stopping and penalty
def train_with_early_stopping(model, train_loader, test_loader, optimizer, criterion, epochs, device, patience):
    model.to(device)
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    best_test_accuracy = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        correct = 0
        total = 0
        epoch_smooth_score = 0  # To accumulate smooth scores

        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Calculate smooth score for this batch
            smooth_score_value = smooth_score(outputs, targets).item()
            epoch_smooth_score += smooth_score_value

        train_accuracy = 100.0 * correct / total
        test_accuracy = test(model, test_loader, device)

        # Penalize if train accuracy is more than 5% higher than test accuracy
        if train_accuracy - test_accuracy > 5.0:
            loss += (train_accuracy - test_accuracy) * 0.01  # Add a penalty

        # Print results for the epoch
        print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(train_loader):.4f}, "
              f"Train Accuracy: {train_accuracy:.2f}%, Test Accuracy: {test_accuracy:.2f}%, "
              f"Smooth Score: {epoch_smooth_score/len(train_loader):.4f}")

        # Save model every 50th epoch
        if (epoch + 1) % 50 == 0:
            torch.save(model.state_dict(), f"{SAVE_PATH}/model_epoch_{epoch+1}.pth")
            print(f"Model saved at epoch {epoch+1}")

        # Save the best model based on test accuracy
        if test_accuracy > best_test_accuracy:
            best_test_accuracy = test_accuracy
            torch.save(model.state_dict(), f"{SAVE_PATH}/best_model.pth")
            print("Best model saved!")

        # Early stopping check
        early_stopping(test_accuracy, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

# Label Smoothing Cross Entropy Loss
class LabelSmoothingCrossEntropyLoss(nn.Module):
    def __init__(self, smoothing=0.1):
        super(LabelSmoothingCrossEntropyLoss, self).__init__()
        self.smoothing = smoothing
        self.confidence = 1.0 - smoothing

    def forward(self, outputs, targets):
        log_probs = F.log_softmax(outputs, dim=-1)
        with torch.no_grad():
            num_classes = outputs.size(-1)
            smooth_targets = torch.full_like(log_probs, self.smoothing / (num_classes - 1))
            smooth_targets.scatter_(1, targets.unsqueeze(1), self.confidence)
        return (-smooth_targets * log_probs).sum(dim=-1).mean()

# Initialize ResNet-50 and optimizer (use a smaller model to reduce overfitting)
model = ResNet(BasicBlock, [3, 4, 6, 3])  # ResNet-50
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
criterion = LabelSmoothingCrossEntropyLoss(smoothing=0.1)

# Train the model with early stopping and penalty
train_with_early_stopping(model, train_loader, test_loader, optimizer, criterion, EPOCHS, device, PATIENCE)



Using device: cuda
Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:05<00:00, 29140319.61it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Test Accuracy: 10.63%
Epoch [1/150], Loss: 4.2306, Train Accuracy: 6.69%, Test Accuracy: 10.63%, Smooth Score: -633.8221
Best model saved!
Test Accuracy: 18.72%
Epoch [2/150], Loss: 3.8186, Train Accuracy: 13.18%, Test Accuracy: 18.72%, Smooth Score: -628.3544
Best model saved!
Test Accuracy: 25.88%
Epoch [3/150], Loss: 3.5184, Train Accuracy: 20.24%, Test Accuracy: 25.88%, Smooth Score: -615.3669
Best model saved!
Test Accuracy: 31.12%
Epoch [4/150], Loss: 3.2748, Train Accuracy: 26.21%, Test Accuracy: 31.12%, Smooth Score: -595.9694
Best model saved!
Test Accuracy: 35.60%
Epoch [5/150], Loss: 3.0667, Train Accuracy: 31.10%, Test Accuracy: 35.60%, Smooth Score: -572.3820
Best model saved!
Test Accuracy: 41.53%
Epoch [6/150], Loss: 2.8744, Train Accuracy: 36.35%, Test Accuracy: 41.53%, Smooth Score: -543.9459
Best model saved!
Test Accuracy: 43.04%
Epoch [7/150], Loss: 2.7114, Train Accuracy: 40.8