# **Fine-Tuning ViT with LoRA for Steel Scrap Inspection**
This notebook fine-tunes a Vision Transformer (ViT) model using LoRA (Low-Rank Adaptation) to classify scrap metal images into six categories based on their physical properties. The dataset is stored locally (15GB), and we’ll optimize training for limited hardware (16GB RAM, 8GB VRAM) with class weights to handle potential imbalances.

## **Steel Scrap Classes**
We’re classifying scrap metal into the following categories:

* E1: Old thin steel scrap (≤1.5x0.5x0.5 m, thickness <6 mm)
* E2: Thick new production steel scrap (≤1.5x0.5x0.5 m, thickness ≥3 mm)
* E3: Old thick steel scrap (≤1.5x0.5x0.5 m, thickness ≥6 mm)
* E6: Thin new production steel scrap, compressed or baled (thickness <3 mm)
* E8: Thin new production steel scrap (≤1.5x0.5x0.5 m, thickness <3 mm)
* EHRB: Old and new steel scrap, mainly rebars and merchant bars (max 1.5x0.5x0.5 m)

### **1. Import Libraries**

**Explanation:**

* Imports handle dataset management (datasets), model training (transformers), image processing (torchvision, PIL), LoRA adaptation (peft), and class weight computation (sklearn).

In [1]:
import os
import numpy as np
import torch
from PIL import Image
from datasets import Dataset, Image as HFImage
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)
from peft import LoraConfig, get_peft_model
import evaluate
from sklearn.utils.class_weight import compute_class_weight

### **2. Load the Dataset**

**Explanation:**

* Loads our 15GB dataset from D:\LLM\DOES\does_dataset_images.
* Each folder (e.g., E8) represents a scrap metal class. Images are paired with their file paths and labels.
* HFImage() marks the "image" column as image data for efficient processing.

In [2]:
# Define dataset path and class folders
output_dir = r"D:\LLM\DOES\does_dataset_images"
folders = ["E8", "E3", "E1", "E2", "E6", "EHRB"]

def create_dataset():
    data = []
    for folder in folders:
        folder_path = os.path.join(output_dir, folder)
        for img_file in os.listdir(folder_path):
            if os.path.isfile(os.path.join(folder_path, img_file)):
                data.append({"image": os.path.join(folder_path, img_file), "label": folder})
    return Dataset.from_list(data).cast_column("image", HFImage())

dataset = create_dataset()
print(f"Dataset size: {len(dataset)} examples")

Dataset size: 102399 examples


### **3. Define Label Mappings and Class Weights**

**Explanation:**

* **Mappings**: label2id and id2label convert between class names (e.g., E8) and IDs (0-5).
* **Class Meanings:** Defines what each class represents—useful for interpreting predictions.
* **Class Weights:** Uses 'balanced' mode to give higher weights to underrepresented classes (e.g., if EHRB has fewer images than E8). This balances training focus, improving accuracy on rare classes.

In [3]:
# Label mappings
labels = folders
label2id = {label: i for i, label in enumerate(labels)}
id2label = {i: label for i, label in enumerate(labels)}

# Class meanings for reference
class_meanings = {
    "E1": "Old thin steel scrap (≤1.5x0.5x0.5 m, thickness <6 mm)",
    "E2": "Thick new production steel scrap (≤1.5x0.5x0.5 m, thickness ≥3 mm)",
    "E3": "Old thick steel scrap (≤1.5x0.5x0.5 m, thickness ≥6 mm)",
    "E6": "Thin new production steel scrap, compressed or baled (thickness <3 mm)",
    "E8": "Thin new production steel scrap (≤1.5x0.5x0.5 m, thickness <3 mm)",
    "EHRB": "Old and new steel scrap, mainly rebars and merchant bars (max 1.5x0.5x0.5 m)"
}

# Compute class weights for imbalance
label_list = [example["label"] for example in dataset]
class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(label_list),
    y=label_list
)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float)
print(f"Class weights: {class_weights}")

Class weights: [0.93586861 1.6772973  0.80297826 1.93038118 0.43210705 3.88139641]


### **4. Set Up Image Preprocessing**

**Explanation:**

* **Image Processor:** Ensures images match ViT’s expected format (224x224, normalized).
* **Transforms:**
    * **Training:** Adds randomness (RandomResizedCrop, RandomHorizontalFlip) to make the model generalize better.
    * **Validation:** Uses fixed resizing for consistent evaluation.
* **Split:** 20% validation set (increased from 10%) provides a bigger sample to test performance.

In [None]:
# Load image processor
model_checkpoint = "google/vit-base-patch16-224"
image_processor = AutoImageProcessor.from_pretrained(model_checkpoint)

# Define transforms
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
train_transforms = Compose([
    RandomResizedCrop(image_processor.size["height"]),  # Random crop
    RandomHorizontalFlip(),                            # Augmentation: random flip
    ToTensor(),                                        # Convert to tensor
    normalize                                          # Normalize for ViT
])
val_transforms = Compose([
    Resize(image_processor.size["height"]),            # Consistent resize
    CenterCrop(image_processor.size["height"]),        # Center crop
    ToTensor(),
    normalize
])

# Preprocessing functions
def preprocess_train(example_batch):
    images = [Image.open(img).convert("RGB") if isinstance(img, str) else img.convert("RGB") 
              for img in example_batch["image"]]
    example_batch["pixel_values"] = [train_transforms(img) for img in images]
    example_batch["labels"] = [label2id[label] for label in example_batch["label"]]
    return example_batch

def preprocess_val(example_batch):
    images = [Image.open(img).convert("RGB") if isinstance(img, str) else img.convert("RGB") 
              for img in example_batch["image"]]
    example_batch["pixel_values"] = [val_transforms(img) for img in images]
    example_batch["labels"] = [label2id[label] for label in example_batch["label"]]
    return example_batch

# Split dataset (80% train, 20% validation)
splits = dataset.train_test_split(test_size=0.2)
train_ds = splits["train"]
val_ds = splits["test"]
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.


### **5. Load Model with LoRA**

**Explanation:**

* **LoRA:** Adapts only a small part of ViT, keeping memory usage low (fits 8GB VRAM).
* **Hyperparameters:**
    * <span style="color: orange;">r=16:</span> Higher rank (from 8) increases capacity for learning scrap metal features.
    * <span style="color: orange;">lora_alpha=16:</span> Matches r for balanced adaptation strength.
    * <span style="color: orange;">lora_dropout=0.3:</span> Higher dropout (from 0.1) reduces overfitting on your dataset.
    * <span style="color: orange">target_modules:</span> Focuses on attention layers (query, value), critical for ViT’s performance.

In [5]:
# Load base ViT model
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

# Configure LoRA
config = LoraConfig(
    r=16,              # LoRA rank
    lora_alpha=16,     # Scaling factor
    target_modules=["query", "value"],  # ViT attention layers
    lora_dropout=0.3,  # Dropout rate
    bias="none",       # No bias adaptation
    modules_to_save=["classifier"],  # Save classifier weights
)
lora_model = get_peft_model(model, config)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([6]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([6, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### **6. Define Metrics and Collate Function**

**Explanation:**

* **Metrics:** Measures accuracy—straightforward for classification.
* **Collate:** Combines images and labels into batches for training.

In [6]:
# Load accuracy metric
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return {"accuracy": metric.compute(predictions=predictions, references=eval_pred.label_ids)["accuracy"]}

# Collate function for batching
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

### **7. Custom Trainer with Class Weights**

**Explanation:**

* Customizes the Trainer to apply class weights in the loss function.
* **Why:** Our dataset has a significant class imbalance, meaning some classes have way more images than others. Without weights, the model would focus too much on frequent classes (e.g., *E8*) and ignore rare ones (e.g., *EHRB*), tanking accuracy on underrepresented categories. Class weights penalize mistakes on rare classes more heavily, balancing the model’s attention across all scrap metal types.

**Class Distribution**
Here’s the number of images per class in the dataset:

* **E8:** 39,496 (most common)
* **E3:** 21,254
* **E1:** 18,236
* **E2:** 10,175
* **E6:** 8,841
* **EHRB:** 4,397 (least common)

**Total Images:** 102,399

**Imbalance:** *E8* has ~9x more samples than *EHRB*, skewing the model toward *E8* without correction.

**Class Weights and Penalties**
We use *compute_class_weight('balanced')* to assign weights inversely proportional to class frequency. The formula is:
Weight for class \( i \):
$$
\text{Weight for class } i = \frac{\text{total samples}}{\text{number of classes} \cdot \text{samples in class } i}
$$

**Example Calculation:**
* Total samples = 102,399  
* Number of classes = 6  

**Weights per Class:**

* **E8**: `0.432` – (Lowest penalty, most common)  
* **E3**: `0.803`  
* **E1**: `0.936`  
* **E2**: `1.677`  
* **E6**: `1.931`  
* **EHRB**: `3.879` – (Highest penalty, rarest) 

In [7]:
class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.get("labels")
        outputs = model(**inputs)
        logits = outputs.get("logits")
        loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights_tensor.to(logits.device))
        loss = loss_fct(logits, labels)
        return (loss, outputs) if return_outputs else loss

### **8. Set Training Arguments**

**Explanation:**

* Hyperparameters:
    * <span style='color: orange;'>learning_rate=2e-4:</span> Reduced (from 5e-3) for precise updates, avoiding overshooting.
    * <span style='color: orange;'>batch_size=8, gradient_accumulation_steps=4:</span> Effective batch size of 32 fits 8GB VRAM with FP16.
    * <span style='color: orange;'>num_train_epochs=5:</span> More epochs (from 3) for better convergence, with early stopping via load_best_model_at_end.
    * <span style='color: orange;'>weight_decay=0.01:</span> Adds regularization to prevent overfitting.
    * <span style='color: orange'>warmup_ratio=0.1:</span> Gradually ramps up learning rate over 10% of steps for stability.
**Why:** Optimizes for your hardware and dataset, balancing speed and accuracy.

In [8]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 32

args = TrainingArguments(
    output_dir=f"{model_name}-finetuned-lora-metalscrap",
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-4,              # Lowered learning rate
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    fp16=True,                       # Mixed precision
    num_train_epochs=5,              # More epochs
    logging_steps=10,                # Frequent logging
    load_best_model_at_end=True,     # Keep best model
    metric_for_best_model="accuracy",
    push_to_hub=False,
    label_names=["labels"],
    weight_decay=0.01,               # Regularization
    warmup_ratio=0.1,                # Warmup period
)



### **9. Train the Model**

**What Happens:**

* Trains for 5 epochs, evaluating and saving after each.
* Applies class weights to balance loss across classes.
* Logs loss every 10 steps—expect it to decrease steadily.

In [9]:
# Initialize trainer
trainer = WeightedTrainer(
    model=lora_model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

# Train
trainer.train()

  trainer = WeightedTrainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,0.0936,0.049235,0.98584
2,0.0902,0.030784,0.98999
3,0.0404,0.019799,0.992822
4,0.0259,0.016324,0.994971
5,0.0159,0.014208,0.995654




TrainOutput(global_step=12800, training_loss=0.11891007877886295, metrics={'train_runtime': 6530.1723, 'train_samples_per_second': 62.723, 'train_steps_per_second': 1.96, 'total_flos': 3.1961371690685768e+19, 'train_loss': 0.11891007877886295, 'epoch': 5.0})

In [10]:
repo_name = f"iDharshan/{model_name}-SIViT"
lora_model.push_to_hub(repo_name)
image_processor.push_to_hub(repo_name)
print(f"Model pushed to: https://huggingface.co/{repo_name}")

adapter_model.safetensors:   0%|          | 0.00/2.38M [00:00<?, ?B/s]

README.md:   0%|          | 0.00/5.17k [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


Model pushed to: https://huggingface.co/iDharshan/vit-base-patch16-224-SIViT
