# Simple Cuisine Classification ML Training Pipeline

A straightforward ML training pipeline for cuisine classification using ResNet-50.

## Pipeline Overview
1. **Data Loading**: Load processed images from gold layer
2. **Simple Preprocessing**: Convert bytes to PIL images with transforms
3. **Model Training**: Fine-tune ResNet-50 using standard Transformers patterns
4. **MLflow Integration**: Log and register model

*Based on proven reference patterns - simple and reliable.*

In [0]:
# Simple installation - only what we need
%pip install torch torchvision transformers datasets mlflow scikit-learn

In [0]:
dbutils.library.restartPython()

In [0]:
# Simple imports - clean and minimal
import mlflow
import torch
import pandas as pd
import numpy as np
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer
from PIL import Image
import io
from torchvision.transforms import Compose, Normalize, ToTensor, Lambda
from datasets import Dataset
from sklearn.metrics import accuracy_score, f1_score
# from sklearn.preprocessing import LabelEncoder

print("‚úÖ Simple imports loaded successfully")

In [0]:
# Simple configuration - no complex widgets
CATALOG = "cuisine_vision_catalog"
MODEL_CHECKPOINT = "microsoft/resnet-50"
EXPERIMENT_NAME = "/cuisine_classifier"
NUM_EPOCHS = 5 # 3
BATCH_SIZE = 12 # 8
LEARNING_RATE = 2e-4 # 5e-5

print(f"üîß Configuration:")
print(f"   üìä Catalog: {CATALOG}")
print(f"   üß† Model: {MODEL_CHECKPOINT}")
print(f"   üîÑ Epochs: {NUM_EPOCHS}")
print(f"   üì¶ Batch Size: {BATCH_SIZE}")
print(f"   üìà Learning Rate: {LEARNING_RATE}")

In [0]:
# Simple data loading - direct from gold table
print("üìä Loading data from gold layer...")

# Load data directly - no complex joins
dataset_df = (
    spark.table(f"{CATALOG}.gold.ml_dataset")
    .select("processed_image_data", "cuisine_category")
    .filter("processed_image_data IS NOT NULL")
    .toPandas()
)

print(f"‚úÖ Loaded {len(dataset_df)} samples")
print(f"   üçΩÔ∏è Cuisines: {sorted(dataset_df['cuisine_category'].unique())}")

# Create HuggingFace dataset - simple rename
dataset = Dataset.from_pandas(
    dataset_df.rename(columns={
        "processed_image_data": "image", 
        "cuisine_category": "label"
    })
)

# Simple train/test split
splits = dataset.train_test_split(test_size=0.2, seed=42)
train_ds = splits['train']
val_ds = splits['test']

print(f"‚úÖ Data splits:")
print(f"   üèãÔ∏è Training: {len(train_ds)} samples")
print(f"   ‚úÖ Validation: {len(val_ds)} samples")

In [0]:
# Simple preprocessing - exactly like reference notebook
print("üîÑ Setting up simple preprocessing...")

# Load image processor
image_processor = AutoImageProcessor.from_pretrained(MODEL_CHECKPOINT)

# Simple transform pipeline
transforms = Compose([
    Lambda(lambda b: Image.open(io.BytesIO(b)).convert("RGB")),
    ToTensor(),
    Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
])

def preprocess(batch):
    """Simple preprocessing function"""
    batch["image"] = [transforms(image) for image in batch["image"]]
    return batch

# Apply transforms
train_ds.set_transform(preprocess)
val_ds.set_transform(preprocess)

print("‚úÖ Simple preprocessing setup complete")

In [0]:
# Simple model setup - no complex wrappers
print("üß† Setting up simple model...")

# Create simple label mappings
unique_labels = sorted(set(dataset['label']))
label2id = {label: i for i, label in enumerate(unique_labels)}
id2label = {i: label for label, i in label2id.items()}
num_labels = len(unique_labels)

print(f"‚úÖ Labels: {id2label}")

# Load model - simple and direct
model = AutoModelForImageClassification.from_pretrained(
    MODEL_CHECKPOINT,
    label2id=label2id,
    id2label=id2label,
    num_labels=num_labels,
    ignore_mismatched_sizes=True
)

print(f"‚úÖ Model loaded with {num_labels} classes")

In [0]:
# Optimize training performance and eliminate warnings
import os

print("üîß Optimizing training performance...")

# Set threading for better CPU utilization
os.environ['OMP_NUM_THREADS'] = '8'
os.environ['MKL_NUM_THREADS'] = '8'

# Configure PyTorch for optimal performance
torch.set_num_threads(8)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"   üñ•Ô∏è Training device: {device}")
print(f"   üßµ CPU threads: 8")
print("‚úÖ Performance optimizations applied")

In [0]:
# Simple training - no complex custom trainers
print("üèãÔ∏è Starting training...")

# Setup MLflow
mlflow.set_experiment(EXPERIMENT_NAME)

with mlflow.start_run() as run:
    print(f"üîÑ MLflow run: {run.info.run_id}")
    

    # Training arguments

    args = TrainingArguments(
            output_dir=f"/dbfs/tmp/cuisine-classifier-simple",
            remove_unused_columns=False,
            eval_strategy="epoch",
            save_strategy="epoch",
            learning_rate=LEARNING_RATE,
            per_device_train_batch_size=BATCH_SIZE,
            per_device_eval_batch_size=BATCH_SIZE,
            num_train_epochs=NUM_EPOCHS,
            weight_decay=0.01,
            load_best_model_at_end=True,
            metric_for_best_model="eval_loss",
            logging_steps=10,
            report_to=[],
            # PERFORMANCE OPTIMIZATIONS:
            dataloader_pin_memory=False,  # Fix pin_memory warning
            ddp_find_unused_parameters=False,  # Fix DDP warning
            use_cpu=not torch.cuda.is_available(),  # Optimize for CPU if no GPU
            )
    
    # args = TrainingArguments(
    #     output_dir=f"/dbfs/tmp/cuisine-classifier-simple",
    #     remove_unused_columns=False,
    #     eval_strategy="epoch",  # Fixed: was evaluation_strategy
    #     save_strategy="epoch",
    #     learning_rate=LEARNING_RATE,
    #     per_device_train_batch_size=BATCH_SIZE,
    #     per_device_eval_batch_size=BATCH_SIZE,
    #     num_train_epochs=NUM_EPOCHS,
    #     weight_decay=0.01,
    #     load_best_model_at_end=True,
    #     metric_for_best_model="eval_loss",
    #     logging_steps=10,
    #     report_to=[]
    # )
    
    # Simple data collator - like reference
    def collate_fn(examples):
        pixel_values = torch.stack([e["image"] for e in examples])
        labels = torch.tensor([label2id[e["label"]] for e in examples], dtype=torch.long)
        return {"pixel_values": pixel_values, "labels": labels}
    
    # Simple metrics
    def compute_metrics(eval_pred):
        predictions, labels = eval_pred
        predictions = predictions.argmax(axis=-1)
        accuracy = accuracy_score(labels, predictions)
        f1 = f1_score(labels, predictions, average='weighted')
        return {'accuracy': accuracy, 'f1': f1}

    # Trainer - standard Transformers - FIXED VERSION
    trainer = Trainer(
        model=model, 
        args=args, 
        train_dataset=train_ds, 
        eval_dataset=val_ds, 
        processing_class=image_processor,  # Fixed: use processing_class instead of tokenizer
        data_collator=collate_fn,
        compute_metrics=compute_metrics
    )
    # trainer = Trainer(
    #     model=model, 
    #     args=args, 
    #     train_dataset=train_ds, 
    #     eval_dataset=val_ds, 
    #     tokenizer=image_processor, 
    #     data_collator=collate_fn,
    #     compute_metrics=compute_metrics
    # )
    
    # Train the model
    print("üöÄ Training started...")
    trainer.train()
    print("‚úÖ Training completed!")
    
    # Evaluate
    print("üìä Evaluating model...")
    eval_results = trainer.evaluate()
    print(f"‚úÖ Final metrics: {eval_results}")
    
    # Log parameters
    mlflow.log_param("model_checkpoint", MODEL_CHECKPOINT)
    mlflow.log_param("num_epochs", NUM_EPOCHS)
    mlflow.log_param("batch_size", BATCH_SIZE)
    mlflow.log_param("learning_rate", LEARNING_RATE)
    mlflow.log_param("num_labels", num_labels)
    
    # Log metrics
    for key, value in eval_results.items():
        if isinstance(value, (int, float)):
            mlflow.log_metric(key, value)

com.databricks.backend.common.rpc.CommandCancelledException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$5(SequenceExecutionState.scala:132)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:132)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.can

In [0]:
# Simple model wrapper for MLflow - like reference
print("üì¶ Creating simple model wrapper...")

from transformers import pipeline

# Create pipeline from trained model
classifier = pipeline(
    "image-classification", 
    model=trainer.model, 
    feature_extractor=image_processor
)

class SimpleCuisineClassifier(mlflow.pyfunc.PythonModel):
    """Simple wrapper for cuisine classification - like reference notebook"""
    
    def __init__(self, pipeline):
        self.pipeline = pipeline
        self.pipeline.model.eval()
    
    def predict(self, context, model_input):
        """Simple prediction method"""
        # Handle DataFrame input
        if isinstance(model_input, pd.DataFrame):
            # Convert bytes to PIL images
            images = model_input['processed_image_data'].apply(
                lambda b: Image.open(io.BytesIO(b)).convert("RGB")
            ).tolist()
            
            # Get predictions
            with torch.no_grad():
                predictions = self.pipeline(images)
            
            # Return top prediction for each image
            return pd.DataFrame([
                max(pred, key=lambda x: x['score']) 
                for pred in predictions
            ])
        
        # Handle single image bytes
        else:
            image = Image.open(io.BytesIO(model_input)).convert("RGB")
            with torch.no_grad():
                prediction = self.pipeline(image)
            return max(prediction, key=lambda x: x['score'])

# Create wrapped model
wrapped_model = SimpleCuisineClassifier(classifier)
print("‚úÖ Simple model wrapper created")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:741)
	at com.data

In [0]:
# Simple MLflow logging and registration
print("üìä Logging model to MLflow...")

# Import signature utilities
from mlflow.models.signature import infer_signature

with mlflow.start_run(run_id=run.info.run_id):
    # Test model with sample data and create signature
    test_df = dataset_df[['processed_image_data']].head(3)
    test_predictions = wrapped_model.predict(None, test_df)
    print(f"‚úÖ Test predictions: {test_predictions}")
    
    # Create model signature - required for Unity Catalog
    signature = infer_signature(test_df, test_predictions)
    print(f"‚úÖ Model signature created: {signature}")
    
    # Log model with signature - required for Unity Catalog
    model_info = mlflow.pyfunc.log_model(
        artifact_path="model",
        python_model=wrapped_model,
        signature=signature,  # Added signature for Unity Catalog
        pip_requirements=[
            "torch", 
            "transformers", 
            "pillow", 
            "pandas",
            "numpy"
        ]
    )
    
    print(f"‚úÖ Model logged with signature: {model_info.model_uri}")

# Register to Unity Catalog - simple registration
full_model_name = f"{CATALOG}.ml_models.cuisine_classifier"
registered_model = mlflow.register_model(
    model_uri=model_info.model_uri, 
    name=full_model_name,
    tags={
        "stage": "development",
        "task": "image_classification",
        "architecture": "ResNet-50",
        "approach": "simple"
    }
)

print(f"üéâ Model registered successfully!")
print(f"   üì¶ Model: {full_model_name}")
print(f"   üè∑Ô∏è Version: {registered_model.version}")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:741)
	at com.data

In [0]:
# Simple testing - verify everything works
print("üß™ Final testing...")

# Test with a few samples
test_samples = dataset_df.sample(n=4)
for idx, row in test_samples.iterrows():
    true_label = row['cuisine_category']
    image_bytes = row['processed_image_data']
    
    # Make prediction
    prediction = wrapped_model.predict(None, image_bytes)
    
    print(f"Sample {idx}:")
    print(f"   ‚úÖ True: {true_label}")
    print(f"   üéØ Predicted: {prediction['label']} (score: {prediction['score']:.3f})")
    print()

print("üéâ Simple pipeline completed successfully!")
print("\nüìã Summary:")
print(f"   üìä Total samples: {len(dataset_df)}")
print(f"   üè∑Ô∏è Classes: {num_labels}")
print(f"   üîÑ Epochs: {NUM_EPOCHS}")
print(f"   üì¶ Model: {full_model_name} v{registered_model.version}")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:741)
	at com.data

## üìä Model Performance Diagnostics

Let's analyze why the model might not be predicting accurately by examining the dataset and training results.

In [0]:
# Dataset Analysis - Check for common issues
print("üîç Dataset Analysis:")
print(f"üìä Total samples: {len(dataset_df)}")

# Check class distribution
class_counts = dataset_df['cuisine_category'].value_counts()
print(f"\nüçΩÔ∏è Class Distribution:")
for cuisine, count in class_counts.items():
    percentage = (count / len(dataset_df)) * 100
    print(f"   {cuisine}: {count} samples ({percentage:.1f}%)")

# Check for class imbalance
min_samples = class_counts.min()
max_samples = class_counts.max()
imbalance_ratio = max_samples / min_samples
print(f"\n‚öñÔ∏è Class Imbalance Analysis:")
print(f"   Min class size: {min_samples} samples")
print(f"   Max class size: {max_samples} samples") 
print(f"   Imbalance ratio: {imbalance_ratio:.2f}x")

# Identify potential issues
print(f"\n‚ö†Ô∏è Potential Issues Detected:")
if imbalance_ratio > 3:
    print("   üö® SIGNIFICANT CLASS IMBALANCE! Some classes have 3x+ more samples than others")
    print("      ‚Üí Solution: Use class weights or data augmentation")

if min_samples < 50:
    print("   üö® VERY SMALL DATASET! Some classes have <50 samples")
    print("      ‚Üí Solution: Collect more data or use data augmentation")

if len(dataset_df) < 500:
    print("   üö® SMALL TOTAL DATASET! Less than 500 samples for deep learning")
    print("      ‚Üí Solution: Collect significantly more data")

if max_samples > 5 * min_samples:
    print("   üö® EXTREME IMBALANCE! Majority class dominates")
    print("      ‚Üí Solution: Balance dataset or use stratified sampling")

print(f"\nüìà Recommendations:")
print(f"   ‚Ä¢ Ideal dataset size: 1000+ samples per class")
print(f"   ‚Ä¢ Current average: {len(dataset_df) / num_labels:.0f} samples per class")
print(f"   ‚Ä¢ Minimum recommended: 200+ samples per class")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:741)
	at com.data

In [0]:
# Training Performance Analysis
print("üìä Training Performance Analysis:")

# Analyze final training metrics
if 'eval_results' in locals():
    print("\n‚úÖ Final Evaluation Metrics:")
    for metric, value in eval_results.items():
        if isinstance(value, (int, float)):
            print(f"   {metric}: {value:.4f}")
    
    # Interpret the metrics
    eval_acc = eval_results.get('eval_accuracy', 0)
    eval_loss = eval_results.get('eval_loss', float('inf'))
    
    print(f"\nüéØ Performance Interpretation:")
    if eval_acc < 0.3:
        print("   üî¥ CRITICAL: Very low accuracy (<30%) - model is barely learning")
        print("      ‚Üí Likely causes: insufficient data, too few epochs, or data quality issues")
    elif eval_acc < 0.5:
        print("   üü° POOR: Low accuracy (<50%) - significant improvement needed")
        print("      ‚Üí Likely causes: class imbalance, insufficient training, or weak features")
    elif eval_acc < 0.7:
        print("   üü† FAIR: Moderate accuracy (<70%) - room for improvement")
        print("      ‚Üí Solutions: more training, data augmentation, or hyperparameter tuning")
    elif eval_acc < 0.85:
        print("   üü¢ GOOD: Solid accuracy (70-85%) - decent performance")
        print("      ‚Üí Can improve with more data or fine-tuning")
    else:
        print("   üü¢ EXCELLENT: High accuracy (>85%) - great performance!")
        
    if eval_loss > 2.0:
        print("   ‚ö†Ô∏è High validation loss - model may be underfitting")
    elif eval_loss < 0.1:
        print("   ‚ö†Ô∏è Very low validation loss - check for overfitting")

# Extended prediction accuracy test
print(f"\nüéØ Extended Prediction Accuracy Test:")
test_size = min(50, len(dataset_df))  # Test on up to 50 samples
test_larger = dataset_df.sample(n=test_size, random_state=42)
correct = 0
total = len(test_larger)
cuisine_correct = {cuisine: 0 for cuisine in dataset_df['cuisine_category'].unique()}
cuisine_total = {cuisine: 0 for cuisine in dataset_df['cuisine_category'].unique()}

print(f"Testing on {total} random samples...")

for idx, row in test_larger.iterrows():
    true_label = row['cuisine_category']
    prediction = wrapped_model.predict(None, row['processed_image_data'])
    predicted_label = prediction['label']
    confidence = prediction['score']
    
    cuisine_total[true_label] += 1
    
    if true_label == predicted_label:
        correct += 1
        cuisine_correct[true_label] += 1
        status = "‚úÖ"
    else:
        status = "‚ùå"
    
    if idx < 10:  # Show first 10 predictions
        print(f"   {status} True: {true_label:<15} | Predicted: {predicted_label:<15} | Confidence: {confidence:.3f}")

# Overall accuracy
overall_accuracy = correct / total
print(f"\nüìà Overall Test Accuracy: {overall_accuracy:.1%} ({correct}/{total})")

# Per-class accuracy
print(f"\nüìä Per-Class Accuracy:")
for cuisine in sorted(cuisine_total.keys()):
    if cuisine_total[cuisine] > 0:
        class_acc = cuisine_correct[cuisine] / cuisine_total[cuisine]
        print(f"   {cuisine:<15}: {class_acc:.1%} ({cuisine_correct[cuisine]}/{cuisine_total[cuisine]})")
    else:
        print(f"   {cuisine:<15}: No samples in test set")

# Identify problematic classes
print(f"\nüö® Classes with Low Accuracy (<50%):")
problem_classes = []
for cuisine in cuisine_total.keys():
    if cuisine_total[cuisine] > 0:
        class_acc = cuisine_correct[cuisine] / cuisine_total[cuisine]
        if class_acc < 0.5:
            problem_classes.append(f"{cuisine} ({class_acc:.1%})")

if problem_classes:
    for problem in problem_classes:
        print(f"   ‚Ä¢ {problem}")
    print(f"\nüí° Focus improvement efforts on these classes!")
else:
    print("   üéâ All classes performing reasonably well!")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:741)
	at com.data

In [0]:
# Improvement Recommendations Based on Analysis
print("üöÄ Improvement Recommendations:")

# Get current metrics for recommendations
current_accuracy = eval_results.get('eval_accuracy', 0) if 'eval_results' in locals() else 0
dataset_size = len(dataset_df)
min_class_size = class_counts.min()
max_class_size = class_counts.max()

print(f"\nüìã Priority Actions (implement in order):")

# Priority 1: Data quantity issues
if dataset_size < 1000:
    print(f"   üî¥ CRITICAL - Collect more data:")
    print(f"      Current: {dataset_size} samples | Target: 1000+ samples")
    print(f"      Need: {1000 - dataset_size} more samples")

if min_class_size < 100:
    print(f"   üî¥ CRITICAL - Balance dataset:")
    print(f"      Smallest class: {min_class_size} samples | Target: 100+ per class")
    print(f"      Focus on collecting data for: {class_counts.idxmin()}")

# Priority 2: Training configuration
if current_accuracy < 0.6:
    print(f"   üü° HIGH - Improve training:")
    print(f"      ‚Ä¢ Increase epochs: {NUM_EPOCHS} ‚Üí 10-15 epochs")
    print(f"      ‚Ä¢ Increase learning rate: {LEARNING_RATE} ‚Üí 2e-4")
    print(f"      ‚Ä¢ Add data augmentation")
    
# Priority 3: Model improvements    
if imbalance_ratio > 3:
    print(f"   üü† MEDIUM - Address class imbalance:")
    print(f"      ‚Ä¢ Use class weights during training")
    print(f"      ‚Ä¢ Apply stratified sampling")
    print(f"      ‚Ä¢ Generate synthetic data for minority classes")

print(f"\nüîß Quick Fixes to Try Next:")
print(f"   1. Update configuration in cell 5:")
print(f"      NUM_EPOCHS = 10")
print(f"      BATCH_SIZE = 16  # if memory allows")
print(f"      LEARNING_RATE = 2e-4")

print(f"\n   2. Add data augmentation in cell 7:")
print(f"      from torchvision.transforms import RandomHorizontalFlip, ColorJitter")
print(f"      # Add to transforms: RandomHorizontalFlip(p=0.5), ColorJitter(...)")

print(f"\n   3. Consider using a different model:")
print(f"      MODEL_CHECKPOINT = 'google/vit-base-patch16-224'  # Vision Transformer")
print(f"      # or")
print(f"      MODEL_CHECKPOINT = 'microsoft/swin-tiny-patch4-window7-224'  # Swin Transformer")

# Expected improvement
print(f"\nüìà Expected Improvements:")
if dataset_size < 500:
    print(f"   ‚Ä¢ With 2-3x more data: +15-25% accuracy")
if NUM_EPOCHS == 3:
    print(f"   ‚Ä¢ With 10 epochs: +5-15% accuracy") 
if min_class_size < 50:
    print(f"   ‚Ä¢ With balanced classes: +10-20% accuracy")

print(f"\nüéØ Realistic Targets:")
if dataset_size < 500:
    print(f"   ‚Ä¢ Short term: 50-60% accuracy (with current data + better training)")
    print(f"   ‚Ä¢ Long term: 75-85% accuracy (with more balanced data)")
else:
    print(f"   ‚Ä¢ Short term: 65-75% accuracy (with better training)")
    print(f"   ‚Ä¢ Long term: 80-90% accuracy (with data augmentation and tuning)")

com.databricks.backend.common.rpc.CommandSkippedException
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3(SequenceExecutionState.scala:134)
	at com.databricks.spark.chauffeur.SequenceExecutionState.$anonfun$cancel$3$adapted(SequenceExecutionState.scala:129)
	at scala.collection.immutable.Range.foreach(Range.scala:190)
	at com.databricks.spark.chauffeur.SequenceExecutionState.cancel(SequenceExecutionState.scala:129)
	at com.databricks.spark.chauffeur.ExecContextState.cancelRunningSequence(ExecContextState.scala:715)
	at com.databricks.spark.chauffeur.ExecContextState.$anonfun$cancel$1(ExecContextState.scala:435)
	at scala.Option.getOrElse(Option.scala:201)
	at com.databricks.spark.chauffeur.ExecContextState.cancel(ExecContextState.scala:435)
	at com.databricks.spark.chauffeur.ExecutionContextManagerV1.cancelExecution(ExecutionContextManagerV1.scala:465)
	at com.databricks.spark.chauffeur.ChauffeurState.$anonfun$process$1(ChauffeurState.scala:741)
	at com.data