In [9]:
from dataclasses import dataclass
import os
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig
from torch.optim import AdamW
import torchvision.transforms as transforms
import json
import numpy as np
from typing import Tuple, Dict, List
import random

@dataclass
class ModelConfig:
    input_dir: str = "images_for_finetuning"
    output_dir: str = "classifier_finetuned"
    model_name: str = "microsoft/resnet-50"  
    train_batch_size: int = 8
    eval_batch_size: int = 4
    num_epochs: int = 20
    learning_rate: float = 1e-4
    image_size: int = 224
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    train_ratio: float = 0.7
    val_ratio: float = 0.2
    test_ratio: float = 0.1
    use_augmentation: bool = True
    rotation_degrees: int = 15
    brightness: float = 0.2
    contrast: float = 0.2
    saturation: float = 0.2
    hue: float = 0.1

In [10]:
class AmazonDataset(Dataset):
    def __init__(self, root_dir, processor, image_size=224, split='train', augment=True):
        self.root_dir = root_dir
        self.processor = processor
        self.image_size = image_size
        self.split = split
        self.augment = augment and (split == 'train')
        self.samples = []

        if not os.path.exists(root_dir):
            raise FileNotFoundError(f"Directory '{root_dir}' not found. Please create the directory and add your image data.")

        class_dirs = [d for d in os.listdir(root_dir) if os.path.isdir(os.path.join(root_dir, d))]
        if not class_dirs:
            raise ValueError(f"No class directories found in '{root_dir}'")

        all_samples = []
        for label in class_dirs:
            class_dir = os.path.join(root_dir, label)
            for img_file in os.listdir(class_dir):
                if img_file.lower().endswith(('.jpeg', '.jpg', '.png', '.bmp', '.tiff')):
                    all_samples.append((os.path.join(class_dir, img_file), label))

        if not all_samples:
            raise ValueError(f"No images found in '{root_dir}'. Please check your directory structure.")

        self.label2id = {label: idx for idx, label in enumerate(sorted(class_dirs))}
        self.id2label = {v: k for k, v in self.label2id.items()}
        
        self.samples = self._split_data(all_samples, split)
        
        if self.augment:
            self.augmentation = transforms.Compose([
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(
                    brightness=0.2,
                    contrast=0.2,
                    saturation=0.2,
                    hue=0.1
                ),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomResizedCrop(size=(image_size, image_size), scale=(0.8, 1.0)),
            ])
        else:
            self.augmentation = transforms.Compose([
                transforms.Resize((image_size, image_size)),
            ])
        
        print(f"Split '{split}': Found {len(self.samples)} images across {len(self.label2id)} classes")
        if split == 'train':  
            for label, idx in self.label2id.items():
                count = sum(1 for _, l in self.samples if l == label)
                print(f"  {label}: {count} images (id: {idx})")

    def _split_data(self, all_samples: List[Tuple[str, str]], split: str) -> List[Tuple[str, str]]:
        class_samples = {}
        for sample in all_samples:
            label = sample[1]
            if label not in class_samples:
                class_samples[label] = []
            class_samples[label].append(sample)
        
        split_samples = []
        for label, samples in class_samples.items():
            random.shuffle(samples)
            
            n_samples = len(samples)
            train_end = int(n_samples * 0.7)  
            val_end = int(n_samples * 0.9)    
            
            if split == 'train':
                split_samples.extend(samples[:train_end])
            elif split == 'val':
                split_samples.extend(samples[train_end:val_end])
            elif split == 'test':
                split_samples.extend(samples[val_end:])
            else:
                raise ValueError(f"Unknown split: {split}")
        
        return split_samples

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            
            if self.augment:
                image = self.augmentation(image)
            else:
                image = transforms.Resize((self.image_size, self.image_size))(image)
            
            inputs = self.processor(image, return_tensors="pt")
            
            return {
                'pixel_values': inputs['pixel_values'].squeeze(0), 
                'labels': torch.tensor(self.label2id[label], dtype=torch.long)
            }
            
        except Exception as e:
            print(f"Error processing image {img_path}: {e}")
            raise

def collate_fn(batch):
    pixel_values = torch.stack([item['pixel_values'] for item in batch])
    labels = torch.stack([item['labels'] for item in batch])
    
    return {
        'pixel_values': pixel_values,
        'labels': labels
    }

In [11]:
def evaluate_model(model, dataloader, device, split_name="Validation"):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_predictions = 0
    
    with torch.no_grad():
        for batch in dataloader:
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss
            logits = outputs.logits

            predictions = torch.argmax(logits, dim=-1)
            correct_predictions += (predictions == labels).sum().item()
            total_predictions += labels.size(0)
            total_loss += loss.item()

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_predictions * 100
    
    print(f"{split_name} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")
    return avg_loss, accuracy

In [12]:
def train():
    cfg = ModelConfig()
    print(f"Using device: {cfg.device}")
    
    try:
        temp_processor = AutoImageProcessor.from_pretrained(cfg.model_name)
        
        print("Creating datasets...")
        train_dataset = AmazonDataset(cfg.input_dir, temp_processor, cfg.image_size, 'train', cfg.use_augmentation)
        val_dataset = AmazonDataset(cfg.input_dir, temp_processor, cfg.image_size, 'val', False)
        test_dataset = AmazonDataset(cfg.input_dir, temp_processor, cfg.image_size, 'test', False)
        
        num_labels = len(train_dataset.label2id)
        
        config = AutoConfig.from_pretrained(cfg.model_name)
        config.num_labels = num_labels
        config.id2label = train_dataset.id2label
        config.label2id = train_dataset.label2id
        
        model = AutoModelForImageClassification.from_pretrained(
            cfg.model_name, 
            config=config,
            ignore_mismatched_sizes=True 
        )
        
        processor = AutoImageProcessor.from_pretrained(cfg.model_name)
        model.to(cfg.device)
        
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=cfg.train_batch_size, 
            shuffle=True, 
            collate_fn=collate_fn,
            num_workers=0
        )
        
        val_dataloader = DataLoader(
            val_dataset, 
            batch_size=cfg.eval_batch_size, 
            shuffle=False, 
            collate_fn=collate_fn,
            num_workers=0
        )
        
        test_dataloader = DataLoader(
            test_dataset, 
            batch_size=cfg.eval_batch_size, 
            shuffle=False, 
            collate_fn=collate_fn,
            num_workers=0
        )

        optimizer = AdamW(model.parameters(), lr=cfg.learning_rate, weight_decay=0.01)
        
        print(f"Starting training for {cfg.num_epochs} epochs...")
        print(f"Train batches: {len(train_dataloader)}, Val batches: {len(val_dataloader)}, Test batches: {len(test_dataloader)}")
        
        best_val_accuracy = 0
        best_model_state = None
        
        for epoch in range(cfg.num_epochs):
            model.train()
            total_loss = 0
            correct_predictions = 0
            total_predictions = 0
            
            for batch_idx, batch in enumerate(train_dataloader):
                try:
                    pixel_values = batch["pixel_values"].to(cfg.device)
                    labels = batch["labels"].to(cfg.device)

                    outputs = model(pixel_values=pixel_values, labels=labels)
                    loss = outputs.loss
                    logits = outputs.logits

                    predictions = torch.argmax(logits, dim=-1)
                    correct_predictions += (predictions == labels).sum().item()
                    total_predictions += labels.size(0)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    
                    if batch_idx % 50 == 0:  # Print less frequently
                        accuracy = correct_predictions / total_predictions * 100
                        print(f"Epoch [{epoch+1}/{cfg.num_epochs}], "
                              f"Batch [{batch_idx}/{len(train_dataloader)}], "
                              f"Loss: {loss.item():.4f}, "
                              f"Accuracy: {accuracy:.2f}%")
                        
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {e}")
                    continue

            avg_train_loss = total_loss / len(train_dataloader)
            train_accuracy = correct_predictions / total_predictions * 100
            
            val_loss, val_accuracy = evaluate_model(model, val_dataloader, cfg.device, "Validation")
            
            print(f"\nEpoch [{epoch+1}/{cfg.num_epochs}] Summary:")
            print(f"  Train - Loss: {avg_train_loss:.4f}, Accuracy: {train_accuracy:.2f}%")
            print(f"  Val   - Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.2f}%")
            
            if val_accuracy > best_val_accuracy:
                best_val_accuracy = val_accuracy
                best_model_state = model.state_dict().copy()
                print(f"  New best validation accuracy: {best_val_accuracy:.2f}%")
            
            print("-" * 60)

        if best_model_state is not None:
            model.load_state_dict(best_model_state)
        
        test_loss , test_accuracy = evaluate_model(model, test_dataloader, cfg.device, "Test")
        
        os.makedirs(cfg.output_dir, exist_ok=True)
        model.save_pretrained(cfg.output_dir)
        processor.save_pretrained(cfg.output_dir)
        
        metadata = {
            'label2id': train_dataset.label2id,
            'id2label': train_dataset.id2label,
            'best_val_accuracy': best_val_accuracy,
            'test_accuracy': test_accuracy,
            'num_classes': num_labels,
            'train_samples': len(train_dataset),
            'val_samples': len(val_dataset),
            'test_samples': len(test_dataset),
            'model_config': {
                'model_name': cfg.model_name,
                'image_size': cfg.image_size,
                'num_epochs': cfg.num_epochs,
                'learning_rate': cfg.learning_rate,
                'batch_size': cfg.train_batch_size,
                'augmentation': cfg.use_augmentation
            }
        }
        
        with open(os.path.join(cfg.output_dir, 'training_metadata.json'), 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"\n{'='*60}")
        print("TRAINING COMPLETE")
        print(f"{'='*60}")
        print(f"Best Validation Accuracy: {best_val_accuracy:.2f}%")
        print(f"Test Accuracy: {test_accuracy:.2f}%")
        print(f"Model saved at: {cfg.output_dir}")
        print("Files saved:")
        print(f"  - Model weights: {cfg.output_dir}/pytorch_model.bin")
        print(f"  - Model config: {cfg.output_dir}/config.json")
        print(f"  - Processor config: {cfg.output_dir}/preprocessor_config.json")
        print(f"  - Training metadata: {cfg.output_dir}/training_metadata.json")

    except FileNotFoundError as e:
        print(f"Error: {e}")
        
        
    except Exception as e:
        print(f"Unexpected error: {e}")
        import traceback
        traceback.print_exc()

def test_model():
    cfg = ModelConfig()
    
    try:
        model = AutoModelForImageClassification.from_pretrained(cfg.output_dir)
        processor = AutoImageProcessor.from_pretrained(cfg.output_dir)
        
        with open(os.path.join(cfg.output_dir, 'training_metadata.json'), 'r') as f:
            metadata = json.load(f)
            id2label = {int(k): v for k, v in metadata['id2label'].items()}
        
        model.eval()
        model.to(cfg.device)
        
        print("Model loaded successfully!")
        print(f"Test Accuracy: {metadata['test_accuracy']:.2f}%")
        print("Available classes:", list(id2label.values()))
        
        def predict_image(image_path):
            image = Image.open(image_path).convert('RGB')
            inputs = processor(image, return_tensors="pt").to(cfg.device)
            
            with torch.no_grad():
                outputs = model(**inputs)
                predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
                predicted_class_id = predictions.argmax().item()
                confidence = predictions[0][predicted_class_id].item()
                
            return id2label[predicted_class_id], confidence
        
        return predict_image
        
    except Exception as e:
        print(f"Error loading model: {e}")
        return None

if __name__ == "__main__":
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(42)
    
    train()

Using device: cuda
Creating datasets...
Split 'train': Found 3500 images across 50 classes
  AirPods: 70 images (id: 0)
  Android_phone: 70 images (id: 1)
  CD_player: 70 images (id: 2)
  DVD_player: 70 images (id: 3)
  Lightning_cable: 70 images (id: 4)
  TV_remote_control: 70 images (id: 5)
  USB-C_cable: 70 images (id: 6)
  air_freshener_dispenser: 70 images (id: 7)
  air_fryer: 70 images (id: 8)
  blood_pressure_monitor: 70 images (id: 9)
  bluetooth_speaker: 70 images (id: 10)
  calculator: 70 images (id: 11)
  car_charger: 70 images (id: 12)
  computer_mouse: 70 images (id: 13)
  digital_thermometer: 70 images (id: 14)
  electric_fan: 70 images (id: 15)
  electric_kettle: 70 images (id: 16)
  electric_toothbrush: 70 images (id: 17)
  fitness_tracker: 70 images (id: 18)
  food_scale: 70 images (id: 19)
  hair_dryer: 70 images (id: 20)
  hand_mixer: 70 images (id: 21)
  headphones: 70 images (id: 22)
  iPhone: 70 images (id: 23)
  iPhone_charger: 70 images (id: 24)
  juicer: 70 ima

Some weights of ResNetForImageClassification were not initialized from the model checkpoint at microsoft/resnet-50 and are newly initialized because the shapes did not match:
- classifier.1.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.1.weight: found shape torch.Size([1000, 2048]) in the checkpoint and torch.Size([50, 2048]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Starting training for 20 epochs...
Train batches: 438, Val batches: 250, Test batches: 125
Epoch [1/20], Batch [0/438], Loss: 3.9145, Accuracy: 0.00%
Epoch [1/20], Batch [50/438], Loss: 3.8883, Accuracy: 3.43%
Epoch [1/20], Batch [100/438], Loss: 3.8463, Accuracy: 3.09%
Epoch [1/20], Batch [150/438], Loss: 3.8892, Accuracy: 4.22%
Epoch [1/20], Batch [200/438], Loss: 3.7188, Accuracy: 5.29%
Epoch [1/20], Batch [250/438], Loss: 3.7810, Accuracy: 7.42%
Epoch [1/20], Batch [300/438], Loss: 3.4793, Accuracy: 9.93%
Epoch [1/20], Batch [350/438], Loss: 3.2624, Accuracy: 12.29%
Epoch [1/20], Batch [400/438], Loss: 3.3360, Accuracy: 14.84%
Validation - Loss: 2.9630, Accuracy: 52.00%

Epoch [1/20] Summary:
  Train - Loss: 3.6675, Accuracy: 17.40%
  Val   - Loss: 2.9630, Accuracy: 52.00%
  New best validation accuracy: 52.00%
------------------------------------------------------------
Epoch [2/20], Batch [0/438], Loss: 3.3366, Accuracy: 50.00%
Epoch [2/20], Batch [50/438], Loss: 2.5994, Accuracy