In [1]:
import os
import csv
from PIL import Image

In [2]:
# File and Directory Paths
figure_dir = "extracted_figures/"
figure_metadata_file = figure_dir + "figures.csv"

classified_figure_dir = figure_dir + "classified_figures/"
unclassified_figure_dir = figure_dir + "unclassified_figures/"
model_dir = "models/"

os.makedirs(classified_figure_dir, exist_ok=True)
os.makedirs(unclassified_figure_dir, exist_ok=True)

# CSV Size Limit
csv.field_size_limit(260000)

131072

In [12]:
# Delete figures from dataset if their image size is larger than the defined threshold
def delete_figures_size(max_size):
    with open(figure_metadata_file, "r", newline='', encoding='utf-8') as input_file:
        with open(figure_dir + "tmp.csv", "w", newline='', encoding='utf-8') as output_file:
            csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            csv_writer = csv.writer(output_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            
            counter = 0
            for row in csv_reader:
                figure_id = row[0]
                figure_file = figure_dir + figure_id + ".png"
                try:
                    # Obtain file size
                    figure_file = figure_dir + figure_id + extension
                    file_size = os.path.getsize(figure_file)

                    # Remove from dataset when file size is too large
                    if file_size > max_size:
                        os.remove(figure_file)
                        counter += 1
                    else:
                        csv_writer.writerow(row)
                except e as Exception:
                    print(f"Error occurred: {e}")
                    if os.path.isfile(figure_file):
                        os.remove(figure_file)
                        counter += 1
                    
    # Replace old csv file with new csv file
    os.replace(figure_dir + "tmp.csv", figure_metadata_file)
    
    print(f"{counter} figures were deleted from the dataset.")
    
    
# Delete figures from dataset if their pixel size is larger than the defined threshold
def delete_figures_pixel(max_pixel):
    with open(figure_metadata_file, "r", newline='', encoding='utf-8') as input_file:
        with open(figure_dir + "tmp.csv", "w", newline='', encoding='utf-8') as output_file:
            csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            csv_writer = csv.writer(output_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
            
            counter = 0
            for row in csv_reader:
                figure_id = row[0]
                figure_file = figure_dir + figure_id + ".png"
                try:
                    # Obtain pixel dimension
                    with Image.open(figure_file) as img:
                        width, height = img.size
                        total_pixels = width * height

                        # Remove from dataset when file size is too large
                        if total_pixels > max_pixel:
                            os.remove(figure_file)
                            counter += 1
                        else:
                            csv_writer.writerow(row)
                except Exception as e:
                    print(f"Error occurred: {e}")
                    if os.path.isfile(figure_file):
                        os.remove(figure_file)
                    
    # Replace old csv file with new csv file
    os.replace(figure_dir + "tmp.csv", figure_metadata_file)
    
    print(f"{counter} figures were deleted from the dataset.")

In [35]:
# Set threshold for figure file size
MAX_FILE_SIZE = 2 * 1024 * 1024

# Delete figures with a larger file size
delete_figures_size(MAX_FILE_SIZE)

2384 figures were deleted from the dataset.


In [13]:
# Set threshold for figure pixel size
IMG_MAX_PIX_SIZE = 89478485

# Delete figures with a larger file size
delete_figures_pixel(IMG_MAX_PIX_SIZE)

Error occurred: cannot identify image file 'extracted_figures/2302.00390_FIG_3.png'
Error occurred: cannot identify image file 'extracted_figures/2304.02725_FIG_2.png'
Error occurred: cannot identify image file 'extracted_figures/2303.11580_FIG_3.png'
Error occurred: cannot identify image file 'extracted_figures/2301.11096_FIG_16.png'
Error occurred: cannot identify image file 'extracted_figures/2210.17166_FIG_2.png'
Error occurred: cannot identify image file 'extracted_figures/2303.13307_FIG_2.png'
Error occurred: cannot identify image file 'extracted_figures/2210.03589_FIG_1.png'
Error occurred: cannot identify image file 'extracted_figures/2303.00260_FIG_5.png'
Error occurred: cannot identify image file 'extracted_figures/2303.08859_FIG_2.png'
Error occurred: cannot identify image file 'extracted_figures/2303.08859_FIG_3.png'
Error occurred: cannot identify image file 'extracted_figures/2212.03640_FIG_4.png'
Error occurred: cannot identify image file 'extracted_figures/2212.07384_FI

Error occurred: cannot identify image file 'extracted_figures/2208.02442_FIG_3.png'
Error occurred: cannot identify image file 'extracted_figures/2208.03970_FIG_4.png'
Error occurred: cannot identify image file 'extracted_figures/2209.02960_FIG_1.png'
Error occurred: cannot identify image file 'extracted_figures/2209.06032_FIG_4.png'
0 figures were deleted from the dataset.


In [32]:
# Get all file extensions:
file_extensions = set()
file_collection = os.listdir(figure_dir)

for file in file_collection:
    ext = "." + file.split(".")[-1]
    file_extensions.add(ext)

file_extensions

{'.csv', '.png', '.zip'}

In [27]:
# Rename PNG to png
file_collection = os.listdir(figure_dir)
for file in file_collection:
    if file.endswith(".PNG"):
        os.rename(figure_dir + file, figure_dir + file.replace(".PNG", ".png"))

In [24]:
# Remove all non-png files:
counter = 0
for file in file_collection:
    if not file.endswith(".png") and not file.endswith(".PNG") \
    and not file.endswith(".csv") and not file.endswith(".zip") \
    and not file.endswith("tex_files"):
        counter += 1
        os.remove(figure_dir + file)
        
print(f"{counter} files have been deleted.")

0 files have been deleted.


In [14]:
# Remove figures from metadata if file is not on disk
file_collection = set(os.listdir(figure_dir))
counter = 0

with open(figure_metadata_file, "r", newline='', encoding='utf-8') as input_file:
    with open(figure_dir + "tmp.csv", "w", newline='', encoding='utf-8') as output_file:
        csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        csv_writer = csv.writer(output_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        
        for row in csv_reader:
            file_name = row[0] + ".png"
            if file_name in file_collection:
                csv_writer.writerow(row)
            else:
                counter += 1
                
# Replace old csv file with new csv file
os.replace(figure_dir + "tmp.csv", figure_metadata_file)

print(f"{counter} figures have been removed.")

0 figures have been removed.


In [None]:
# Classifying between 'scientific' and 'non_scientific' figures
from torchvision import datasets, transforms
from torch.optim import AdamW
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModel
from transformers import AutoImageProcessor, AutoModelForImageClassification

# Specify model
model_id = "google/siglip-base-patch16-224"

# Load pre-trained model and processor
model = AutoModelForImageClassification.from_pretrained(model_id, problem_type="single_label_classification")
processor = AutoImageProcessor.from_pretrained(model_id)

# Load fine-tuned model from disk
model_path = model_dir + 'binary_classifier.pth'
model.load_state_dict(torch.load(model_path))

# Move model to GPU and evaluation mode
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model.eval()

# Returns if an image consists of a scientific figure
def is_figure(image_path, fixed_threshold):
    # Load the image from the file path
    image = Image.open(image_path).convert("RGB")
    
    # Preprocessing
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

    # Use binary classifier
    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits
        sigmoid = torch.nn.Sigmoid()
        probs = sigmoid(logits.squeeze().cpu())
        pred = 1 if probs[0].item() >= fixed_threshold else 0
        return pred

In [None]:
# Set probability threshold for classifying a figure
PROBABILITY_THRESHOLD = 0.80

# Classify images
classified_figures = set()
for image_file in figure_dir:
    try:
        if image_file.endswith(".png"):
            if is_figure(image_file, model, device, PROBABILITY_THRESHOLD):
                os.rename(figure_dir + image_file, classified_figure_dir + image_file)
                classified_figures.add(image_file.replace(".png", "")
            else:
                os.rename(figure_dir + image_file, unclassified_figure_dir + image_file)
    except Exception as e:
        print(f"Exception for {image_file}: {e}")

# Create new csv file
with open(figure_metadata_file, "r", newline='', encoding='utf-8') as input_file:
    with open(classified_figure_dir + "classified_figures.csv", "w", newline='', encoding='utf-8') as output_file:
        csv_reader = csv.reader(input_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)
        csv_writer = csv.writer(output_file, delimiter=';', quotechar='|', quoting=csv.QUOTE_MINIMAL)

        # Copy rows of classified figures
        for row in csv_reader:
            if row[0] in classified_figures:
                csv_writer.writerow(row)

print("Classifying completed.")