In [66]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from albumentations import Compose, Resize, RandomResizedCrop, HorizontalFlip, VerticalFlip, ColorJitter, Rotate, Affine, Normalize
from albumentations.pytorch import ToTensorV2
from PIL import Image
from efficientnet_pytorch import EfficientNet
import os
import shutil
import random
from pathlib import Path
from tqdm import tqdm


In [67]:
class DocumentClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(DocumentClassifier, self).__init__()
        
        self.backbone = EfficientNet.from_pretrained('efficientnet-b0')
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        num_features = self.backbone._fc.in_features
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        x = self.backbone.extract_features(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [68]:
def get_train_transforms():
    return Compose([
        Resize(224, 224),
        RandomResizedCrop(224, 224),
        HorizontalFlip(),
        VerticalFlip(),
        ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
        Rotate(limit=30),
        Affine(translate_percent=(0.1, 0.1)),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_test_transforms():
    return Compose([
        Resize(224, 224),
        Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

class AlbumentationsTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        img = np.array(img)
        img = self.transform(image=img)['image']
        return img

In [69]:
train_transform = AlbumentationsTransform(get_train_transforms())
val_test_transform = AlbumentationsTransform(get_val_test_transforms())

In [70]:
train_dataset = datasets.ImageFolder(root=os.path.join(base_dest_dir, 'train'), transform=train_transform)
val_dataset = datasets.ImageFolder(root=os.path.join(base_dest_dir, 'val'), transform=val_test_transform)
test_dataset = datasets.ImageFolder(root=os.path.join(base_dest_dir, 'test'), transform=val_test_transform)

In [71]:
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(dataset=val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(dataset=test_dataset, batch_size=32, shuffle=False, num_workers=4)

In [72]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DocumentClassifier(num_classes=3)
model.to(device)

Loaded pretrained weights for efficientnet-b0


DocumentClassifier(
  (backbone): EfficientNet(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSam

In [73]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [74]:
def train_epoch(loader):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in tqdm(loader):
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        outputs = model(images)
        
        # Compute loss
        loss = criterion(outputs, labels)
        
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    print(f'Train Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

In [75]:
def validate_epoch(loader):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            outputs = model(images)
            
            # Compute loss
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = correct / total
    print(f'Validation Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

In [80]:
def test(loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = correct / total
    print(f'Test Accuracy: {accuracy:.4f}')

In [77]:
num_epochs = 50
for epoch in range(num_epochs):
    print(f'Epoch {epoch+1}/{num_epochs}')
    train_epoch(train_loader)
    validate_epoch(val_loader)

test(test_loader)

Epoch 1/50


100%|██████████| 120/120 [00:23<00:00,  5.04it/s]


Train Loss: 0.6758, Accuracy: 0.9444


100%|██████████| 35/35 [00:04<00:00,  7.99it/s]


Validation Loss: 0.5741, Accuracy: 0.9780
Epoch 2/50


100%|██████████| 120/120 [00:21<00:00,  5.61it/s]


Train Loss: 0.5677, Accuracy: 0.9864


100%|██████████| 35/35 [00:04<00:00,  7.64it/s]


Validation Loss: 0.5633, Accuracy: 0.9881
Epoch 3/50


100%|██████████| 120/120 [00:21<00:00,  5.61it/s]


Train Loss: 0.5645, Accuracy: 0.9885


100%|██████████| 35/35 [00:04<00:00,  7.88it/s]


Validation Loss: 0.5584, Accuracy: 0.9927
Epoch 4/50


100%|██████████| 120/120 [00:21<00:00,  5.58it/s]


Train Loss: 0.5595, Accuracy: 0.9929


100%|██████████| 35/35 [00:04<00:00,  7.80it/s]


Validation Loss: 0.5560, Accuracy: 0.9963
Epoch 5/50


100%|██████████| 120/120 [00:21<00:00,  5.59it/s]


Train Loss: 0.5567, Accuracy: 0.9955


100%|██████████| 35/35 [00:04<00:00,  7.86it/s]


Validation Loss: 0.5539, Accuracy: 0.9982
Epoch 6/50


100%|██████████| 120/120 [00:22<00:00,  5.36it/s]


Train Loss: 0.5552, Accuracy: 0.9974


100%|██████████| 35/35 [00:04<00:00,  7.68it/s]


Validation Loss: 0.5541, Accuracy: 0.9982
Epoch 7/50


100%|██████████| 120/120 [00:22<00:00,  5.44it/s]


Train Loss: 0.5554, Accuracy: 0.9969


100%|██████████| 35/35 [00:04<00:00,  7.97it/s]


Validation Loss: 0.5568, Accuracy: 0.9945
Epoch 8/50


100%|██████████| 120/120 [00:21<00:00,  5.50it/s]


Train Loss: 0.5543, Accuracy: 0.9974


100%|██████████| 35/35 [00:04<00:00,  7.14it/s]


Validation Loss: 0.5532, Accuracy: 0.9991
Epoch 9/50


100%|██████████| 120/120 [00:21<00:00,  5.46it/s]


Train Loss: 0.5545, Accuracy: 0.9969


100%|██████████| 35/35 [00:04<00:00,  7.03it/s]


Validation Loss: 0.5592, Accuracy: 0.9945
Epoch 10/50


100%|██████████| 120/120 [00:22<00:00,  5.28it/s]


Train Loss: 0.5531, Accuracy: 0.9987


100%|██████████| 35/35 [00:04<00:00,  7.19it/s]


Validation Loss: 0.5537, Accuracy: 0.9982
Epoch 11/50


100%|██████████| 120/120 [00:22<00:00,  5.35it/s]


Train Loss: 0.5535, Accuracy: 0.9979


100%|██████████| 35/35 [00:04<00:00,  7.64it/s]


Validation Loss: 0.5529, Accuracy: 0.9991
Epoch 12/50


100%|██████████| 120/120 [00:22<00:00,  5.36it/s]


Train Loss: 0.5551, Accuracy: 0.9966


100%|██████████| 35/35 [00:04<00:00,  7.04it/s]


Validation Loss: 0.5543, Accuracy: 0.9963
Epoch 13/50


100%|██████████| 120/120 [00:22<00:00,  5.37it/s]


Train Loss: 0.5531, Accuracy: 0.9987


100%|██████████| 35/35 [00:04<00:00,  7.45it/s]


Validation Loss: 0.5542, Accuracy: 0.9972
Epoch 14/50


100%|██████████| 120/120 [00:21<00:00,  5.47it/s]


Train Loss: 0.5571, Accuracy: 0.9945


100%|██████████| 35/35 [00:04<00:00,  7.06it/s]


Validation Loss: 0.5517, Accuracy: 1.0000
Epoch 15/50


100%|██████████| 120/120 [00:23<00:00,  5.01it/s]


Train Loss: 0.5551, Accuracy: 0.9963


100%|██████████| 35/35 [00:04<00:00,  7.56it/s]


Validation Loss: 0.5530, Accuracy: 0.9982
Epoch 16/50


100%|██████████| 120/120 [00:21<00:00,  5.48it/s]


Train Loss: 0.5544, Accuracy: 0.9971


100%|██████████| 35/35 [00:04<00:00,  7.05it/s]


Validation Loss: 0.5539, Accuracy: 0.9972
Epoch 17/50


100%|██████████| 120/120 [00:22<00:00,  5.40it/s]


Train Loss: 0.5541, Accuracy: 0.9976


100%|██████████| 35/35 [00:04<00:00,  7.12it/s]


Validation Loss: 0.5540, Accuracy: 0.9972
Epoch 18/50


100%|██████████| 120/120 [00:22<00:00,  5.40it/s]


Train Loss: 0.5545, Accuracy: 0.9966


100%|██████████| 35/35 [00:04<00:00,  7.09it/s]


Validation Loss: 0.5538, Accuracy: 0.9972
Epoch 19/50


100%|██████████| 120/120 [00:21<00:00,  5.60it/s]


Train Loss: 0.5532, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  7.68it/s]


Validation Loss: 0.5525, Accuracy: 0.9991
Epoch 20/50


100%|██████████| 120/120 [00:21<00:00,  5.65it/s]


Train Loss: 0.5532, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  8.01it/s]


Validation Loss: 0.5539, Accuracy: 0.9982
Epoch 21/50


100%|██████████| 120/120 [00:21<00:00,  5.59it/s]


Train Loss: 0.5540, Accuracy: 0.9974


100%|██████████| 35/35 [00:04<00:00,  7.79it/s]


Validation Loss: 0.5525, Accuracy: 0.9991
Epoch 22/50


100%|██████████| 120/120 [00:21<00:00,  5.63it/s]


Train Loss: 0.5521, Accuracy: 0.9997


100%|██████████| 35/35 [00:04<00:00,  8.06it/s]


Validation Loss: 0.5533, Accuracy: 0.9982
Epoch 23/50


100%|██████████| 120/120 [00:21<00:00,  5.63it/s]


Train Loss: 0.5531, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  7.74it/s]


Validation Loss: 0.5520, Accuracy: 1.0000
Epoch 24/50


100%|██████████| 120/120 [00:21<00:00,  5.53it/s]


Train Loss: 0.5528, Accuracy: 0.9990


100%|██████████| 35/35 [00:04<00:00,  7.81it/s]


Validation Loss: 0.5532, Accuracy: 0.9972
Epoch 25/50


100%|██████████| 120/120 [00:21<00:00,  5.52it/s]


Train Loss: 0.5525, Accuracy: 0.9990


100%|██████████| 35/35 [00:04<00:00,  7.62it/s]


Validation Loss: 0.5533, Accuracy: 0.9982
Epoch 26/50


100%|██████████| 120/120 [00:23<00:00,  5.13it/s]


Train Loss: 0.5537, Accuracy: 0.9976


100%|██████████| 35/35 [00:04<00:00,  7.64it/s]


Validation Loss: 0.5529, Accuracy: 0.9991
Epoch 27/50


100%|██████████| 120/120 [00:21<00:00,  5.51it/s]


Train Loss: 0.5534, Accuracy: 0.9982


100%|██████████| 35/35 [00:04<00:00,  7.01it/s]


Validation Loss: 0.5516, Accuracy: 1.0000
Epoch 28/50


100%|██████████| 120/120 [00:22<00:00,  5.44it/s]


Train Loss: 0.5521, Accuracy: 0.9995


100%|██████████| 35/35 [00:04<00:00,  7.39it/s]


Validation Loss: 0.5517, Accuracy: 1.0000
Epoch 29/50


100%|██████████| 120/120 [00:22<00:00,  5.22it/s]


Train Loss: 0.5536, Accuracy: 0.9979


100%|██████████| 35/35 [00:04<00:00,  7.24it/s]


Validation Loss: 0.5524, Accuracy: 0.9991
Epoch 30/50


100%|██████████| 120/120 [00:22<00:00,  5.40it/s]


Train Loss: 0.5539, Accuracy: 0.9976


100%|██████████| 35/35 [00:04<00:00,  7.69it/s]


Validation Loss: 0.5525, Accuracy: 0.9991
Epoch 31/50


100%|██████████| 120/120 [00:22<00:00,  5.44it/s]


Train Loss: 0.5530, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  7.74it/s]


Validation Loss: 0.5525, Accuracy: 0.9991
Epoch 32/50


100%|██████████| 120/120 [00:21<00:00,  5.47it/s]


Train Loss: 0.5532, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  7.50it/s]


Validation Loss: 0.5556, Accuracy: 0.9954
Epoch 33/50


100%|██████████| 120/120 [00:21<00:00,  5.51it/s]


Train Loss: 0.5533, Accuracy: 0.9982


100%|██████████| 35/35 [00:04<00:00,  7.53it/s]


Validation Loss: 0.5527, Accuracy: 0.9991
Epoch 34/50


100%|██████████| 120/120 [00:22<00:00,  5.37it/s]


Train Loss: 0.5522, Accuracy: 0.9992


100%|██████████| 35/35 [00:04<00:00,  7.18it/s]


Validation Loss: 0.5542, Accuracy: 0.9972
Epoch 35/50


100%|██████████| 120/120 [00:22<00:00,  5.43it/s]


Train Loss: 0.5549, Accuracy: 0.9966


100%|██████████| 35/35 [00:04<00:00,  7.46it/s]


Validation Loss: 0.5522, Accuracy: 0.9991
Epoch 36/50


100%|██████████| 120/120 [00:21<00:00,  5.48it/s]


Train Loss: 0.5547, Accuracy: 0.9969


100%|██████████| 35/35 [00:04<00:00,  7.40it/s]


Validation Loss: 0.5557, Accuracy: 0.9963
Epoch 37/50


100%|██████████| 120/120 [00:21<00:00,  5.48it/s]


Train Loss: 0.5549, Accuracy: 0.9958


100%|██████████| 35/35 [00:04<00:00,  7.52it/s]


Validation Loss: 0.5533, Accuracy: 0.9982
Epoch 38/50


100%|██████████| 120/120 [00:21<00:00,  5.47it/s]


Train Loss: 0.5535, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  7.66it/s]


Validation Loss: 0.5529, Accuracy: 0.9991
Epoch 39/50


100%|██████████| 120/120 [00:21<00:00,  5.62it/s]


Train Loss: 0.5524, Accuracy: 0.9992


100%|██████████| 35/35 [00:04<00:00,  7.71it/s]


Validation Loss: 0.5525, Accuracy: 0.9991
Epoch 40/50


100%|██████████| 120/120 [00:21<00:00,  5.62it/s]


Train Loss: 0.5527, Accuracy: 0.9990


100%|██████████| 35/35 [00:04<00:00,  7.16it/s]


Validation Loss: 0.5552, Accuracy: 0.9963
Epoch 41/50


100%|██████████| 120/120 [00:21<00:00,  5.54it/s]


Train Loss: 0.5539, Accuracy: 0.9974


100%|██████████| 35/35 [00:04<00:00,  7.67it/s]


Validation Loss: 0.5528, Accuracy: 0.9982
Epoch 42/50


100%|██████████| 120/120 [00:21<00:00,  5.62it/s]


Train Loss: 0.5525, Accuracy: 0.9990


100%|██████████| 35/35 [00:04<00:00,  7.87it/s]


Validation Loss: 0.5527, Accuracy: 0.9991
Epoch 43/50


100%|██████████| 120/120 [00:21<00:00,  5.61it/s]


Train Loss: 0.5529, Accuracy: 0.9987


100%|██████████| 35/35 [00:04<00:00,  7.54it/s]


Validation Loss: 0.5535, Accuracy: 0.9982
Epoch 44/50


100%|██████████| 120/120 [00:21<00:00,  5.58it/s]


Train Loss: 0.5533, Accuracy: 0.9979


100%|██████████| 35/35 [00:04<00:00,  7.90it/s]


Validation Loss: 0.5535, Accuracy: 0.9982
Epoch 45/50


100%|██████████| 120/120 [00:22<00:00,  5.33it/s]


Train Loss: 0.5526, Accuracy: 0.9987


100%|██████████| 35/35 [00:04<00:00,  7.37it/s]


Validation Loss: 0.5542, Accuracy: 0.9972
Epoch 46/50


100%|██████████| 120/120 [00:21<00:00,  5.49it/s]


Train Loss: 0.5544, Accuracy: 0.9969


100%|██████████| 35/35 [00:04<00:00,  7.70it/s]


Validation Loss: 0.5550, Accuracy: 0.9963
Epoch 47/50


100%|██████████| 120/120 [00:21<00:00,  5.49it/s]


Train Loss: 0.5528, Accuracy: 0.9987


100%|██████████| 35/35 [00:04<00:00,  7.71it/s]


Validation Loss: 0.5530, Accuracy: 0.9982
Epoch 48/50


100%|██████████| 120/120 [00:22<00:00,  5.44it/s]


Train Loss: 0.5548, Accuracy: 0.9966


100%|██████████| 35/35 [00:04<00:00,  7.57it/s]


Validation Loss: 0.5515, Accuracy: 1.0000
Epoch 49/50


100%|██████████| 120/120 [00:23<00:00,  5.15it/s]


Train Loss: 0.5538, Accuracy: 0.9974


100%|██████████| 35/35 [00:05<00:00,  6.75it/s]


Validation Loss: 0.5519, Accuracy: 0.9991
Epoch 50/50


100%|██████████| 120/120 [00:22<00:00,  5.31it/s]


Train Loss: 0.5530, Accuracy: 0.9984


100%|██████████| 35/35 [00:04<00:00,  7.60it/s]


Validation Loss: 0.5538, Accuracy: 0.9972


100%|██████████| 18/18 [00:03<00:00,  4.75it/s]

Test Accuracy: 1.0000





In [78]:
model_save_path = 'document_classifier.pth'
torch.save(model.state_dict(), model_save_path)
print(f'Model saved to {model_save_path}')

Model saved to document_classifier.pth
