In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import StepLR
import scipy.stats as stats
import os

# Set a confidence threshold (e.g., 0.99 for 99% confidence)
CONFIDENCE_THRESHOLD = 0.99
PENALTY_WEIGHT = 10  # Weight for penalizing incorrect predictions after 50% accuracy
SAVE_PATH = './saved_models/'  # Directory to save model checkpoints
if not os.path.exists(SAVE_PATH):
    os.makedirs(SAVE_PATH)

# Focal Loss for handling imbalanced datasets
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, outputs, targets):
        BCE_loss = F.cross_entropy(outputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)  # Get the probability
        focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
        return focal_loss.mean()

# Custom Loss Function with Focal Loss and Penalty for Wrong Predictions after 50% Accuracy
def custom_loss_function(outputs, targets, class_accuracy, current_accuracy):
    # Apply softmax to get probabilities
    probabilities = F.softmax(outputs, dim=1)
    
    # Get the max probability (confidence) and corresponding predicted class
    confidences, predicted_classes = torch.max(probabilities, dim=1)
    
    # Calculate the Focal Loss for class imbalance
    focal_loss = FocalLoss()(outputs, targets)
    
    # Heavily penalize wrong argmax predictions if training accuracy > 50%
    wrong_predictions = (predicted_classes != targets).float()
    if current_accuracy > 0.5:
        wrong_prediction_penalty = PENALTY_WEIGHT * wrong_predictions.sum()
    else:
        wrong_prediction_penalty = 0

    # Calculate the total loss
    total_loss = focal_loss + wrong_prediction_penalty
    return total_loss

# WideResNeXt Block
class WideResNeXtBlock(nn.Module):
    expansion = 2  # Expansion factor for WideResNeXt

    def __init__(self, in_planes, planes, stride=1, cardinality=32, widen_factor=2):
        super(WideResNeXtBlock, self).__init__()
        D = cardinality * widen_factor
        self.conv1 = nn.Conv2d(in_planes, D, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm2d(D)
        self.conv2 = nn.Conv2d(D, D, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
        self.bn2 = nn.BatchNorm2d(D)
        self.conv3 = nn.Conv2d(D, planes * WideResNeXtBlock.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm2d(planes * WideResNeXtBlock.expansion)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes * WideResNeXtBlock.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes * WideResNeXtBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * WideResNeXtBlock.expansion)
            )

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = torch.relu(self.bn2(self.conv2(out)))
        out = self.bn3(self.conv3(out))
        out += self.shortcut(x)
        out = torch.relu(out)
        return out

# WideResNeXt Model
class WideResNeXt(nn.Module):
    def __init__(self, block, num_blocks, cardinality=32, widen_factor=2, num_classes=100):
        super(WideResNeXt, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1, cardinality=cardinality, widen_factor=widen_factor)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2, cardinality=cardinality, widen_factor=widen_factor)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2, cardinality=cardinality, widen_factor=widen_factor)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2, cardinality=cardinality, widen_factor=widen_factor)

        self.linear = nn.Linear(512 * WideResNeXtBlock.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride, cardinality, widen_factor):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride, cardinality, widen_factor))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = torch.nn.functional.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# WideResNeXt-101 Model Definition
def WideResNeXt101():
    return WideResNeXt(WideResNeXtBlock, [3, 4, 23, 3], cardinality=32, widen_factor=2)

# Training function with penalty for wrong predictions and tracking class-wise 100% accuracy
def train_with_penalty(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0
    class_correct = {i: 0 for i in range(100)}  # Tracks correct predictions per class
    class_total = {i: 0 for i in range(100)}    # Tracks total predictions per class

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        
        # Calculate the overall training accuracy before updating weights
        probabilities = F.softmax(outputs, dim=1)
        _, predicted_classes = torch.max(probabilities, dim=1)
        correct_predictions = predicted_classes.eq(targets).sum().item()
        total += targets.size(0)
        current_accuracy = correct_predictions / total
        
        # Calculate custom loss with penalties if training accuracy is above 50%
        loss = custom_loss_function(outputs, targets, class_accuracy, current_accuracy)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        correct += correct_predictions

        # Update per-class accuracy tracking
        for i in range(100):
            class_correct[i] += (predicted_classes == i).sum().item()
            class_total[i] += (targets == i).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {train_loss / (batch_idx + 1):.3f}, Acc: {100.*correct/total:.3f}%')

    # Print if any class has been 100% correctly predicted
    for i in range(100):
        if class_total[i] > 0 and class_correct[i] == class_total[i]:
            print(f'Class {i} has been 100% correctly predicted.')

# Testing function to evaluate performance and save the best model
best_accuracy = 0.0  # Track the best accuracy

def test_with_penalty(epoch):
    global best_accuracy
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            
            loss = custom_loss_function(outputs, targets, class_accuracy, 0)  # Use 0 for test phase
            test_loss += loss.item()
            
            probabilities = F.softmax(outputs, dim=1)
            _, predicted_classes = torch.max(probabilities, dim=1)
            correct_predictions = predicted_classes.eq(targets).sum().item()
            correct += correct_predictions
            total += targets.size(0)

    test_accuracy = 100. * correct / total
    print(f'Test set: Average loss: {test_loss/len(testloader):.4f}, Accuracy: {correct}/{total} ({test_accuracy:.2f}%)')

    # Save the model at every 50th epoch
    if epoch % 50 == 0:
        save_path = os.path.join(SAVE_PATH, f'model_epoch_{epoch}.pth')
        torch.save(model.state_dict(), save_path)
        print(f"Model saved at epoch {epoch}.")

    # Save the best model based on test accuracy
    if test_accuracy > best_accuracy:
        best_accuracy = test_accuracy
        save_path = os.path.join(SAVE_PATH, 'best_model.pth')
        torch.save(model.state_dict(), save_path)
        print(f"New best model saved with accuracy: {test_accuracy:.2f}%")

# Load CIFAR-100 dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2)

# Model, loss, optimizer, and scheduler
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = WideResNeXt101().to(device)

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1)
class_accuracy = {i: 1.0 for i in range(100)}

# Example Training Loop
for epoch in range(0, 100):
    train_with_penalty(epoch)
    test_with_penalty(epoch)
    scheduler.step()


Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to ./data/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:05<00:00, 29016271.62it/s]


Extracting ./data/cifar-100-python.tar.gz to ./data
Files already downloaded and verified
Epoch 0, Batch 0, Loss: 4.653, Acc: 1.562%
Epoch 0, Batch 100, Loss: 5.057, Acc: 1.145%
Epoch 0, Batch 200, Loss: 4.755, Acc: 1.640%
Epoch 0, Batch 300, Loss: 4.605, Acc: 2.121%
Test set: Average loss: 4.0675, Accuracy: 495/10000 (4.95%)
Model saved at epoch 0.
New best model saved with accuracy: 4.95%
Epoch 1, Batch 0, Loss: 3.992, Acc: 4.688%
Epoch 1, Batch 100, Loss: 3.932, Acc: 6.536%
Epoch 1, Batch 200, Loss: 3.839, Acc: 7.381%
Epoch 1, Batch 300, Loss: 3.755, Acc: 8.194%
Test set: Average loss: 3.3965, Accuracy: 1330/10000 (13.30%)
New best model saved with accuracy: 13.30%
Epoch 2, Batch 0, Loss: 3.289, Acc: 14.062%
Epoch 2, Batch 100, Loss: 3.253, Acc: 14.349%
Epoch 2, Batch 200, Loss: 3.177, Acc: 15.718%
Epoch 2, Batch 300, Loss: 3.112, Acc: 16.928%
Test set: Average loss: 3.2364, Accuracy: 1730/10000 (17.30%)
New best model saved with accuracy: 17.30%
Epoch 3, Batch 0, Loss: 2.709, Acc: 