This notebook was created by usage of the tutorial "Fine-tune SigLIP and friends for multi-label image classification" by Niels Rogge. (https://github.com/NielsRogge/Transformers-Tutorials/blob/master/SigLIP/Fine_tuning_SigLIP_and_friends_for_multi_label_image_classification.ipynb)

In [1]:
import os
import torch
from PIL import Image
from torchvision import datasets, transforms
from torch.optim import AdamW
from tqdm.auto import tqdm
from transformers import AutoProcessor, AutoModel
from transformers import AutoImageProcessor, AutoModelForImageClassification
from sklearn.metrics import f1_score, classification_report

2024-12-17 12:20:57.819131: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1734434457.832581  115230 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1734434457.836472  115230 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-17 12:20:57.849838: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Specify model
model_id = "google/siglip-base-patch16-224"

# Get appropriate size, mean and std based on the image processor
processor = AutoImageProcessor.from_pretrained(model_id)


# Define transform operation
size = processor.size["height"]
mean = processor.image_mean
std = processor.image_std
transform = transforms.Compose([
    transforms.Resize((size, size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std),
])

# Data directories
train_data = "binary_train_dataset"
test_data = "binary_test_dataset"
model_dir = "models/"

# Learning variables
fixed_batch_size = 32
fixed_num_workers = 2

# Load dataset
train_dataset = datasets.ImageFolder(root=train_data, transform=transform)
test_dataset = datasets.ImageFolder(root=test_data, transform=transform)

# Data loader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=fixed_batch_size, shuffle=True, num_workers=fixed_num_workers)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=fixed_batch_size, shuffle=False, num_workers=fixed_num_workers)

# Class to index mapping
print(train_dataset.class_to_idx)

{'figure': 0, 'non_figure': 1}


In [3]:
# Found at https://github.com/wenwei202/pytorch-examples/blob/ecbb7beb0fac13133c0b09ef980caf002969d315/imagenet/main.py#L296
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [4]:
# Load the pre-trained model
model = AutoModelForImageClassification.from_pretrained(model_id, problem_type="single_label_classification")

# Move model to the GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Loss function and optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)
losses = AverageMeter()

Some weights of SiglipForImageClassification were not initialized from the model checkpoint at google/siglip-base-patch16-224 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
# For training purpose
Image.MAX_IMAGE_PIXELS = None

# Training loop
num_epochs = 10
model.train()  
for epoch in range(num_epochs):
    running_loss = 0.0

    for idx, batch in enumerate(tqdm(train_loader)):
        # Get input
        pixel_values, labels = batch

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass 
        outputs = model(
            pixel_values=pixel_values.to(device),
            labels=labels.to(device),
        )
        
        # Calculate gradients
        loss = outputs.loss
        losses.update(loss.item(), pixel_values.size(0))
        loss.backward()

        # Update weights
        optimizer.step()
  
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(train_loader):.4f}")

print("Training complete!")

  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 1, Loss: 0.2515


  0%|          | 0/166 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0>
Traceback (most recent call last):
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0>self._shutdown_workers()

Traceback (most recent call last):
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
        if w.is_alive():self._shutdown_workers()

   File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
      if w.is_alive(): 
     ^ ^ ^ ^Exception ignored in:  ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0>Exception ignored in:  ^
<funct

Epoch 2, Loss: 0.1141


  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 3, Loss: 0.0577


  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 4, Loss: 0.0339


  0%|          | 0/166 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0>
Traceback (most recent call last):
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in: 
<function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0> 
 Traceback (most recent call last):
   File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
      self._shutdown_workers() 
   File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587

Epoch 5, Loss: 0.0603


  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 6, Loss: 0.0550


  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 7, Loss: 0.0192


  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 8, Loss: 0.0391


  0%|          | 0/166 [00:00<?, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0>
Traceback (most recent call last):
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
Exception ignored in:  <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0>  
 Traceback (most recent call last):
   File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
      self._shutdown_workers()^
^^  File "/home/fuubian/.virtualenvs/py3/lib/python3.12/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
^    ^if w.is_alive():^
^ ^ ^ ^^ Exception ignored in: ^ <function _MultiProcessingDataLoaderIter.__del__ at 0x7f2011b471a0> 

   File "/usr/lib/pyth

Epoch 9, Loss: 0.0153


  0%|          | 0/166 [00:00<?, ?it/s]



Epoch 10, Loss: 0.0474
Training complete!


In [6]:
# Model evaluation
model.eval()  # Set model to evaluation mode
correct = 0
total = 0
all_labels = []
all_predictions = []

with torch.no_grad():
    for idx, batch in enumerate(tqdm(test_loader)):
        pixel_values, labels = batch
        pixel_values, labels = pixel_values.to(device), labels.to(device)
        
        outputs = model(
            pixel_values=pixel_values.to(device),
            labels=labels.to(device),
        )
        
        logits = outputs.logits
        _, predicted = torch.max(logits, 1)
        
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        all_labels.extend(labels.cpu().numpy())
        all_predictions.extend(predicted.cpu().numpy())

# Accuracy
accuracy = 100 * correct / total
print(f"Test Accuracy: {accuracy:.2f}%")

# F1 Score
f1 = f1_score(all_labels, all_predictions, average="binary")
print(f"F1 Score: {f1:.2f}")

# Classification report
report = classification_report(all_labels, all_predictions, target_names=test_dataset.classes)
print(report)

  0%|          | 0/21 [00:00<?, ?it/s]

Test Accuracy: 93.18%
F1 Score: 0.74
              precision    recall  f1-score   support

      figure       0.95      0.97      0.96       554
  non_figure       0.79      0.70      0.74        91

    accuracy                           0.93       645
   macro avg       0.87      0.84      0.85       645
weighted avg       0.93      0.93      0.93       645



In [7]:
# Saving the model
model_path = model_dir + 'binary_classifier.pth'
torch.save(model.state_dict(), model_path)

In [8]:
# Binary classification example
model_path = model_dir + 'binary_classifier.pth'
model.load_state_dict(torch.load(model_path))
model.eval()  # Set to evaluation mode

def is_figure(image_path):
    # Load the image from the file path
    image = Image.open(image_path).convert('RGB')
    
    # Preprocessing
    pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)

    # Use binary classifier
    with torch.no_grad():
        outputs = model(pixel_values)
        logits = outputs.logits
        sigmoid = torch.nn.Sigmoid()
        probs = sigmoid(logits.squeeze().cpu())
        pred = 1 if probs[0].item() >= 0.80 else 0
        return pred

  model.load_state_dict(torch.load(model_path))


In [9]:
# Apply classifier on a subset of data
dataset = "figure_data/"
figure_dir = "figures/"
non_figure_dir = "non_figures/"

for image_file in os.listdir(dataset):
    try:
        if is_figure(dataset+image_file):
            os.rename(dataset + image_file, figure_dir + image_file)
        else:
            os.rename(dataset + image_file, non_figure_dir + image_file)
    except Exception as e:
        print(f"Exception for {image_file}: {e}")

print("Classifying completed.")



Exception for 2407.06552_FIG_1.png: Decompressed data too large for PngImagePlugin.MAX_TEXT_CHUNK
Exception for 2407.00851_FIG_3.png: Decompressed data too large for PngImagePlugin.MAX_TEXT_CHUNK
Classifying completed.
