In [1]:
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
from torchvision.models import ResNet18_Weights
import torchvision.transforms as transforms
from torchvision.datasets import EMNIST
from torch.utils.data import random_split
from sklearn.metrics import precision_score, f1_score
import torch.nn.functional as F
from archs import *

class ModifiedCrossEntropyLoss(nn.Module):
    def __init__(self, penalty_weight=0.1):
        super(ModifiedCrossEntropyLoss, self).__init__()
        self.penalty_weight = penalty_weight

    def forward(self, inputs, targets):
        # Calculate probabilities using softmax
        probs = F.softmax(inputs, dim=1)  # Get probabilities from raw logits

        # Standard cross-entropy loss for the true class
        loss_ce = torch.log(probs[range(targets.size(0)), targets] + 1e-12).mean()

        # Calculate the penalty for all classes except the true class
        penalty = self.penalty_weight * (torch.sum(torch.log(1 - probs + 1e-12), dim=1) - 
                                          torch.log(1 - probs[range(targets.size(0)), targets] + 1e-12))

        # Final loss
        total_loss = loss_ce + penalty.mean()
        return -total_loss

class ImageClassifier:
    def __init__(self, network, optimizer, criterion, l2_lambda=0.01):
        self.network = network
        self.optimizer = optimizer
        self.criterion = criterion
        self.l2_lambda = l2_lambda
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.network.to(self.device)
    
    def _regularize(self, network, l2_lambda):
        # Compute L2 regularization
        l2_reg = 0.0
        for param in network.parameters():
            l2_reg += torch.norm(param, 2)
        return l2_lambda * l2_reg
            
    def compute_loss(self, outputs, targets, l2_lambda=0.01, regularize = False):
        # Compute the cross-entropy loss
        ce_loss = self.criterion(outputs, targets)
        
        if regularize:
            # Compute regularization loss
            l2_reg = self._regularize(self.network, l2_lambda)
            
            return ce_loss + l2_reg

        return ce_loss
    
    def compute_metrics(self, preds, targets):
        """Helper function to compute accuracy, precision, and F1 score."""
        # Ensure preds are already in label form (if not already converted)
        if preds.dim() > 1:  # Check if preds need reduction
            preds = preds.argmax(dim=1)  # Get the predicted labels
        
        preds = preds.cpu().numpy()  # Convert predictions to NumPy
        targets = targets.cpu().numpy()  # Convert true labels to NumPy

        # Compute accuracy
        accuracy = (preds == targets).mean()

        # Compute precision and F1 score using scikit-learn
        precision = precision_score(targets, preds, average='weighted', zero_division=0)
        f1 = f1_score(targets, preds, average='weighted')

        return accuracy, precision, f1

    def train(self, train_loader, val_loader, n_epochs=10, patience=3):
        best_val_loss = float('inf')
        current_patience = 0
        
        for epoch in range(n_epochs):
            # Train
            self.network.train()
            train_loss = 0.0
            all_preds = []
            all_targets = []
            
            # Use tqdm for progress bar and set dynamic description
            train_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'Training Epoch {epoch + 1}')
            for batch_idx, (data, target) in train_bar:
                data, target = data.to(self.device), target.to(self.device)
                self.optimizer.zero_grad()
                
                # Forward pass
                outputs = self.network(data)
                
                # Compute loss
                loss = self.compute_loss(outputs, target)
                loss.backward()
                self.optimizer.step()
                
                train_loss += loss.item()

                # Gather predictions and true labels for accuracy/metrics calculation
                preds = outputs.argmax(dim=1)
                all_preds.append(preds)
                all_targets.append(target)
                
                # Update progress bar with loss and accuracy
                current_accuracy, _, _ = self.compute_metrics(torch.cat(all_preds), torch.cat(all_targets))
                train_bar.set_postfix(loss=train_loss / (batch_idx + 1), accuracy=current_accuracy)

            # Calculate final metrics for training
            all_preds = torch.cat(all_preds)
            all_targets = torch.cat(all_targets)
            train_accuracy, train_precision, train_f1 = self.compute_metrics(all_preds, all_targets)
            
            # Validate
            self.network.eval()
            val_loss = 0.0
            val_preds = []
            val_targets = []
            
            # Use tqdm for validation progress bar
            val_bar = tqdm(val_loader, desc='Validating')
            with torch.no_grad():
                for data, target in val_bar:
                    data, target = data.to(self.device), target.to(self.device)
                    
                    # Forward pass
                    outputs = self.network(data)
                    
                    # Compute loss
                    loss = self.compute_loss(outputs, target)
                    val_loss += loss.item()
                    
                    # Gather predictions and true labels
                    preds = outputs.argmax(dim=1)
                    val_preds.append(preds)
                    val_targets.append(target)

                    # Update progress bar with validation loss and accuracy
                    val_accuracy, _, _ = self.compute_metrics(torch.cat(val_preds), torch.cat(val_targets))
                    val_bar.set_postfix(val_loss=val_loss / len(val_loader), accuracy=val_accuracy)

            # Calculate final validation metrics
            val_preds = torch.cat(val_preds)
            val_targets = torch.cat(val_targets)
            val_accuracy, val_precision, val_f1 = self.compute_metrics(val_preds, val_targets)

            # Print epoch statistics
            train_loss /= len(train_loader)
            val_loss /= len(val_loader)
            print(f'Epoch {epoch + 1}/{n_epochs}, '
                  f'Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, '
                  f'Train Acc: {train_accuracy:.4f}, Val Acc: {val_accuracy:.4f}, '
                  f'Train Prec: {train_precision:.4f}, Val Prec: {val_precision:.4f}, '
                  f'Train F1: {train_f1:.4f}, Val F1: {val_f1:.4f}')
            
            # Check for early stopping
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                current_patience = 0
            else:
                current_patience += 1
                if current_patience >= patience:
                    print(f'Validation loss did not improve for {patience} epochs. Stopping training.')
                    break
    
    def test(self, test_loader):
        self.network.eval()
        test_loss = 0.0
        correct = 0
        all_preds = []
        all_targets = []
        
        # Use tqdm for test progress bar
        test_bar = tqdm(test_loader, desc='Testing')
        with torch.no_grad():
            for data, target in test_bar:
                data, target = data.to(self.device), target.to(self.device)
                
                # Forward pass
                outputs = self.network(data)
                
                # Compute loss
                loss = self.compute_loss(outputs, target)
                test_loss += loss.item()
                
                # Gather predictions and true labels for accuracy/metrics calculation
                preds = outputs.argmax(dim=1)
                all_preds.append(preds)
                all_targets.append(target)
                
                # Update progress bar with test loss and accuracy
                accuracy, _, _ = self.compute_metrics(torch.cat(all_preds), torch.cat(all_targets))
                test_bar.set_postfix(loss=test_loss / len(test_loader), accuracy=accuracy)

        # Calculate final test metrics
        all_preds = torch.cat(all_preds)
        all_targets = torch.cat(all_targets)
        accuracy, precision, f1 = self.compute_metrics(all_preds, all_targets)

        test_loss /= len(test_loader)
        print(f'Test Loss: {test_loss:.4f}, Accuracy: {accuracy:.2f}%, Precision: {precision:.2f}, F1 Score: {f1:.2f}')
        
# Define transformation for the images
transform = transforms.Compose([
    transforms.ToTensor(),            # Convert to tensor (1 channel)
    transforms.Normalize((0.5), (0.5))  # Normalize for RGB
])

# Download the EMNIST ByClass dataset
emnist_dataset = EMNIST(root='data', split='byclass', train=True, download=True, transform=transform)
test_dataset = EMNIST(root='data', split='byclass', train=False, download=True, transform=transform)

# Define the sizes for the training and validation sets
train_size = int(0.85 * len(emnist_dataset))  # 80% for training
val_size = len(emnist_dataset) - train_size   # remaining 20% for validation

# Split the dataset into training and validation sets
train_dataset, val_dataset = random_split(emnist_dataset, [train_size, val_size])

print(f'Training set size: {len(train_dataset)}')
print(f'Validation set size: {len(val_dataset)}')
print(f'Test set size: {len(test_dataset)}')

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1024)
test_loader = DataLoader(test_dataset, batch_size=1024)

def Arch3_heavy(num_classes):
    return DenseNetMod(num_classes=num_classes, growth_rate=24, block_layers=[6, 6])

# Initialize the neural network, optimizer, and criterion
model = Arch4(num_classes=62)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = ModifiedCrossEntropyLoss(penalty_weight=0.1)

# Create an instance of ImageClassifier
classifier = ImageClassifier(model, optimizer, criterion)

# Train the classifier
classifier.train(train_loader, val_loader, n_epochs=10, patience=3)

# Test the classifier
classifier.test(test_loader)

torch.save(model.state_dict(), 'models/model4.pth')

Training set size: 593242
Validation set size: 104690
Test set size: 116323


Training Epoch 1: 100%|██████████| 580/580 [03:26<00:00,  2.81it/s, accuracy=0.803, loss=0.681]
Validating: 100%|██████████| 103/103 [00:17<00:00,  5.74it/s, accuracy=0.85, val_loss=0.448] 


Epoch 1/10, Train Loss: 0.6808, Val Loss: 0.4479, Train Acc: 0.8033, Val Acc: 0.8498, Train Prec: 0.7846, Val Prec: 0.8369, Train F1: 0.7848, Val F1: 0.8362


Training Epoch 2: 100%|██████████| 580/580 [03:18<00:00,  2.92it/s, accuracy=0.861, loss=0.413]
Validating: 100%|██████████| 103/103 [00:18<00:00,  5.47it/s, accuracy=0.86, val_loss=0.414] 


Epoch 2/10, Train Loss: 0.4127, Val Loss: 0.4142, Train Acc: 0.8608, Val Acc: 0.8605, Train Prec: 0.8485, Val Prec: 0.8459, Train F1: 0.8461, Val F1: 0.8460


Training Epoch 3: 100%|██████████| 580/580 [03:18<00:00,  2.92it/s, accuracy=0.867, loss=0.389]
Validating: 100%|██████████| 103/103 [00:17<00:00,  5.76it/s, accuracy=0.865, val_loss=0.4]  


Epoch 3/10, Train Loss: 0.3894, Val Loss: 0.3995, Train Acc: 0.8668, Val Acc: 0.8654, Train Prec: 0.8559, Val Prec: 0.8546, Train F1: 0.8531, Val F1: 0.8484


Training Epoch 4: 100%|██████████| 580/580 [03:04<00:00,  3.15it/s, accuracy=0.87, loss=0.376] 
Validating: 100%|██████████| 103/103 [00:18<00:00,  5.46it/s, accuracy=0.863, val_loss=0.401]


Epoch 4/10, Train Loss: 0.3760, Val Loss: 0.4009, Train Acc: 0.8703, Val Acc: 0.8629, Train Prec: 0.8608, Val Prec: 0.8554, Train F1: 0.8572, Val F1: 0.8470


Training Epoch 5: 100%|██████████| 580/580 [03:13<00:00,  3.00it/s, accuracy=0.873, loss=0.366]
Validating: 100%|██████████| 103/103 [00:18<00:00,  5.72it/s, accuracy=0.864, val_loss=0.395]


Epoch 5/10, Train Loss: 0.3661, Val Loss: 0.3951, Train Acc: 0.8728, Val Acc: 0.8641, Train Prec: 0.8625, Val Prec: 0.8566, Train F1: 0.8602, Val F1: 0.8517


Training Epoch 6: 100%|██████████| 580/580 [03:21<00:00,  2.88it/s, accuracy=0.874, loss=0.359]
Validating: 100%|██████████| 103/103 [00:20<00:00,  4.98it/s, accuracy=0.865, val_loss=0.393]


Epoch 6/10, Train Loss: 0.3592, Val Loss: 0.3932, Train Acc: 0.8745, Val Acc: 0.8653, Train Prec: 0.8658, Val Prec: 0.8584, Train F1: 0.8624, Val F1: 0.8519


Training Epoch 7: 100%|██████████| 580/580 [03:15<00:00,  2.97it/s, accuracy=0.876, loss=0.353]
Validating: 100%|██████████| 103/103 [00:18<00:00,  5.63it/s, accuracy=0.866, val_loss=0.391]


Epoch 7/10, Train Loss: 0.3532, Val Loss: 0.3906, Train Acc: 0.8760, Val Acc: 0.8661, Train Prec: 0.8673, Val Prec: 0.8603, Train F1: 0.8643, Val F1: 0.8518


Training Epoch 8: 100%|██████████| 580/580 [03:17<00:00,  2.94it/s, accuracy=0.877, loss=0.347]
Validating: 100%|██████████| 103/103 [00:19<00:00,  5.41it/s, accuracy=0.867, val_loss=0.39] 


Epoch 8/10, Train Loss: 0.3475, Val Loss: 0.3901, Train Acc: 0.8775, Val Acc: 0.8669, Train Prec: 0.8693, Val Prec: 0.8593, Train F1: 0.8660, Val F1: 0.8547


Training Epoch 9: 100%|██████████| 580/580 [03:16<00:00,  2.96it/s, accuracy=0.879, loss=0.343]
Validating: 100%|██████████| 103/103 [00:18<00:00,  5.70it/s, accuracy=0.869, val_loss=0.388]


Epoch 9/10, Train Loss: 0.3432, Val Loss: 0.3875, Train Acc: 0.8791, Val Acc: 0.8689, Train Prec: 0.8713, Val Prec: 0.8637, Train F1: 0.8679, Val F1: 0.8564


Training Epoch 10: 100%|██████████| 580/580 [03:17<00:00,  2.94it/s, accuracy=0.88, loss=0.339] 
Validating: 100%|██████████| 103/103 [00:18<00:00,  5.60it/s, accuracy=0.867, val_loss=0.388]


Epoch 10/10, Train Loss: 0.3387, Val Loss: 0.3880, Train Acc: 0.8799, Val Acc: 0.8674, Train Prec: 0.8724, Val Prec: 0.8622, Train F1: 0.8690, Val F1: 0.8553


Testing: 100%|██████████| 114/114 [00:19<00:00,  5.75it/s, accuracy=0.869, loss=0.382]


Test Loss: 0.3825, Accuracy: 0.87%, Precision: 0.86, F1 Score: 0.86
