In [None]:
import os
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
from PIL import Image

In [None]:
# File and Directory Paths
figure_dir = "extracted_figures/"
figure_metadata_file = figure_dir + "figures.csv"

classified_figure_dir = "classified_figures/"
unclassified_figure_dir = "unclassified_figures/"
model_path= "models/"

os.makedirs(classified_figure_dir, exist_ok=True)
os.makedirs(unclassified_figure_dir, exist_ok=True)

possible_extensions = [".pdf", ".png", ".jpg", ".jpeg", ".eps", ".ps"]

In [None]:
# Delete figures from dataset if their image size is larger than the defined threshold
def delete_figures_size(max_size):
    with open(figure_metadata_file, "r", newline='', encoding='utf-8') as input_file:
        with open(figure_dir + "tmp.csv", "w", newline='', encoding='utf-8') as output_file:
            csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            csv_writer = csv.writer(output_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            
            counter = 0
            for row in csv_reader:
                figure_id = row[0]
                try:
                    # Find figure file
                    found_ext = None
                    for extension in possible_extensions:
                        if os.path.isfile(figure_dir + figure_id + extension):
                            found_ext = extension
                            break
                            
                    if found_ext:
                        # Obtain file size
                        figure_file = figure_dir + figure_id + extension
                        file_size = os.path.getsize(figure_file)

                        # Remove from dataset when file size is too large
                        if file_size > max_size:
                            os.remove(figure_file)
                            counter += 1
                        else:
                            csv_writer.writerow(row)
                    else:
                        # Remove from dataset if file was not found
                        print(f"File for figure {figure_id} was not found.")
                        counter += 1
                    
                except Exception as e:
                    # No removal when an error occurres
                    print(f"Error occurred for {figure_id}: {e}")
                    csv_writer.writerow(row)
                    
    # Replace old csv file with new csv file
    os.replace(figure_dir + "tmp.csv", figure_metadata_file)
    
    print(f"{counter} figures were deleted from the dataset.")

In [None]:
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Classify image files as figure or non-figure
def is_figure(model, file, device, threshold):
    image_path = figure_dir + file
    image = Image.open(image_path).convert("RGBA")
    
    # Preprocessing
    transformed_image = test_transform(image)
    transformed_image = transformed_image.unsqueeze(0)
    
    # Binary classifying
    transformed_image = transformed_image.to(device)
    logits = model(transformed_image)
    probabilities = F.softmax(logits, dim=1)
    
    # Predict
    confidence = probabilities[0, 1].item()
    pred = True if confidence <= threshold else False
    return pred

In [None]:
# Set threshold for figure file size
MAX_FILE_SIZE = 2 * 1024 * 1024

# Delete figures with a larger file size
delete_figures_size(MAX_FILE_SIZE)

In [None]:
# Set probability threshold for classifying a figure
PROBABILITY_THRESHOLD = 0.05

# Load model on CPU
model.load_state_dict(torch.load(model_path + "binary_classifier.pth", map_location=torch.device('cpu')))
device = torch.device('cpu')
model = model.to(device)
model.eval()

# Classify images
for image_file in figure_dir:
    if image_file.endswith(".png") or image_file.endswith(".jpg"):
        if is_figure(image_file, model, device, PROBABILITY_THRESHOLD):
            os.rename(figure_dir + image_file, classified_figure_dir + image_file)
        else:
            os.rename(figure_dir + image_file, unclassified_figure_dir + image_file)
            
# TODO: Remove non-figures from dataset