In [3]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

# ----------------------------
# Model (must match training)
# ----------------------------
class ChartClassifier(nn.Module):
    def __init__(self):
        super(ChartClassifier, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(128 * 28 * 28, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 2)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        x = self.fc_layers(x)
        return x


# ----------------------------
# Load Model
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ChartClassifier().to(device)
model.load_state_dict(torch.load("models/chartdetector.pth", map_location=device))
model.eval()

# ----------------------------
# Image Transform
# ----------------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# ----------------------------
# Prediction Function
# ----------------------------
def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    img_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(img_tensor)
        _, predicted = torch.max(outputs, 1)

    classes = ["Non-Chart", "Chart"]
    return classes[predicted.item()]


# ----------------------------
# Test
# ----------------------------
if __name__ == "__main__":
    test_img = "page.png"  # change this path
    result = predict(test_img)
    print(f"Prediction: {result}")


Prediction: Non-Chart
