In [2]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from efficientnet_pytorch import EfficientNet

# 1. Define the DocumentClassifier model architecture
class DocumentClassifier(nn.Module):
    def __init__(self, num_classes=3):
        super(DocumentClassifier, self).__init__()
        
        # EfficientNet backbone
        self.backbone = EfficientNet.from_pretrained('efficientnet-b0')
        
        # Global average pooling
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        
        # Classifier head
        num_features = self.backbone._fc.in_features
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes),
            nn.Softmax(dim=1)
        )
    
    def forward(self, x):
        # Extract features using the backbone
        x = self.backbone.extract_features(x)
        
        # Global average pooling
        x = self.global_pool(x)
        
        # Flatten the output tensor
        x = x.view(x.size(0), -1)
        
        # Classifier
        x = self.classifier(x)
        
        return x

# 2. Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Instantiate the model and load the state dict
model = DocumentClassifier(num_classes=3)
model.load_state_dict(torch.load('/mnt/c/Users/Rahul/Desktop/Document-and-Record-Management/notebooks/document_classifier.pth', map_location=device))

# Move the model to the appropriate device
model.to(device)

# Set the model to evaluation mode
model.eval()

# 3. Preprocess the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize the image to the input size expected by EfficientNet
    transforms.ToTensor(),           # Convert image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize based on ImageNet stats
])

# Load your image
image_path = '/mnt/c/Users/Rahul/Desktop/aayush.jpg'
image = Image.open(image_path)

# Preprocess the image
image = transform(image).unsqueeze(0).to(device)  # Add batch dimension and move to the device

# 4. Perform inference
with torch.no_grad():  # Disable gradient calculation for inference
    output = model(image)

# 5. Process the output
# The output is a probability distribution over the classes
predicted_probabilities = output.cpu().numpy()

# Get the predicted class
predicted_class = torch.argmax(output, dim=1).item()

print(f'Predicted class: {predicted_class}')
print(f'Class probabilities: {predicted_probabilities}')


Loaded pretrained weights for efficientnet-b0


  model.load_state_dict(torch.load('/mnt/c/Users/Rahul/Desktop/Document-and-Record-Management/notebooks/document_classifier.pth', map_location=device))


Predicted class: 0
Class probabilities: [[1.0000000e+00 9.6527951e-20 7.5662624e-18]]
