In [30]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from tqdm import tqdm
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from typing import List, Dict, Any, Optional
import pandas as pd

In [31]:
class CLIPClassifier:
    def __init__(self, model_name: str = "openai/clip-vit-base-patch32"):
        """
        Initialize the CLIP classifier
        Args:
            model_name: Name of the CLIP model to use
        """
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {self.device}")
        
        self.model = CLIPModel.from_pretrained(model_name).to(self.device)
        self.processor = CLIPProcessor.from_pretrained(model_name)
        
    def prepare_dataset(self, dataset_name: str, split: str = "train", 
                       text_column: str = "text", image_column: str = "image", **kwargs):
        """
        Load and prepare the dataset from Hugging Face
        """
        self.dataset = load_dataset(dataset_name, split=split, **kwargs)
        self.text_column = text_column
        self.image_column = image_column
        
        # Get unique labels
        self.labels = sorted(list(set(self.dataset[text_column])))
        print(f"Found {len(self.labels)} unique labels: {self.labels}")
        
        # Generate CLIP-style label descriptions
        self.clip_labels = [f"a photo of a {label}" for label in self.labels]
        
        # Create and process label tokens
        self.label_tokens = self.processor(
            text=self.clip_labels,
            padding=True,
            images=None,
            return_tensors='pt'
        ).to(self.device)
        
        # Generate label embeddings
        with torch.no_grad():
            self.label_embeddings = self.model.get_text_features(**self.label_tokens)
            self.label_embeddings = self.label_embeddings.detach().cpu().numpy()
            # Normalize label embeddings
            self.label_embeddings = self.label_embeddings / np.linalg.norm(self.label_embeddings, axis=0)

    def inference_on_samples(self, sample_indices: List[int]) -> Dict[str, Any]:
        """
        Perform inference on specific samples
        Args:
            sample_indices: List of indices to perform inference on
        Returns:
            Dictionary containing predictions and confidence scores
        """
        results = []
        for idx in sample_indices:
            sample = self.dataset[idx]
            
            # Process image
            image = self.processor(
                text=None,
                images=sample[self.image_column],
                return_tensors='pt'
            )['pixel_values'].to(self.device)
            
            # Get image embeddings
            with torch.no_grad():
                image_embeddings = self.model.get_image_features(image)
                image_embeddings = image_embeddings.detach().cpu().numpy()
            
            # Calculate similarity scores
            scores = np.dot(image_embeddings, self.label_embeddings.T)[0]
            pred_idx = np.argmax(scores)
            confidence_scores = softmax(scores)
            
            results.append({
                'index': idx,
                'true_label': sample[self.text_column],
                'predicted_label': self.labels[pred_idx],
                'confidence': confidence_scores[pred_idx],
                'all_scores': dict(zip(self.labels, confidence_scores))
            })
        
        return results

    def plot_confusion_matrix(self, true_labels: List[str], pred_labels: List[str], 
                            output_path: Optional[str] = None):
        """
        Plot confusion matrix
        """
        cm = confusion_matrix(true_labels, pred_labels, labels=self.labels)
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=self.labels, yticklabels=self.labels)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.xticks(rotation=45)
        plt.tight_layout()
        
        if output_path:
            plt.savefig(output_path)
        plt.close()

    def plot_metrics_comparison(self, metrics_dict: Dict[str, float], 
                              output_path: Optional[str] = None):
        """
        Plot comparison of different metrics
        """
        plt.figure(figsize=(10, 6))
        bars = plt.bar(metrics_dict.keys(), metrics_dict.values())
        plt.title('Classification Metrics Comparison')
        plt.ylabel('Score')
        
        # Add value labels on top of bars
        for bar in bars:
            height = bar.get_height()
            plt.text(bar.get_x() + bar.get_width()/2., height,
                    f'{height:.3f}', ha='center', va='bottom')
        
        plt.ylim(0, 1.1)  # Set y-axis limit to 0-1 with some padding
        if output_path:
            plt.savefig(output_path)
        plt.close()

    def evaluate(self, batch_size: int = 32) -> Dict[str, Any]:
        """
        Perform evaluation with multiple metrics
        """
        all_predictions = []
        all_labels = []
        all_scores = []
        
        for i in tqdm(range(0, len(self.dataset), batch_size)):
            i_end = min(i + batch_size, len(self.dataset))
            batch = self.dataset[i:i_end]
            
            # Process batch of images
            images = self.processor(
                text=None,
                images=batch[self.image_column],
                return_tensors='pt'
            )['pixel_values'].to(self.device)
            
            # Get image embeddings
            with torch.no_grad():
                image_embeddings = self.model.get_image_features(images)
                image_embeddings = image_embeddings.detach().cpu().numpy()
            
            # Calculate similarity scores
            scores = np.dot(image_embeddings, self.label_embeddings.T)
            predictions = np.argmax(scores, axis=1)
            
            # Store predictions and true labels
            all_predictions.extend([self.labels[idx] for idx in predictions])
            all_labels.extend(batch[self.text_column])
            all_scores.extend(scores)
        
        # Calculate metrics
        precision, recall, f1, support = precision_recall_fscore_support(
            all_labels, all_predictions, average='weighted'
        )
        
        # Create classification report
        report = classification_report(
            all_labels, all_predictions, 
            target_names=self.labels, 
            output_dict=True
        )
        
        # Calculate accuracy
        accuracy = np.mean(np.array(all_predictions) == np.array(all_labels))
        
        # Compile metrics
        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1
        }
        
        # Generate plots
        self.plot_confusion_matrix(all_labels, all_predictions, 'confusion_matrix.png')
        self.plot_metrics_comparison(metrics, 'metrics_comparison.png')
        
        return {
            'metrics': metrics,
            'detailed_report': report,
            'predictions': all_predictions,
            'true_labels': all_labels,
            'raw_scores': all_scores
        }

def softmax(x):
    """
    Compute softmax values for each set of scores
    """
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

In [23]:
from datasets import load_dataset
dataset = load_dataset("imagefolder", data_dir="dataset")

In [None]:
def main():
    # Initialize classifier
    classifier = CLIPClassifier()
    
    # Prepare dataset
    classifier.prepare_dataset(
        "dataset",
        split="train",
        text_column="label",
        image_column="image"
    )
    
    # Perform evaluation
    print("\nPerforming evaluation...")
    eval_results = classifier.evaluate(batch_size=32)
    
    print("\nOverall Metrics:")
    for metric, value in eval_results['metrics'].items():
        print(f"{metric}: {value:.4f}")
    
    # Perform inference on some samples
    print("\nPerforming inference on sample images...")
    sample_indices = [0, 10, 20, 30, 40]  
    sample_results = classifier.inference_on_samples(sample_indices)
    
    print("\nSample Predictions:")
    for result in sample_results:
        print(f"\nImage {result['index']}:")
        print(f"True label: {result['true_label']}")
        print(f"Predicted label: {result['predicted_label']}")
        print(f"Confidence: {result['confidence']:.4f}")

if __name__ == "__main__":
    main()

Using device: cpu
Found 2 unique labels: [0, 1]

Performing evaluation...


100%|██████████| 92/92 [00:40<00:00,  2.28it/s]



Overall Metrics:
accuracy: 0.3833
precision: 0.3672
recall: 0.3833
f1_score: 0.3640

Performing inference on sample images...

Sample Predictions:

Image 0:
True label: 0
Predicted label: 0
Confidence: 0.5803

Image 10:
True label: 0
Predicted label: 0
Confidence: 0.9882

Image 20:
True label: 0
Predicted label: 1
Confidence: 0.9987

Image 30:
True label: 0
Predicted label: 1
Confidence: 0.5716

Image 40:
True label: 0
Predicted label: 0
Confidence: 0.9184


In [6]:
from typing import List, Dict, Any, Optional
import matplotlib.pyplot as plt
def plot_implementations_comparison(implementations_results: Dict[str, Dict[str, float]], 
                                   output_path: Optional[str] = None):
    """
    Create a line plot comparing metrics between different implementations
    Args:
        implementations_results: Dictionary containing results from different implementations
        output_path: Optional path to save the plot
    """
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
    metric_keys = ['accuracy', 'precision', 'recall', 'f1_score']
    
    # Set style parameters
    plt.style.use('bmh')  # Using 'bmh' style which provides good visualization for data comparison
    
    # Create figure
    plt.figure(figsize=(12, 6))
    
    # Plot lines for each implementation
    markers = ['o', 's', 'D', '^', 'v']  # Different markers for different implementations
    for idx, (impl_name, results) in enumerate(implementations_results.items()):
        values = [results[key] for key in metric_keys]
        plt.plot(metrics, values, marker=markers[idx], label=impl_name, linewidth=2, markersize=8)
    
    # Customize the plot
    plt.xlabel('Metrics', fontsize=12, fontweight='bold')
    plt.ylabel('Score', fontsize=12, fontweight='bold')
    plt.title('Comparison of Different Model Implementations for Autism Detection', fontsize=14, fontweight='bold', pad=20)
    
    # Enhance grid and spines
    plt.grid(True, linestyle='--', alpha=0.7)
    ax = plt.gca()
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Customize legend
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', frameon=True, fancybox=True, shadow=True)
    
    # Customize tick labels
    plt.xticks(fontsize=10)
    plt.yticks(fontsize=10)
    
    # Set y-axis limits with some padding
    plt.ylim(0, 1.1)
    
    # Add value annotations
    for impl_name, results in implementations_results.items():
        values = [results[key] for key in metric_keys]
        for i, value in enumerate(values):
            plt.annotate(f'{value:.3f}', 
                        (metrics[i], value),
                        textcoords="offset points",
                        xytext=(0,10),
                        ha='center')
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, bbox_inches='tight', dpi=300)
        print(f"Plot saved to {output_path}")
    
    plt.close()

def compare_implementations(implementations_results: Dict[str, Dict[str, float]]) -> Dict:
    """
    Compare metrics between different implementations
    Args:
        implementations_results: Dictionary containing results from different implementations
        Example format:
        {
            'CLIP Custom': {'accuracy': 0.87, 'precision': 0.86, 'recall': 0.85, 'f1_score': 0.86},
            'PyTorch CLIP': {'accuracy': 0.85, 'precision': 0.84, 'recall': 0.83, 'f1_score': 0.83},
            ...
        }
    Returns:
        Dictionary with formatted comparison data
    """
    metrics = ['Accuracy', 'Precision', 'Recall', 'F1 Score']
    metric_keys = ['accuracy', 'precision', 'recall', 'f1_score']
    
    comparison_data = []
    for metric, key in zip(metrics, metric_keys):
        data_point = {'metric': metric}
        for impl_name, results in implementations_results.items():
            data_point[impl_name] = results[key]
        comparison_data.append(data_point)
    
    return comparison_data

In [8]:
implementations_results = {
    'Vision Transformer': {
        'accuracy': 0.86,
        'precision': 0.86,
        'recall': 0.86,
        'f1_score': 0.86
    },
    'VGG 16': {
        'accuracy': 0.76,
        'precision': 0.75,
        'recall': 0.72,
        'f1_score': 0.74
    },
    'CLIP': {
        'accuracy': 0.38,
        'precision': 0.36,
        'recall': 0.38,
        'f1_score': 0.36
    }
}

# Create the comparison plot
plot_implementations_comparison(implementations_results, 'implementations_comparison.png')

Plot saved to implementations_comparison.png
