In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset, random_split

import numpy as np

import math
import torch.nn.functional as F

In [None]:
# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


DATASET   CIFAR 100

In [None]:
imbalance_factor = 100    #100, 200

In [None]:
import matplotlib.pyplot as plt

# Optimized Data Preprocessing
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Combined resizing and cropping
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize directly to target size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load CIFAR-100 dataset
train_set = datasets.CIFAR100(root='./datasets', train=True, download=True, transform=transform_train)
test_set = datasets.CIFAR100(root='./datasets', train=False, download=True, transform=transform_test)

# Define imbalance factor and number of classes
num_classes = 100

# Calculate distribution
mu = (1 / imbalance_factor) ** (1 / (num_classes - 1))

def calculate_class_distribution(dataset, num_classes, mu):
    original_counts = np.bincount([label for _, label in dataset], minlength=num_classes)
    max_samples = max(original_counts)
    target_counts = [int(max_samples * (mu ** i)) for i in range(num_classes)]
    return target_counts

class_distribution_train = calculate_class_distribution(train_set, num_classes, mu)

# Apply class distribution to training set
def apply_class_distribution(dataset, target_counts):
    class_counts = np.zeros(len(target_counts), dtype=int)
    filtered_indices = []

    for idx, (_, label) in enumerate(dataset):
        if class_counts[label] < target_counts[label]:
            filtered_indices.append(idx)
            class_counts[label] += 1

    return torch.utils.data.Subset(dataset, filtered_indices)

train_set = apply_class_distribution(train_set, class_distribution_train)

# Split data into train and validation sets
train_size = int(0.8 * len(train_set))
val_size = len(train_set) - train_size
train_set, val_set = random_split(train_set, [train_size, val_size])

# Prepare the training and testing datasets
batch_size = 64

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

# Calculate class distributions
train_class_counts = np.bincount([train_set[i][1] for i in range(len(train_set))], minlength=num_classes)
val_class_counts = np.bincount([val_set[i][1] for i in range(len(val_set))], minlength=num_classes)
test_class_counts = np.bincount([label for _, label in test_set], minlength=num_classes)


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import numpy as np
import matplotlib.pyplot as plt

# Optimized Data Preprocessing
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),  # Combined resizing and cropping
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

transform_test = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize directly to target size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load CIFAR-100 dataset
train_set = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
test_set = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)

# Define imbalance factor and number of classes
num_classes = 100

# Assuming imbalance_factor is defined somewhere
mu = (1 / imbalance_factor) ** (1 / (num_classes - 1))

def calculate_class_distribution(dataset, num_classes, mu):
    original_counts = np.bincount([label for _, label in dataset], minlength=num_classes)
    max_samples = max(original_counts)
    target_counts = [int(max_samples * (mu ** i)) for i in range(num_classes)]
    return target_counts

class_distribution_train = calculate_class_distribution(train_set, num_classes, mu)

# Apply class distribution to training set
def apply_class_distribution(dataset, target_counts):
    class_counts = np.zeros(len(target_counts), dtype=int)
    filtered_indices = []

    for idx, (_, label) in enumerate(dataset):
        if class_counts[label] < target_counts[label]:
            filtered_indices.append(idx)
            class_counts[label] += 1

    return torch.utils.data.Subset(dataset, filtered_indices)

train_set = apply_class_distribution(train_set, class_distribution_train)

# Split data into train and validation sets
train_size = int(0.8 * len(train_set))
val_size = len(train_set) - train_size
train_set, val_set = random_split(train_set, [train_size, val_size])

# Prepare the training and testing datasets
batch_size = 64

train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=2)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)

# Calculate class distributions
train_class_counts = np.bincount([train_set[i][1] for i in range(len(train_set))], minlength=num_classes)

# Categorize classes into many, medium, and few
many_classes = [i for i, count in enumerate(train_class_counts) if count > 250]
medium_classes = [i for i, count in enumerate(train_class_counts) if 75 <= count <= 250]
few_classes = [i for i, count in enumerate(train_class_counts) if count < 75]

# Print class categorization
print("Many classes (more than 250 samples):", many_classes)
print("Medium classes (between 75 and 250 samples):", medium_classes)
print("Few classes (less than 75 samples):", few_classes)

#Resnet-32 Model

In [None]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet32(nn.Module):
    def __init__(self, block, num_blocks, num_classes=100):
        super(ResNet32, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def create_resnet32():
    return ResNet32(BasicBlock, [5, 5, 5])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = create_resnet32().to(device)

LOSS TRAINING


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class AdditionalTermLayer(nn.Module):
    def __init__(self, target_class_index, num_classes):
        super(AdditionalTermLayer, self).__init__()
        self.target_class_index = target_class_index
        self.num_classes = num_classes
        self.previous_epoch_class_predictions = None
        self.feature_storage = {i: [] for i in range(num_classes)}
        self.gamma_values = []
        self.class_predictions_history = []
        self.semantic_scales_history = []
        self.entropies = {i: [] for i in range(num_classes)}  # Track entropy for each class

    def compute_entropy(self, class_predictions, num_samples):
        """
        Compute entropy for class i based on its predictions.
        """
        probabilities = class_predictions.float() / num_samples
        non_zero_probs = probabilities[probabilities > 0]
        entropy = -torch.sum(non_zero_probs * torch.log(non_zero_probs + 1e-6))  # Add small value to avoid log(0)
        return entropy.item()

    def forward(self, inputs, true_labels, epoch):
        inputs = torch.nan_to_num(inputs)  # Replace NaNs with zero
        additional_term = 0.0

        class_predictions = torch.argmax(inputs, dim=-1)

        # Store the current batch's features
        for i in range(self.num_classes):
            class_indices = (true_labels == i).nonzero(as_tuple=True)[0]
            if class_indices.size(0) > 0:  # Ensure class_indices is not empty
                self.feature_storage[i].extend(inputs[class_indices].detach().cpu().numpy())

        # Calculate the semantic scale for each class
        semantic_scales = []
        for features in self.feature_storage.values():
            if len(features) > 0:
                features = np.array(features)
                avg_magnitude = np.mean(np.linalg.norm(features, axis=1))
                semantic_scale = avg_magnitude ** 2
                semantic_scales.append(semantic_scale)
            else:
                semantic_scales.append(0.0)

        # Store the semantic scales for all epochs
        self.semantic_scales_history.append(semantic_scales.copy())

        # Calculate class entropies
        class_entropies = []
        num_samples = len(true_labels)
        for i in range(self.num_classes):
            class_indices = (true_labels == i).nonzero(as_tuple=True)[0]
            class_predictions_i = (class_predictions == i).float()
            entropy = self.compute_entropy(class_predictions_i, num_samples)
            self.entropies[i].append(entropy)
            class_entropies.append(entropy)

        # Calculate gamma values using the updated formula
        max_semantic_scale = max(semantic_scales) + 1e-6
        dynamic_gammas = [
            scale / (1e-6 + max_semantic_scale * entropy)
            for scale, entropy in zip(semantic_scales, class_entropies)
        ]

        # Store gamma values for all epochs
        self.gamma_values.append(dynamic_gammas.copy())

        # Calculate the number of predictions for each class
        current_epoch_class_predictions = torch.tensor([
            torch.sum((class_predictions == i).float()).item() for i in range(self.num_classes)
        ])

        # Store class predictions for all epochs
        self.class_predictions_history.append(current_epoch_class_predictions.tolist())

        # Compute the additional term
        for i, gamma in enumerate(dynamic_gammas):
            class_i_predictions = current_epoch_class_predictions[i]
            if i in self.target_class_index:
                if self.previous_epoch_class_predictions is not None:
                    previous_class_i_predictions = self.previous_epoch_class_predictions[i]
                    reinforcement_term = torch.tensor(0.0)
                    if class_i_predictions > previous_class_i_predictions:
                        reinforcement_term = -2.0
                    elif class_i_predictions < previous_class_i_predictions:
                        reinforcement_term = 2.0
                else:
                    reinforcement_term = torch.tensor(0.0)
            else:
                reinforcement_term = torch.tensor(0.0)

            term = (gamma * class_i_predictions + reinforcement_term) ** 2
            denom = torch.sum((inputs - F.one_hot(torch.tensor(i), num_classes=self.num_classes).float().to(inputs.device)) ** 2) + 1e-6  # Add small value to avoid division by zero
            additional_term += term / denom

        # Normalize the additional term
        additional_term /= self.num_classes

        self.previous_epoch_class_predictions = current_epoch_class_predictions

        return additional_term

class CustomLossWithL2AndAdditionalTerm(nn.Module):
    def __init__(self, target_class_index, num_classes):
        super(CustomLossWithL2AndAdditionalTerm, self).__init__()
        self.additional_term_layer = AdditionalTermLayer(target_class_index, num_classes)

    def forward(self, y_true, y_pred, epoch):
        y_true_one_hot = F.one_hot(y_true.squeeze().long(), num_classes=y_pred.size(-1)).float()
        cross_entropy_loss = F.cross_entropy(y_pred, y_true)

        additional_term = self.additional_term_layer(y_pred, y_true, epoch)
        total_loss = cross_entropy_loss + additional_term

        # Debugging statements
        assert not torch.isnan(cross_entropy_loss).any(), "cross_entropy_loss has NaNs"
        assert not torch.isnan(additional_term).any(), "additional_term_layer has NaNs"
        assert not torch.isnan(total_loss).any(), "total_loss has NaNs"

        return total_loss


Training

In [None]:
import torch.optim as optim
import numpy as np

# Hyperparameters
target_class_index = list(range(0, 100))
initial_learning_rate = 0.001
num_epochs = 200
batch_size = 64
num_classes = 100  # Assuming 100 classes

# Optimizer and learning rate scheduler
optimizer = optim.SGD(model.parameters(), lr=initial_learning_rate, momentum=0.9, weight_decay=2e-4)
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 0.99 ** epoch)

# Loss function
criterion = CustomLossWithL2AndAdditionalTerm(target_class_index, num_classes).to(device)

# Add this list to store validation accuracy
val_accuracies = []

# Initialize class-wise correct and total counters
train_class_correct = np.zeros(num_classes)
train_class_total = np.zeros(num_classes)
val_class_correct = np.zeros(num_classes)
val_class_total = np.zeros(num_classes)

# Define thresholds for many, medium, and few
many_threshold = 250
medium_min_threshold = 75

# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(labels, outputs, epoch)  # Pass the current epoch to the loss function
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        # Update class-wise statistics
        for label, prediction in zip(labels, predicted):
            train_class_total[label] += 1
            if prediction == label:
                train_class_correct[label] += 1

    train_loss /= len(train_loader.dataset)
    train_accuracy = 100. * correct / total

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(labels, outputs, epoch)  # Pass the current epoch to the loss function

            val_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

            # Update class-wise statistics for validation
            for label, prediction in zip(labels, predicted):
                val_class_total[label] += 1
                if prediction == label:
                    val_class_correct[label] += 1

    val_loss /= len(val_loader.dataset)
    val_accuracy = 100. * correct / total

    # Store validation accuracy for plotting later
    val_accuracies.append(val_accuracy)

    # Learning rate update
    scheduler.step()

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}%, '
          f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')

# Calculate overall accuracies for many, medium, and few classes
train_many_correct = train_class_correct[train_class_total > many_threshold].sum()
train_many_total = train_class_total[train_class_total > many_threshold].sum()
train_medium_correct = train_class_correct[(train_class_total <= many_threshold) & (train_class_total >= medium_min_threshold)].sum()
train_medium_total = train_class_total[(train_class_total <= many_threshold) & (train_class_total >= medium_min_threshold)].sum()
train_few_correct = train_class_correct[train_class_total < medium_min_threshold].sum()
train_few_total = train_class_total[train_class_total < medium_min_threshold].sum()

# Print overall and category-wise accuracy
print("\nOverall Accuracy: {:.2f}%".format(100. * correct / total))
print("Many Classes Accuracy: {:.2f}%".format(100. * train_many_correct / train_many_total if train_many_total > 0 else 0))
print("Medium Classes Accuracy: {:.2f}%".format(100. * train_medium_correct / train_medium_total if train_medium_total > 0 else 0))
print("Few Classes Accuracy: {:.2f}%".format(100. * train_few_correct / train_few_total if train_few_total > 0 else 0))

# Function to plot validation accuracy
def plot_validation_accuracy(val_accuracies):
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy', color='blue')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy (%)')
    plt.title('Validation Accuracy Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.show()

# Plot the validation accuracy
plot_validation_accuracy(val_accuracies)
