In [7]:
import os
import json
import torch
import numpy as np
from PIL import Image
from datasets import Dataset
from transformers import AutoImageProcessor, Trainer, TrainingArguments
import evaluate
from torch import nn
from torch.utils.data import DataLoader


In [2]:
# CONFIG
base_dir = "/Users/georgye/Documents/repos/ethz/dslab25/"
root_dir = base_dir + "training/vacuum_pump"
image_dir = os.path.join(root_dir, "images", "augmented")
label_dir = os.path.join(root_dir, "annotation", "augmented")
coco_path = os.path.join(root_dir, "coco_annotations.json")

# Collect category mapping
stage_folders = [f"stage_{i}" for i in range(8)]
category_mapping = {name: i for i, name in enumerate(stage_folders)}  # name -> ID
categories = [{"id": i, "name": name} for name, i in category_mapping.items()]

# Initialize COCO structure
coco_output = {
	"images": [],
	"annotations": [],
	"categories": categories
}

image_id = 0
annotation_id = 0

# Traverse through each stage folder
for class_folder in stage_folders:
	img_folder = os.path.join(image_dir, class_folder)
	label_folder = os.path.join(label_dir, class_folder)
	category_id = category_mapping[class_folder]

	for filename in os.listdir(img_folder):
		if not filename.endswith(".jpg"):
			continue

		image_path = os.path.join(img_folder, filename)
		label_path = os.path.join(label_folder, filename.replace(".jpg", ".txt"))

		# Read image size
		with Image.open(image_path) as img:
			width, height = img.size

		# Add image entry
		coco_output["images"].append({
			"id": image_id,
			"file_name": f"{class_folder}/{filename}",
			"width": width,
			"height": height
		})

		# Process annotation
		if os.path.exists(label_path):
			with open(label_path, "r") as f:
				for line in f:
					parts = line.strip().split()
					if len(parts) != 5:
						continue
					cls, x_center, y_center, w, h = map(float, parts)

					# Convert YOLO to COCO format
					x = (x_center - w / 2) * width
					y = (y_center - h / 2) * height
					box_width = w * width
					box_height = h * height

					coco_output["annotations"].append({
						"id": annotation_id,
						"image_id": image_id,
						"category_id": category_id,
						"bbox": [x, y, box_width, box_height],
						"area": box_width * box_height,
						"iscrowd": 0
					})
					annotation_id += 1

		image_id += 1

# Save to JSON
with open(coco_path, "w") as f:
	json.dump(coco_output, f, indent=2)

print(f"COCO annotations saved to {coco_path}")

COCO annotations saved to /Users/georgye/Documents/repos/ethz/dslab25/training/vacuum_pump/coco_annotations.json


In [10]:
# Load COCO annotations
with open(coco_path, 'r') as f:
    coco_data = json.load(f)

# Create image-to-category mapping
image_to_category = {}
for annotation in coco_data["annotations"]:
    image_id = annotation["image_id"]
    category_id = annotation["category_id"]
    image_to_category[image_id] = category_id

# Create dataset dictionaries
dataset_dict = {
    "image_path": [],
    "label": []
}

for image_info in coco_data["images"]:
    image_id = image_info["id"]
    file_name = image_info["file_name"]
    full_path = os.path.join(image_dir, file_name)
    
    if image_id in image_to_category:
        dataset_dict["image_path"].append(full_path)
        dataset_dict["label"].append(image_to_category[image_id])

# Split dataset into train and validation (80/20)
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.train_test_split(test_size=0.2, seed=42)

# Load DINOv2 model with registers via torch hub
# Using ViT-B/14 with registers which has improved performance
print("Loading DINOv2 ViT-B/14 with registers...")
dinov2 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14_reg')
dinov2.eval()  # Set to evaluation mode

# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dinov2 = dinov2.to(device)
print(f"Using device: {device}")

# Define image processor
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")

# Custom Dataset for DINOv2 features
class DINOv2Dataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dict, processor, feature_extractor, device):
        self.dataset = dataset_dict
        self.processor = processor
        self.feature_extractor = feature_extractor
        self.device = device
        
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = Image.open(item["image_path"]).convert("RGB")
        
        # Process image for DINOv2
        inputs = self.processor(images=image, return_tensors="pt")
        inputs = {k: v.to(self.device) for k, v in inputs.items()}
        
        # Extract features - use the class token ([CLS])
        with torch.no_grad():
            features = self.feature_extractor(**inputs).last_hidden_state[:, 0].squeeze().cpu()
        
        return {
            "features": features,
            "labels": torch.tensor(item["label"], dtype=torch.long)
        }

# Prepare datasets
print("Preparing datasets...")
train_dataset = DINOv2Dataset(dataset["train"], processor, dinov2, device)
eval_dataset = DINOv2Dataset(dataset["test"], processor, dinov2, device)

# Define linear classifier
class LinearClassifier(nn.Module):
    def __init__(self, input_dim=768, num_labels=8):  # 768 is ViT-B dimension
        super().__init__()
        self.classifier = nn.Linear(input_dim, num_labels)
        
    def forward(self, features):
        return {"logits": self.classifier(features)}

# Initialize model
model = LinearClassifier()

# Define training arguments
training_args = TrainingArguments(
    output_dir=os.path.join(root_dir, "dinov2_register_linear_classifier"),
    learning_rate=3e-4,
    per_device_train_batch_size=64,  # Can increase this on RTX 4090
    per_device_eval_batch_size=64,
    num_train_epochs=10,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    dataloader_num_workers=4,  # Use multiple workers for data loading
)

# Define compute_metrics function
metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return metric.compute(predictions=predictions, references=labels)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

print("Starting training...")
# Start training
trainer.train()

# Save the fine-tuned model
model_save_path = os.path.join(root_dir, "dinov2_register_linear_classifier_final")
trainer.save_model(model_save_path)
print(f"Model saved to {model_save_path}")

# Evaluate the model
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

# Example of how to use the model for inference
def predict_image(image_path):
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Extract features with DINOv2
    with torch.no_grad():
        features = dinov2(**inputs).last_hidden_state[:, 0].squeeze().cpu()
    
    # Get prediction from linear classifier
    with torch.no_grad():
        outputs = model(features)
    
    logits = outputs["logits"]
    predicted_class_id = logits.argmax(-1).item()
    predicted_class = f"stage_{predicted_class_id}"
    
    return predicted_class, torch.softmax(logits, dim=-1)[predicted_class_id].item()

# Test on a sample image
sample_stage = "stage_0"
sample_image = os.listdir(os.path.join(image_dir, sample_stage))[0]
sample_path = os.path.join(image_dir, sample_stage, sample_image)

predicted_class, confidence = predict_image(sample_path)
print(f"Sample image: {sample_path}")
print(f"Predicted class: {predicted_class}, confidence: {confidence:.4f}")

Loading DINOv2 ViT-B/14 with registers...


Using cache found in /Users/georgye/.cache/torch/hub/facebookresearch_dinov2_main


Using device: cpu
Preparing datasets...


ValueError: --load_best_model_at_end requires the save and eval strategy to match, but found
- Evaluation strategy: IntervalStrategy.NO
- Save strategy: SaveStrategy.EPOCH