In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import os
import matplotlib.pyplot as plt
from PIL import Image

In [4]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 
])

test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [9]:
# Data directory
train_data = "binary_train_dataset"
test_data = "binary_test_dataset"

# Batch_size
fixed_batch_size = 32

# Load dataset
train_dataset = datasets.ImageFolder(root=train_data, transform=train_transform)
test_dataset = datasets.ImageFolder(root=test_data, transform=test_transform)

# Data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=fixed_batch_size, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=fixed_batch_size, shuffle=False, num_workers=4)

# Class to index mapping
print(train_dataset.class_to_idx)

{'figure': 0, 'non_figure': 1}


In [10]:
# Load the pre-trained model
model = models.resnet18(pretrained=True)

# Freezing all layers except the final classification layer
for name, param in model.named_parameters():
    if "fc" in name:
        param.requires_grad = True # unfreeze final classification layer
    else:
        param.requires_grad = False

# Modify final classification layer
num_classes = 2
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=5e-4)

# Move model to the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [11]:
# Training loop
num_epochs = 15
for epoch in range(num_epochs):
    model.train()  
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()  # Backpropagation
        optimizer.step()  # Update weights
        
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

print("Training complete!")



Epoch 1, Loss: 0.4019




Epoch 2, Loss: 0.3830




Epoch 3, Loss: 0.3937




Epoch 4, Loss: 0.3651




Epoch 5, Loss: 0.3692




Epoch 6, Loss: 0.3531




Epoch 7, Loss: 0.3287




Epoch 8, Loss: 0.3961




Epoch 9, Loss: 0.3792




Epoch 10, Loss: 0.3331




Epoch 11, Loss: 0.3919




Epoch 12, Loss: 0.4526




Epoch 13, Loss: 0.3505




Epoch 14, Loss: 0.3626




Epoch 15, Loss: 0.3758
Training complete!


In [12]:
# Model evaluation
model.eval()  # Set model to evaluation mode
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")



Test Accuracy: 80.84%


In [13]:
torch.save(model.state_dict(), 'binary_classifier.pth')

In [70]:
# Binary classification
model.load_state_dict(torch.load("binary_classifier.pth"))
model.eval()  # Set to evaluation mode

def is_chart(image_path):
    # Load the image from the file path
    image = Image.open(image_path).convert("RGBA")  # Convert to RGB for consistency
    
    # Preprocessing
    transformed_image = test_transform(image)
    transformed_image = transformed_image.unsqueeze(0)  # Add batch dimension

    # Use binary classifier
    transformed_image = transformed_image.to(device)
    logits = model(transformed_image)
    #confidence, pred = torch.max(output, 1)
    probabilities = F.softmax(logits, dim=1)

    confidence = probabilities[0, 1].item()  # Positive class confidence
    #print(confidence)
    pred = 1 if confidence <= 0.05 else 0  # Thresholding at 0.5
    return pred
    #return pred.item() == 0  # Return True if the image is a chart

  model.load_state_dict(torch.load("binary_classifier.pth"))


In [71]:
dataset = "figure_data/"
figure_dir = "figures/"
non_figure_dir = "non_figures/"

for image_file in os.listdir(dataset):
    if is_chart(dataset+image_file):
        os.rename(dataset + image_file, figure_dir + image_file)
    else:
        os.rename(dataset + image_file, non_figure_dir + image_file)

