In [None]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from transformers import ViTForImageClassification, ViTImageProcessor
from tqdm import tqdm

# Step 1: Configuration
PRETRAINED_MODEL_NAME = "google/vit-base-patch16-224-in21k"
BATCH_SIZE = 8
EPOCHS = 5
LEARNING_RATE = 2e-5
TRAIN_CSV_PATH = "./train_labels.csv"  # Path to training CSV
TEST_CSV_PATH = "./test_labels.csv"  # Path to testing CSV
BASE_DIR = "./dataset/images"  # Base directory for image paths
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 2: Initialize Model and Processor
def initialize_model_and_processor(pretrained_model_name, num_labels, device):
    model = ViTForImageClassification.from_pretrained(pretrained_model_name, num_labels=num_labels)
    processor = ViTImageProcessor.from_pretrained(pretrained_model_name)
    model = model.to(device)
    return model, processor

# Step 3: Define Dataset Class
class HelmetDataset(Dataset):
    def __init__(self, csv_file, processor, base_dir=None):
        self.data = pd.read_csv(csv_file)
        self.processor = processor
        self.base_dir = base_dir

        # Validate and correct paths
        self.data['image_path'] = self.data['image_path'].apply(self.validate_path)
        if self.data['image_path'].isnull().any():
            raise ValueError("No valid image paths found in the CSV.")

    def validate_path(self, image_path):
        if self.base_dir and not os.path.isabs(image_path):
            image_path = os.path.join(self.base_dir, image_path)
        if not os.path.isfile(image_path):
            print(f"Invalid path: {image_path}")
            return None
        return image_path

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        image_path = row['image_path']
        label = row['label']

        image = Image.open(image_path).convert("RGB")
        inputs = self.processor(images=image, return_tensors="pt")
        return inputs["pixel_values"].squeeze(0), torch.tensor(label)

# Step 4: Create DataLoaders
def create_dataloaders(train_csv, test_csv, processor, batch_size, base_dir):
    train_dataset = HelmetDataset(train_csv, processor, base_dir)
    test_dataset = HelmetDataset(test_csv, processor, base_dir)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size)
    return train_loader, test_loader

# Step 5: Training Function
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for pixel_values, labels in tqdm(train_loader, desc="Training", ncols=100):
        pixel_values, labels = pixel_values.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(pixel_values=pixel_values)
        loss = criterion(outputs.logits, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(train_loader)

# Step 6: Evaluation Function
def evaluate_model(model, test_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for pixel_values, labels in tqdm(test_loader, desc="Evaluating", ncols=100):
            pixel_values, labels = pixel_values.to(device), labels.to(device)

            outputs = model(pixel_values=pixel_values)
            loss = criterion(outputs.logits, labels)
            total_loss += loss.item()

            predictions = torch.argmax(outputs.logits, dim=-1)
            correct += (predictions == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    return total_loss / len(test_loader), accuracy

# Step 7: Inference Function
def predict(model, processor, image_path, device):
    if not os.path.isfile(image_path):
        raise FileNotFoundError(f"Image file not found: {image_path}")

    model.eval()
    image = Image.open(image_path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(pixel_values=inputs['pixel_values'])
        predicted_class = torch.argmax(outputs.logits, dim=-1).item()
    return predicted_class

# Main Execution
if __name__ == "__main__":
    # Initialize model and processor
    model, processor = initialize_model_and_processor(PRETRAINED_MODEL_NAME, num_labels=2, device=DEVICE)

    # Create DataLoaders
    train_loader, test_loader = create_dataloaders(TRAIN_CSV_PATH, TEST_CSV_PATH, processor, BATCH_SIZE, BASE_DIR)

    # Define optimizer and loss function
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    # Training and Evaluation Loop
    for epoch in range(EPOCHS):
        print(f"Epoch {epoch + 1}/{EPOCHS}")

        train_loss = train_model(model, train_loader, optimizer, criterion, DEVICE)
        print(f"Training Loss: {train_loss:.4f}")

        test_loss, test_accuracy = evaluate_model(model, test_loader, criterion, DEVICE)
        print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy * 100:.2f}%")

    # Save the fine-tuned model
    output_dir = "./helmet_detection_model"
    os.makedirs(output_dir, exist_ok=True)
    model.save_pretrained(output_dir)
    processor.save_pretrained(output_dir)
    print("Model and processor saved successfully.")

    # Example inference
    test_image_path = "./dataset/test/images/sample.jpg"  # Replace with a valid test image path
    predicted_class = predict(model, processor, test_image_path, DEVICE)
    print(f"Predicted class for the test image: {predicted_class}")
