In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets
from efficientnet_pytorch import EfficientNet
from tqdm import tqdm
import os
import numpy as np
import matplotlib.pyplot as plt
from albumentations import (
    Compose, Resize, RandomResizedCrop, HorizontalFlip, VerticalFlip,
    ColorJitter, Rotate, Affine, Normalize, ToTensorV2
)
from albumentations.pytorch import ToTensorV2

In [None]:
class DocumentClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(DocumentClassifier, self).__init__()
        self.backbone = EfficientNet.from_pretrained('efficientnet-b0')
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        num_features = self.backbone._fc.in_features
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = self.backbone.extract_features(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [None]:
# Albumentations transformations
def get_train_transforms():
    return Compose([
        Resize(224, 224),
        RandomResizedCrop(224, 224),
        HorizontalFlip(),
        VerticalFlip(),
        ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        Rotate(limit=30),
        Affine(translate_percent=(0.1, 0.1)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

In [None]:
def get_val_test_transforms():
    return Compose([
        Resize(224, 224),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

In [None]:
base_dest_dir = '/mnt/c/Users/Rahul/Desktop/Datasets'
train_transform = AlbumentationsTransform(get_train_transforms())
val_test_transform = AlbumentationsTransform(get_val_test_transforms())

In [None]:
train_dataset = datasets.ImageFolder(root=os.path.join(base_dest_dir, 'train'), transform=train_transform)
val_dataset = datasets.ImageFolder(root=os.path.join(base_dest_dir, 'val'), transform=val_test_transform)
test_dataset = datasets.ImageFolder(root=os.path.join(base_dest_dir, 'test'), transform=val_test_transform)

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DocumentClassifier(num_classes=3)
print(device)
model.to(device)

In [None]:
# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
# Early stopping
class EarlyStopping:
    def __init__(self, patience=5, verbose=False):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        elif val_loss >= self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

early_stopping = EarlyStopping(patience=5, verbose=True)

In [None]:
def train_epoch(loader):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
def validate_epoch(loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    return epoch_loss, epoch_acc

In [None]:
# Test function
def test(loader):
    model.eval()
    correct = 0
    total = 0
    all_labels = []
    all_preds = []
    with torch.no_grad():
        for images, labels in tqdm(loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
    
    accuracy = correct / total
    print(f'Test Accuracy: {accuracy:.4f}')
    return all_labels, all_preds

In [None]:
num_epochs = 50
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train_loss, train_acc = train_epoch(train_loader)
    val_loss, val_acc = validate_epoch(val_loader)
    print(f'Train Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}')
    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.4f}')
    
    early_stopping(val_loss)
    if early_stopping.early_stop:
        print("Early stopping")
        break

In [None]:
actual_labels, predicted_labels = test(test_loader)

In [None]:
class_mapping = {0: 'Citizenship', 1: 'License', 2: 'Passport'}

# Display actual vs predicted labels
for i in range(len(actual_labels)):
    print(f'Actual: {class_mapping[actual_labels[i]]}, Predicted: {class_mapping[predicted_labels[i]]}')

# Save the model
model_save_path = 'document_classifier.pth'
torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')