In [17]:
from dataclasses import dataclass
import os
import torch
import torch.nn as nn
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig
from torch.optim import AdamW

@dataclass
class ModelConfig:
    input_dir: str = "images_for_finetuning"
    output_dir: str = "classifier_finetuned"
    model_name: str = "microsoft/resnet-50"  
    train_batch_size: int = 8
    eval_batch_size: int = 4
    num_epochs: int = 5
    learning_rate: float = 2e-5
    image_size: int = 224  
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

In [18]:
class AmazonDataset(Dataset):
    def __init__(self, root_dir, processor, image_size=224):
        self.root_dir = root_dir
        self.processor = processor
        self.image_size = image_size
        self.samples = []

        if not os.path.exists(root_dir):
            raise FileNotFoundError(f"Directory '{root_dir}' not found. Please create the directory and add your image data.")

        class_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        if not class_dirs:
            raise ValueError(f"No class directories found in '{root_dir}'")

        for label in class_dirs:
            class_dir = os.path.join(root_dir, label)
            for img_file in os.listdir(class_dir):
                if img_file.lower().endswith(('.jpeg', '.jpg', '.png', '.bmp', '.tiff')):
                    self.samples.append((os.path.join(class_dir, img_file), label))

        if not self.samples:
            raise ValueError(f"No images found in '{root_dir}'. Please check your directory structure.")

        self.label2id = {label: idx for idx, label in enumerate(sorted(class_dirs))}
        self.id2label = {v: k for k, v in self.label2id.items()}
        
        print(f"Found {len(self.samples)} images across {len(self.label2id)} classes:")
        for label, idx in self.label2id.items():
            count = sum(1 for _, l in self.samples if l == label)
            print(f"  {label}: {count} images (id: {idx})")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            
            inputs = self.processor(image, return_tensors="pt")
            
            return {
                'pixel_values': inputs['pixel_values'].squeeze(0),  # Remove batch dimension
                'labels': torch.tensor(self.label2id[label], dtype=torch.long)
            }
            
        except Exception as e:
            print(f"Error processing image {img_path}: {e}")
            raise

In [19]:
def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }

In [20]:


def train():
    cfg = ModelConfig()
    print(f"Using device: {cfg.device}")
    
    try:
        temp_processor = AutoImageProcessor.from_pretrained(cfg.model_name)
        dataset = AmazonDataset(cfg.input_dir, temp_processor, cfg.image_size)
        num_labels = len(dataset.label2id)
        
        config = AutoConfig.from_pretrained(cfg.model_name)
        config.num_labels = num_labels
        config.id2label = dataset.id2label
        config.label2id = dataset.label2id
        
        model = AutoModelForImageClassification.from_pretrained(
            cfg.model_name, 
            config=config,
            ignore_mismatched_sizes=True 
        )
        
        processor = AutoImageProcessor.from_pretrained(cfg.model_name)
        model.to(cfg.device)
        
        dataloader = DataLoader(
            dataset, 
            batch_size=cfg.train_batch_size, 
            shuffle=True, 
            collate_fn=collate_fn,
            num_workers=0
        )

        optimizer = AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=0.01)
        
        model.train()
        
        print(f"Starting training for {cfg.num_epochs} epochs...")
        print(f"Total batches per epoch: {len(dataloader)}")
        
        for epoch in range(cfg.num_epochs):
            total_loss = 0
            correct_predictions = 0
            total_predictions = 0
            
            for batch_idx, batch in enumerate(dataloader):
                try:
                    pixel_values = batch["pixel_values"].to(cfg.device)
                    labels = batch["labels"].to(cfg.device)

                    outputs = model(pixel_values=pixel_values, labels=labels)
                    loss = outputs.loss
                    logits = outputs.logits

                    predictions = torch.argmax(logits, dim=-1)
                    correct_predictions += (predictions == labels).sum().item()
                    total_predictions += labels.size(0)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    
                    if batch_idx % 5 == 0:
                        accuracy = correct_predictions / total_predictions * 100
                        print(f"Epoch [{epoch+1}/{cfg.num_epochs}], "
                              f"Batch [{batch_idx}/{len(dataloader)}], "
                              f"Loss: {loss.item():.4f}, "
                              f"Accuracy: {accuracy:.2f}%")
                        
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {e}")
                    continue

            avg_loss = total_loss / len(dataloader)
            epoch_accuracy = correct_predictions / total_predictions * 100
            print(f"\nEpoch [{epoch+1}/{cfg.num_epochs}] Summary:")
            print(f"  Average Loss: {avg_loss:.4f}")
            print(f"  Accuracy: {epoch_accuracy:.2f}%")
            print("-" * 50)

        os.makedirs(cfg.output_dir, exist_ok=True)
        model.save_pretrained(cfg.output_dir)
        processor.save_pretrained(cfg.output_dir)
        
        import json
        with open(os.path.join(cfg.output_dir, 'label_mappings.json'), 'w') as f:
            json.dump({
                'label2id': dataset.label2id,
                'id2label': dataset.id2label
            }, f, indent=2)
        
        print(f"\nModel saved successfully at {cfg.output_dir}")
        print("Files saved:")
        print(f"  - Model weights: {cfg.output_dir}/pytorch_model.bin")
        print(f"  - Model config: {cfg.output_dir}/config.json")
        print(f"  - Processor config: {cfg.output_dir}/preprocessor_config.json")
        print(f"  - Label mappings: {cfg.output_dir}/label_mappings.json")

    except FileNotFoundError as e:
        print(f"Error: {e}")
       
        
    except Exception as e:
        print(f"Unexpected error: {e}")
        import traceback
        traceback.print_exc()

def test_model():
    cfg = ModelConfig()
    
    try:
        model = AutoModelForImageClassification.from_pretrained(cfg.output_dir)
        processor = AutoImageProcessor.from_pretrained(cfg.output_dir)
        
        import json
        with open(os.path.join(cfg.output_dir, 'label_mappings.json'), 'r') as f:
            mappings = json.load(f)
            id2label = {int(k): v for k, v in mappings['id2label'].items()}
        
        model.eval()
        model.to(cfg.device)
        
        print("Model loaded successfully!")
        print("Available classes:", list(id2label.values()))
        
        def predict_image(image_path):
            image = Image.open(image_path).convert('RGB')
            inputs = processor(image, return_tensors="pt").to(cfg.device)
            
            with torch.no_grad():
                outputs = model(**inputs)
                predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
                predicted_class_id = predictions.argmax().item()
                confidence = predictions[0][predicted_class_id].item()
                
            return id2label[predicted_class_id], confidence
        
        return predict_image
        
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

if __name__ == "__main__":
    train()

Using device: cpu
Found 5000 images across 50 classes:
  AirPods: 100 images (id: 0)
  Android_phone: 100 images (id: 1)
  CD_player: 100 images (id: 2)
  DVD_player: 100 images (id: 3)
  Lightning_cable: 100 images (id: 4)
  TV_remote_control: 100 images (id: 5)
  USB-C_cable: 100 images (id: 6)
  air_freshener_dispenser: 100 images (id: 7)
  air_fryer: 100 images (id: 8)
  blood_pressure_monitor: 100 images (id: 9)
  bluetooth_speaker: 100 images (id: 10)
  calculator: 100 images (id: 11)
  car_charger: 100 images (id: 12)
  computer_mouse: 100 images (id: 13)
  digital_thermometer: 100 images (id: 14)
  electric_fan: 100 images (id: 15)
  electric_kettle: 100 images (id: 16)
  electric_toothbrush: 100 images (id: 17)
  fitness_tracker: 100 images (id: 18)
  food_scale: 100 images (id: 19)
  hair_dryer: 100 images (id: 20)
  hand_mixer: 100 images (id: 21)
  headphones: 100 images (id: 22)
  iPhone: 100 images (id: 23)
  iPhone_charger: 100 images (id: 24)
  juicer: 100 images (id: 2

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([50, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training for 5 epochs...
Total batches per epoch: 625
Epoch [1/5], Batch [0/625], Loss: 3.9364, Accuracy: 0.00%
Epoch [1/5], Batch [5/625], Loss: 3.9532, Accuracy: 0.00%
Epoch [1/5], Batch [10/625], Loss: 3.9348, Accuracy: 1.14%
Epoch [1/5], Batch [15/625], Loss: 3.9608, Accuracy: 0.78%
Epoch [1/5], Batch [20/625], Loss: 3.9661, Accuracy: 1.19%
Epoch [1/5], Batch [25/625], Loss: 3.8751, Accuracy: 1.44%
Epoch [1/5], Batch [30/625], Loss: 3.9434, Accuracy: 1.61%
Epoch [1/5], Batch [35/625], Loss: 3.8970, Accuracy: 2.08%
Epoch [1/5], Batch [40/625], Loss: 3.9276, Accuracy: 1.83%
Epoch [1/5], Batch [45/625], Loss: 3.9337, Accuracy: 1.90%
Epoch [1/5], Batch [50/625], Loss: 3.8734, Accuracy: 1.72%
Epoch [1/5], Batch [55/625], Loss: 3.9360, Accuracy: 1.56%
Epoch [1/5], Batch [60/625], Loss: 3.9308, Accuracy: 1.43%
Epoch [1/5], Batch [65/625], Loss: 3.8939, Accuracy: 1.52%
Epoch [1/5], Batch [70/625], Loss: 3.9600, Accuracy: 1.41%
Epoch [1/5], Batch [75/625], Loss: 3.9189, Accuracy: 1