In [2]:
import torch
import numpy as np
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.models import ResNet18_Weights
import torchvision.transforms as transforms
from torchvision.datasets import EMNIST
from torch.utils.data import random_split
from sklearn.metrics import precision_score, f1_score
import torch.nn.functional as F

class ModifiedCrossEntropyLoss(nn.Module):
    def __init__(self, penalty_weight=0.1):
        super(ModifiedCrossEntropyLoss, self).__init__()
        self.penalty_weight = penalty_weight

    def forward(self, inputs, targets):
        # Calculate probabilities using softmax
        probs = F.softmax(inputs, dim=1)  # Get probabilities from raw logits

        # Standard cross-entropy loss for the true class
        loss_ce = torch.log(probs[range(targets.size(0)), targets] + 1e-12).mean()

        # Calculate the penalty for all classes except the true class
        penalty = self.penalty_weight * (torch.sum(torch.log(1 - probs + 1e-12), dim=1) - 
                                          torch.log(1 - probs[range(targets.size(0)), targets] + 1e-12))

        # Final loss
        total_loss = loss_ce + penalty.mean()
        return -total_loss

def conv_orth_dist(kernel, stride=1):
    [o_c, i_c, w, h] = kernel.shape
    assert (w == h), "Do not support rectangular kernel"

    # Check if both stride and kernel size are 1, return zero if true
    if stride == 1 and w == 1:
        return torch.tensor(0.0).cuda()

    assert stride < w, f"Warning: Stride {stride} is larger than or equal to kernel size {w}."

    new_s = stride * (w - 1) + w
    temp = torch.eye(new_s * new_s * i_c).reshape((new_s * new_s * i_c, i_c, new_s, new_s)).cuda()
    out = (F.conv2d(temp, kernel, stride=stride)).reshape((new_s * new_s * i_c, -1))

    Vmat = out[np.floor(new_s**2 / 2).astype(int)::new_s**2, :]
    temp = np.zeros((i_c, i_c * new_s**2))
    for i in range(temp.shape[0]):
        temp[i, np.floor(new_s**2 / 2).astype(int) + new_s**2 * i] = 1

    return torch.norm(Vmat @ torch.t(out) - torch.from_numpy(temp).float().cuda())

def deconv_orth_dist(kernel, stride=2, padding=1):
    [o_c, i_c, w, h] = kernel.shape
    output = F.conv_transpose2d(kernel, kernel, stride=stride, padding=padding)
    target = torch.zeros((o_c, o_c, output.shape[-2], output.shape[-1])).cuda()
    ct = int(np.floor(output.shape[-1] / 2))
    target[:, :, ct, ct] = torch.eye(o_c).cuda()
    return torch.norm(output - target)

def orthogonal_regularizer(model):
    orthogonality_loss = 0.0
    for layer in model.modules():
        if isinstance(layer, nn.Conv2d):
            orthogonality_loss += conv_orth_dist(layer.weight)
        elif isinstance(layer, nn.ConvTranspose2d):
            orthogonality_loss += deconv_orth_dist(layer.weight)
    
    return orthogonality_loss

class ImageClassifier:
    def __init__(self, network, optimizer, criterion, l2_lambda=0.01, regularize=False):
        self.network = network
        self.optimizer = optimizer
        self.criterion = criterion
        self.l2_lambda = l2_lambda
        self.regularize = regularize
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.network.to(self.device)
    
    def _regularize(self, network, reg_lambda):
        # Compute L2 regularization
        l2_reg = 0.0
        for param in network.parameters():
            l2_reg += torch.norm(param, 2)
        
        return reg_lambda * l2_reg
            
    def orthogonalize(self, network, reg_lambda):
        orthogonality_loss = 0.0
        for layer in network.modules():
            if isinstance(layer, (nn.Conv2d, nn.ConvTranspose2d)):
                # Apply regularization only to kernels with requires_grad=True
                if layer.weight.requires_grad:
                    if isinstance(layer, nn.Conv2d):
                        orthogonality_loss += conv_orth_dist(layer.weight)
                    elif isinstance(layer, nn.ConvTranspose2d):
                        orthogonality_loss += deconv_orth_dist(layer.weight)
        return reg_lambda * orthogonality_loss

    def compute_loss(self, outputs, targets, reg_lambda=0.01):
        # Compute the cross-entropy loss
        ce_loss = self.criterion(outputs, targets)
        
        # Compute regularization loss
        if self.regularize:
            # reg = self._regularize(self.network, reg_lambda)
            reg = self.orthogonalize(self.network, reg_lambda)
            return ce_loss + reg
            
        return ce_loss
    
    def compute_metrics(self, preds, targets):
        """Helper function to compute accuracy, precision, and F1 score."""
        # Ensure preds are already in label form (if not already converted)
        if preds.dim() > 1:  # Check if preds need reduction
            preds = preds.argmax(dim=1)  # Get the predicted labels
        
        preds = preds.cpu().numpy()  # Convert predictions to NumPy
        targets = targets.cpu().numpy()  # Convert true labels to NumPy

        # Compute accuracy
        accuracy = (preds == targets).mean()

        # Compute precision and F1 score using scikit-learn
        precision = precision_score(targets, preds, average='weighted', zero_division=0)
        f1 = f1_score(targets, preds, average='weighted')

        return accuracy, precision, f1

    def train(self, train_loader, val_loader, n_epochs=10, patience=3, reg_lambda=0.01):
        best_val_loss = float('inf')
        current_patience = 0
        
        for epoch in range(n_epochs):
            # Train
            self.network.train()
            train_loss = 0.0
            all_preds = []
            all_targets = []
            
            # Use tqdm for progress bar and set dynamic description
            train_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Training Epoch {epoch + 1}')
            for batch_idx, (data, target) in train_bar:
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                
                # Forward pass
                outputs = self.network(data)
                
                # Compute loss
                loss = self.compute_loss(outputs, target, reg_lambda)
                loss.backward()
                self.optimizer.step()
                
                train_loss += loss.item()

                # Gather predictions and true labels for accuracy/metrics calculation
                preds = outputs.argmax(dim=1)
                all_preds.append(preds)
                all_targets.append(target)
                
                # Update progress bar with loss and accuracy
                current_accuracy, _, _ = self.compute_metrics(torch.cat(all_preds), torch.cat(all_targets))
                train_bar.set_postfix(loss=train_loss / (batch_idx + 1), accuracy=current_accuracy)

            # Calculate final metrics for training
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            train_accuracy, train_precision, train_f1 = self.compute_metrics(all_preds, all_targets)
            
            # Validate
            self.network.eval()
            val_loss = 0.0
            val_preds = []
            val_targets = []
            
            # Use tqdm for validation progress bar
            val_bar = tqdm(val_loader, desc='Validating')
            with torch.no_grad():
                for data, target in val_bar:
                    data, target = data.to(self.device), target.to(self.device)
                    
                    # Forward pass
                    outputs = self.network(data)
                    
                    # Compute loss
                    loss = self.compute_loss(outputs, target, reg_lambda)
                    val_loss += loss.item()
                    
                    # Gather predictions and true labels
                    preds = outputs.argmax(dim=1)
                    val_preds.append(preds)
                    val_targets.append(target)

                    # Update progress bar with validation loss and accuracy
                    val_accuracy, _, _ = self.compute_metrics(torch.cat(val_preds), torch.cat(val_targets))
                    val_bar.set_postfix(val_loss=val_loss / len(val_loader), accuracy=val_accuracy)

            # Calculate final validation metrics
            val_preds = torch.cat(val_preds)
            val_targets = torch.cat(val_targets)
            val_accuracy, val_precision, val_f1 = self.compute_metrics(val_preds, val_targets)

            # Print epoch statistics
            train_loss /= len(train_loader)
            val_loss /= len(val_loader)
            print(f'Epoch {epoch + 1}/{n_epochs}, '
                  f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
                  f'Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}, '
                  f'Train Prec: {train_precision:.4f}, Val Prec: {val_precision:.4f}, '
                  f'Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}')
            
            # Check for early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                current_patience = 0
            else:
                current_patience += 1
                if current_patience >= patience:
                    print(f'Validation loss did not improve for {patience} epochs. Stopping training.')
                    break
    
    def test(self, test_loader, reg_lambda=0.01):
        self.network.eval()
        test_loss = 0.0
        correct = 0
        all_preds = []
        all_targets = []
        
        # Use tqdm for test progress bar
        test_bar = tqdm(test_loader, desc='Testing')
        with torch.no_grad():
            for data, target in test_bar:
                data, target = data.to(self.device), target.to(self.device)
                
                # Forward pass
                outputs = self.network(data)
                
                # Compute loss
                loss = self.compute_loss(outputs, target, reg_lambda)
                test_loss += loss.item()
                
                # Gather predictions and true labels for accuracy/metrics calculation
                preds = outputs.argmax(dim=1)
                all_preds.append(preds)
                all_targets.append(target)
                
                # Update progress bar with test loss and accuracy
                accuracy, _, _ = self.compute_metrics(torch.cat(all_preds), torch.cat(all_targets))
                test_bar.set_postfix(loss=test_loss / len(test_loader), accuracy=accuracy)

        # Calculate final test metrics
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        accuracy, precision, f1 = self.compute_metrics(all_preds, all_targets)

        test_loss /= len(test_loader)
        print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.2f}, F1 Score: {f1:.2f}')
        
# Define transformation for the images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224
    transforms.ToTensor(),            # Convert to tensor (1 channel)
    transforms.Lambda(lambda x: x.repeat(3, 1, 1)),  # Convert 1 channel to 3 channels (RGB)
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for RGB
])

# Download the EMNIST ByClass dataset
emnist_dataset = EMNIST(root='data', split='byclass', train=True, download=True, transform=transform)
test_dataset = EMNIST(root='data', split='byclass', train=False, download=True, transform=transform)

# Define the sizes for the training and validation sets
train_size = int(0.85 * len(emnist_dataset))  # 80% for training
val_size = len(emnist_dataset) - train_size   # remaining 20% for validation

# Split the dataset into training and validation sets
train_dataset, val_dataset = random_split(emnist_dataset, [train_size, val_size])

print(f'Training set size: {len(train_dataset)}')
print(f'Validation set size: {len(val_dataset)}')
print(f'Test set size: {len(test_dataset)}')

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=512)
test_loader = DataLoader(test_dataset, batch_size=512)

# Example neural network architecture using ResNet-18
class ResNet18Classifier(nn.Module):
    def __init__(self, num_classes=62):
        super(ResNet18Classifier, self).__init__()
        self.resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        
        for name, child in self.resnet.named_children():
            if name in ['layer1', 'layer2', 'layer3']:
                for param in child.parameters():
                    param.requires_grad = False
            
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)

    def forward(self, x):
        return self.resnet(x)

# Initialize the neural network, optimizer, and criterion
model = ResNet18Classifier(num_classes=62)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = ModifiedCrossEntropyLoss(penalty_weight=0.1)

# Create an instance of ImageClassifier
classifier = ImageClassifier(model, optimizer, criterion, regularize=True)

# Train the classifier
classifier.train(train_loader, val_loader, n_epochs=10, patience=3, reg_lambda=0.01)

# Test the classifier
classifier.test(test_loader, reg_lambda=0.01)

Training set size: 593242
Validation set size: 104690
Test set size: 116323


Training Epoch 1: 100%|██████████| 1159/1159 [27:19<00:00,  1.41s/it, accuracy=0.846, loss=1.31]
Validating: 100%|██████████| 205/205 [03:35<00:00,  1.05s/it, accuracy=0.822, val_loss=0.83] 


Epoch 1/10, Train Loss: 1.3111, Val Loss: 0.8299, Train Acc: 0.8460, Val Acc: 0.8217, Train Prec: 0.8284, Val Prec: 0.8288, Train F1: 0.8309, Val F1: 0.7865


Training Epoch 2: 100%|██████████| 1159/1159 [28:58<00:00,  1.50s/it, accuracy=0.859, loss=0.565]
Validating: 100%|██████████| 205/205 [09:15<00:00,  2.71s/it, accuracy=0.739, val_loss=1.08] 


Epoch 2/10, Train Loss: 0.5655, Val Loss: 1.0759, Train Acc: 0.8591, Val Acc: 0.7388, Train Prec: 0.8466, Val Prec: 0.7827, Train F1: 0.8445, Val F1: 0.7190


Training Epoch 3: 100%|██████████| 1159/1159 [1:04:58<00:00,  3.36s/it, accuracy=0.863, loss=0.492]
Validating: 100%|██████████| 205/205 [07:49<00:00,  2.29s/it, accuracy=0.838, val_loss=0.561]


Epoch 3/10, Train Loss: 0.4921, Val Loss: 0.5609, Train Acc: 0.8635, Val Acc: 0.8384, Train Prec: 0.8526, Val Prec: 0.8503, Train F1: 0.8497, Val F1: 0.8050


Training Epoch 4: 100%|██████████| 1159/1159 [1:02:36<00:00,  3.24s/it, accuracy=0.866, loss=0.462]
Validating: 100%|██████████| 205/205 [10:47<00:00,  3.16s/it, accuracy=0.786, val_loss=0.735]


Epoch 4/10, Train Loss: 0.4618, Val Loss: 0.7350, Train Acc: 0.8663, Val Acc: 0.7860, Train Prec: 0.8557, Val Prec: 0.8074, Train F1: 0.8532, Val F1: 0.7745


Training Epoch 5:   0%|          | 0/1159 [00:03<?, ?it/s]


KeyboardInterrupt: 

In [3]:
classifier.test(test_loader)

# Save the model after training
torch.save(classifier.network.state_dict(), 'resnet18_classifier_setting4.pth')

Testing: 100%|██████████| 228/228 [11:03<00:00,  2.91s/it, accuracy=0.793, loss=0.705]

Test Loss: 0.7055, Accuracy: 0.79%, Precision: 0.82, F1 Score: 0.78



