# SkinSight AI: Skin Lesion Classification Model (Image-Only)

This notebook implements a deep learning model for skin lesion classification using the ISIC dataset. The model uses only images and their corresponding diagnostic labels.

**Modifications for Improved Performance:**
- Enabled Mixed Precision Training for potential speed-up.
- Implemented Class Weights to address dataset imbalance.
- Enabled Fine-tuning of the pre-trained base model (EfficientNetB0).
- Uses AdamW optimizer (from `tf.keras.optimizers`).
- Added option for Custom Focal Loss.
- Included `ReduceLROnPlateau` learning rate scheduler.

In [None]:
!pip install tensorflow opencv-python pandas numpy matplotlib seaborn scikit-learn requests tqdm pillow

In [None]:
!pip install keras --upgrade

In [None]:
!mkdir -p data/images
!mkdir -p data/metadata
!mkdir -p models

In [None]:
# Assuming dataset is already downloaded. Provide download commands if needed.
# Example:
!wget https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_Input.zip
!wget https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_Metadata.csv -O ./data/metadata/ISIC_2019_Training_Metadata.csv
!wget https://isic-challenge-data.s3.amazonaws.com/2019/ISIC_2019_Training_GroundTruth.csv -O ./data/metadata/ISIC_2019_Training_GroundTruth.csv
print("If you downloaded the zip, uncomment the next cell to unzip")

In [None]:
!unzip -q ISIC_2019_Training_Input.zip -d ./data/images

## 1. Setup and Dependencies

In [None]:
# Basic data manipulation and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Image processing
import cv2
from PIL import Image

# Deep learning
import tensorflow as tf
from tensorflow.keras import layers, models, applications
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras import mixed_precision # For mixed precision training

# Machine learning utilities
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder # Only for label mapping if needed, not features
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
from sklearn.utils import class_weight # For calculating class weights

# Utilities
import os
import json
from tqdm.notebook import tqdm

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Enable Mixed Precision Training (if using compatible GPU: NVIDIA Volta, Turing, Ampere or newer)
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)
print('Mixed precision policy:', mixed_precision.global_policy())

# Custom Categorical Focal Loss Implementation
class CategoricalFocalCrossentropy(tf.keras.losses.Loss):
    def __init__(self, alpha=0.25, gamma=2.0, from_logits=False,
                 name='categorical_focal_crossentropy', **kwargs):
        super().__init__(name=name, **kwargs)
        self.alpha = alpha
        self.gamma = gamma
        self.from_logits = from_logits

    def call(self, y_true, y_pred):
        if self.from_logits:
            y_pred_softmax = tf.keras.activations.softmax(y_pred, axis=-1)
        else:
            y_pred_softmax = y_pred

        epsilon = tf.keras.backend.epsilon()
        y_pred_softmax = tf.clip_by_value(y_pred_softmax, epsilon, 1. - epsilon)

        cross_entropy = -y_true * tf.math.log(y_pred_softmax)
        focal_factor = tf.math.pow(1. - y_pred_softmax, self.gamma)
        focal_loss = self.alpha * focal_factor * cross_entropy
        focal_loss_summed = tf.reduce_sum(focal_loss, axis=-1)
        return focal_loss_summed

    def get_config(self):
        config = super().get_config()
        config.update({
            'alpha': self.alpha,
            'gamma': self.gamma,
            'from_logits': self.from_logits
        })
        return config

## 2. Configuration

In [None]:
# Model Configuration
CONFIG = {
    # Image parameters
    'IMG_SIZE': 224,
    'BATCH_SIZE': 32,

    # Training parameters
    'EPOCHS': 50,
    'FINETUNE_LEARNING_RATE': 1e-4,
    'EARLY_STOPPING_PATIENCE': 10,
    'REDUCE_LR_PATIENCE': 3,
    'REDUCE_LR_FACTOR': 0.2,

    # Model architecture
    'BASE_MODEL': 'EfficientNetB0',
    'DROPOUT_RATE': 0.4,
    'N_LAYERS_TO_UNFREEZE': 30,

    # Optimizer
    'OPTIMIZER': 'AdamW', # 'Adam' or 'AdamW'
    'WEIGHT_DECAY': 1e-5, # For AdamW

    # Loss Function
    'LOSS_FUNCTION': 'categorical_crossentropy', # 'categorical_crossentropy' or 'focal_loss'
    'FOCAL_LOSS_ALPHA': 0.25,
    'FOCAL_LOSS_GAMMA': 2.0,

    # Class weights (will be calculated automatically)
    'CLASS_WEIGHTS': None,

    # Paths
    'DATA_DIR': './data',
    'IMAGE_SUBDIR': 'images/ISIC_2019_Training_Input',
    'METADATA_SUBDIR': 'metadata',
    'MODEL_SAVE_PATH': './models/skin_lesion_model_image_only_finetuned.h5',
    'CHECKPOINT_DIR': './models/checkpoints'
}
os.makedirs(CONFIG['CHECKPOINT_DIR'], exist_ok=True)

CLASSES = [
    'melanoma',
    'nevus',
    'basal_cell_carcinoma',
    'actinic_keratosis',
    'benign_keratosis',
    'dermatofibroma',
    'vascular_lesions'
]

## 3. Load ISIC Dataset (Metadata for Labels and Image Names)

In [None]:
def load_isic_data_info():
    """Load ISIC dataset metadata and ground truth to get image names and labels."""
    metadata_path = os.path.join(CONFIG['DATA_DIR'], CONFIG['METADATA_SUBDIR'], 'ISIC_2019_Training_Metadata.csv')
    ground_truth_path = os.path.join(CONFIG['DATA_DIR'], CONFIG['METADATA_SUBDIR'], 'ISIC_2019_Training_GroundTruth.csv')

    if not os.path.exists(metadata_path) or not os.path.exists(ground_truth_path):
        raise FileNotFoundError(
            f"Metadata ({metadata_path}) or ground truth ({ground_truth_path}) CSV files not found. "
            "Please ensure they are downloaded and placed correctly."
        )

    metadata_df = pd.read_csv(metadata_path)
    ground_truth_df = pd.read_csv(ground_truth_path)

    # Ground truth has 'image' as name without extension, metadata has it too.
    # We'll add '.jpg' to the 'image' column from ground_truth_df for consistency if merging.
    # However, it's safer to merge and then form the .jpg name.

    # Merge based on 'image' (which is the image name without extension)
    # Only keep relevant columns from metadata_df if needed (e.g. if it had other info, but here it's mainly for linking)
    combined_df = pd.merge(metadata_df[['image']], ground_truth_df, on='image', how='inner')

    # Add .jpg extension to image names for file access
    combined_df['image_filename'] = combined_df['image'].astype(str) + '.jpg'

    class_mapping_from_gt_cols = {
        'MEL': 'melanoma',
        'NV': 'nevus',
        'BCC': 'basal_cell_carcinoma',
        'AK': 'actinic_keratosis',
        'BKL': 'benign_keratosis',
        'DF': 'dermatofibroma',
        'VASC': 'vascular_lesions'
    }

    disease_cols_from_gt = list(class_mapping_from_gt_cols.keys())
    present_disease_cols = [col for col in disease_cols_from_gt if col in combined_df.columns]

    if not present_disease_cols:
        raise ValueError("No target disease columns found in the combined data. Check CSV column names.")

    # Create 'diagnosis' column from one-hot encoded ground truth columns
    # Get the column name (e.g., 'MEL', 'NV') where the value is 1
    combined_df['diagnosis_short_code'] = combined_df[present_disease_cols].idxmax(axis=1)
    combined_df['diagnosis'] = combined_df['diagnosis_short_code'].map(class_mapping_from_gt_cols)

    # Filter out rows where diagnosis could not be determined from our target classes or if sum is not 1
    combined_df = combined_df.dropna(subset=['diagnosis'])
    combined_df = combined_df[combined_df[present_disease_cols].sum(axis=1) == 1]

    # Select only the necessary columns: image filename and diagnosis
    final_df = combined_df[['image_filename', 'diagnosis']].copy()
    final_df.rename(columns={'image_filename': 'image'}, inplace=True) # Rename to 'image' for generators

    print("\nUnique diagnoses mapped:", final_df['diagnosis'].unique())
    print("Diagnosis value counts (from loaded and mapped data):")
    print(final_df['diagnosis'].value_counts())

    return final_df

## 4. Data Preprocessing (Splitting and Image Validation)

In [None]:
def create_data_generators():
    """Create data generators with augmentation for training."""
    train_datagen = ImageDataGenerator(
        preprocessing_function=applications.efficientnet.preprocess_input,
        rotation_range=30,
        width_shift_range=0.25,
        height_shift_range=0.25,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        vertical_flip=True,
        brightness_range=[0.8, 1.2],
        channel_shift_range=20,
        fill_mode='nearest'
    )

    val_test_datagen = ImageDataGenerator(
        preprocessing_function=applications.efficientnet.preprocess_input
    )
    return train_datagen, val_test_datagen

def prepare_image_dataset_splits(df_images_labels):
    """Prepare the image dataset by validating images and creating train/val/test splits."""
    print('Preparing image dataset splits...')
    df = df_images_labels.copy()

    df['image'] = df['image'].astype(str)
    base_image_path = os.path.join(CONFIG['DATA_DIR'], CONFIG['IMAGE_SUBDIR'])
    df['full_image_path'] = df['image'].apply(lambda x: os.path.join(base_image_path, x))

    print('Validating image files...')
    valid_indices = []
    for index, row in tqdm(df.iterrows(), total=len(df), desc="Validating images"):
        if os.path.exists(row['full_image_path']):
            try:
                img = Image.open(row['full_image_path'])
                img.verify()
                # Check if image can be loaded by OpenCV (more robust check)
                cv_img = cv2.imread(row['full_image_path'])
                if cv_img is not None:
                    valid_indices.append(index)
                # else: print(f"Warning: OpenCV could not read {row['full_image_path']}")
            except Exception as e:
                # print(f'Warning: Corrupted image {row["full_image_path"]}: {e}')
                pass
        # else: print(f"Warning: Image file not found {row['full_image_path']}")

    df = df.loc[valid_indices].reset_index(drop=True)
    df = df[['image', 'diagnosis']] # Keep only essential columns

    print(f'\nTotal images after validation: {len(df)}')
    if len(df) == 0:
        raise ValueError("No valid images found. Please check image paths and integrity.")

    train_df, temp_df = train_test_split(
        df,
        test_size=0.3,
        stratify=df['diagnosis'],
        random_state=42
    )

    val_df, test_df = train_test_split(
        temp_df,
        test_size=0.5,
        stratify=temp_df['diagnosis'],
        random_state=42
    )

    print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}, Test samples: {len(test_df)}")
    return train_df, val_df, test_df

## 5. Dataset Info Loading Execution

In [None]:
df_images_labels = None # Initialize
try:
    print('Loading ISIC dataset info (image names and labels)...')
    df_images_labels = load_isic_data_info()
    print("\nOverall class distribution in the loaded and mapped dataset:")
    print(df_images_labels['diagnosis'].value_counts(normalize=True) * 100)
except Exception as e:
    print(f'Error loading dataset info: {str(e)}')


## 6. Dataset Splitting and Generator Creation Execution

In [None]:

try:
    train_df, val_df, test_df = prepare_image_dataset_splits(df_images_labels)

    train_datagen, val_test_datagen = create_data_generators()
    image_dir_path = os.path.join(CONFIG['DATA_DIR'], CONFIG['IMAGE_SUBDIR'])

    train_generator = train_datagen.flow_from_dataframe(
        dataframe=train_df,
        directory=image_dir_path,
        x_col='image',
        y_col='diagnosis',
        target_size=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical',
        classes=CLASSES,
        shuffle=True
    )

    val_generator = val_test_datagen.flow_from_dataframe(
        dataframe=val_df,
        directory=image_dir_path,
        x_col='image',
        y_col='diagnosis',
        target_size=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical',
        classes=CLASSES,
        shuffle=False
    )

    test_generator = val_test_datagen.flow_from_dataframe(
        dataframe=test_df,
        directory=image_dir_path,
        x_col='image',
        y_col='diagnosis',
        target_size=(CONFIG['IMG_SIZE'], CONFIG['IMG_SIZE']),
        batch_size=CONFIG['BATCH_SIZE'],
        class_mode='categorical',
        classes=CLASSES,
        shuffle=False
    )

except Exception as e:
    print(f'Error preparing dataset splits or generators: {str(e)}')

## 7. Dataset Statistics and Visualization

In [None]:

print('\nDataset Statistics (after splitting and validation):')
total_validated_images = len(train_df) + len(val_df) + len(test_df)
if total_validated_images > 0: # Avoid division by zero if all images were invalid
  print(f'Training samples: {len(train_df)} ({len(train_df)/total_validated_images*100:.2f}%)')
  print(f'Validation samples: {len(val_df)} ({len(val_df)/total_validated_images*100:.2f}%)')
  print(f'Test samples: {len(test_df)} ({len(test_df)/total_validated_images*100:.2f}%)')
else:
  print("No validated images to show statistics for.")

plt.figure(figsize=(12, 6))
sns.countplot(data=train_df, y='diagnosis', order=train_df['diagnosis'].value_counts().index, palette='viridis')
plt.title('Class Distribution in Training Set')
plt.xlabel('Number of Samples')
plt.ylabel('Class')
plt.tight_layout()
plt.show()

## 8. Display Sample Images

In [None]:

try:
    print('\nDisplaying sample augmented images from training set:')
    images, labels = next(train_generator)
    plt.figure(figsize=(15, 10))
    for i in range(min(8, CONFIG['BATCH_SIZE'], len(images))):
        plt.subplot(2, 4, i+1)
        display_image = (images[i] * 0.5 + 0.5)
        display_image = np.clip(display_image, 0, 1)
        plt.imshow(display_image)
        class_index = np.argmax(labels[i])
        plt.title(f'Class: {CLASSES[class_index]}')
        plt.axis('off')
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Could not display sample images: {e}")

## 9. Model Definition (Image-Only with Fine-tuning)

In [None]:
def build_model_image_only(num_classes, img_size, base_model_name, dropout_rate, n_layers_to_unfreeze):
    image_input = layers.Input(shape=(img_size, img_size, 3), name='image_input', dtype=tf.float32)

    if base_model_name == 'EfficientNetB0':
        base_model_func = applications.EfficientNetB0
    elif base_model_name == 'ResNet50':
        base_model_func = applications.ResNet50
    else:
        raise ValueError(f"Unsupported base model: {base_model_name}")

    base_model_instance = base_model_func(input_shape=(img_size, img_size, 3),
                                        include_top=False,
                                        weights='imagenet')

    if n_layers_to_unfreeze > 0:
        base_model_instance.trainable = True
        print(f"Unfreezing the top {n_layers_to_unfreeze} layers of {base_model_name}.")
        for layer in base_model_instance.layers[:-n_layers_to_unfreeze]:
            layer.trainable = False
        for layer in base_model_instance.layers: # Explicitly keep BN layers frozen during fine-tuning
            if isinstance(layer, layers.BatchNormalization):
                layer.trainable = False
    else:
        base_model_instance.trainable = False
        print(f"Keeping all layers of {base_model_name} frozen.")

    x = base_model_instance(image_input, training=(n_layers_to_unfreeze > 0 and base_model_instance.trainable))
    x = layers.GlobalAveragePooling2D(name='global_avg_pool')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate)(x)
    x = layers.Dense(128, activation='relu', name='dense_head_1')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(dropout_rate/2)(x)
    output = layers.Dense(num_classes, activation='softmax', name='output', dtype='float32')(x)

    model = models.Model(inputs=image_input, outputs=output)
    return model

## 10. Model Compilation

In [None]:

print("Building an image-only model with fine-tuning capabilities.")
model = build_model_image_only(
    num_classes=len(CLASSES),
    img_size=CONFIG['IMG_SIZE'],
    base_model_name=CONFIG['BASE_MODEL'],
    dropout_rate=CONFIG['DROPOUT_RATE'],
    n_layers_to_unfreeze=CONFIG['N_LAYERS_TO_UNFREEZE']
)

if CONFIG['OPTIMIZER'] == 'AdamW':
    optimizer = tf.keras.optimizers.AdamW(
        learning_rate=CONFIG['FINETUNE_LEARNING_RATE'],
        weight_decay=CONFIG['WEIGHT_DECAY']
    )
    print(f"Using tf.keras.optimizers.AdamW with LR={CONFIG['FINETUNE_LEARNING_RATE']} and Weight Decay={CONFIG['WEIGHT_DECAY']}")
else:
    optimizer = tf.keras.optimizers.Adam(learning_rate=CONFIG['FINETUNE_LEARNING_RATE'])
    print(f"Using Adam optimizer with LR={CONFIG['FINETUNE_LEARNING_RATE']}")

if CONFIG['LOSS_FUNCTION'] == 'focal_loss':
    loss_fn = CategoricalFocalCrossentropy(
        alpha=CONFIG['FOCAL_LOSS_ALPHA'],
        gamma=CONFIG['FOCAL_LOSS_GAMMA'],
        from_logits=False
    )
    print(f"Using Custom Categorical Focal Loss with alpha={CONFIG['FOCAL_LOSS_ALPHA']}, gamma={CONFIG['FOCAL_LOSS_GAMMA']}")
else:
    loss_fn = 'categorical_crossentropy'
    print(f"Using Categorical Crossentropy loss.")

model.compile(
    optimizer=optimizer,
    loss=loss_fn,
    metrics=[
        'accuracy',
        tf.keras.metrics.Precision(name='precision'),
        tf.keras.metrics.Recall(name='recall'),
        tf.keras.metrics.AUC(name='auc')
    ]
)
model.summary()

print("\nTrainable status of base model layers (first 5 and last 5 after unfreezing logic):")
base_model_layer_name = CONFIG['BASE_MODEL'].lower() # e.g., 'efficientnetb0'
# Keras might add a suffix like '_1' if model is rebuilt. Find robustly.
actual_base_model_layer_name = None
for layer in model.layers:
    if base_model_layer_name in layer.name.lower() and isinstance(layer, tf.keras.Model):
        actual_base_model_layer_name = layer.name
        break
if actual_base_model_layer_name:
    base_model_from_model = model.get_layer(actual_base_model_layer_name)
    for i, layer in enumerate(base_model_from_model.layers):
        if i < 5 or i >= len(base_model_from_model.layers) - 5:
            print(f"Layer: {layer.name}, Trainable: {layer.trainable}")
else:
    print(f"Could not find base model layer containing: {base_model_layer_name}")

## 11. Model Training

In [None]:

class_to_int_mapping = {classname: i for i, classname in enumerate(CLASSES)}
train_labels_int_mapped = train_df['diagnosis'].map(class_to_int_mapping).values

class_weights_calculated = class_weight.compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels_int_mapped),
    y=train_labels_int_mapped
)
CONFIG['CLASS_WEIGHTS'] = dict(enumerate(class_weights_calculated))
print("Calculated Class Weights:", CONFIG['CLASS_WEIGHTS'])

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=CONFIG['EARLY_STOPPING_PATIENCE'],
    restore_best_weights=True,
    verbose=1
)
model_checkpoint_callback = ModelCheckpoint(
    filepath=os.path.join(CONFIG['CHECKPOINT_DIR'], 'model_epoch_{epoch:02d}_valloss_{val_loss:.2f}.h5'),
    monitor='val_loss',
    save_best_only=True,
    mode='min',
    verbose=1
)
reduce_lr_on_plateau = ReduceLROnPlateau(
    monitor='val_loss',
    factor=CONFIG['REDUCE_LR_FACTOR'],
    patience=CONFIG['REDUCE_LR_PATIENCE'],
    min_lr=1e-7,
    verbose=1
)
callbacks_list = [early_stopping, model_checkpoint_callback, reduce_lr_on_plateau]

print("\nStarting model training...")
history = model.fit(
    train_generator,
    epochs=CONFIG['EPOCHS'],
    validation_data=val_generator,
    callbacks=callbacks_list,
    class_weight=CONFIG['CLASS_WEIGHTS'],
    steps_per_epoch=max(1, train_generator.samples // CONFIG['BATCH_SIZE']), # Ensure at least 1 step
    validation_steps=max(1, val_generator.samples // CONFIG['BATCH_SIZE'])
)

print(f"\nTraining finished. Saving final model to {CONFIG['MODEL_SAVE_PATH']}")
model.save(CONFIG['MODEL_SAVE_PATH'])

## 12. Model Evaluation

In [None]:

print('\nPlotting training history...')
# Plotting function
def plot_history(training_history):
    plt.figure(figsize=(14, 10))
    metrics_to_plot = ['accuracy', 'loss', 'precision', 'recall', 'auc']
    for i, metric in enumerate(metrics_to_plot):
        plt.subplot(3, 2, i + 1)
        if metric in training_history.history:
            plt.plot(training_history.history[metric], label=f'Train {metric.capitalize()}')
        if f'val_{metric}' in training_history.history:
            plt.plot(training_history.history[f'val_{metric}'], label=f'Val {metric.capitalize()}')
        plt.title(f'{metric.capitalize()} vs. Epochs')
        plt.xlabel('Epoch')
        plt.ylabel(metric.capitalize())
        plt.legend()
    plt.tight_layout()
    plt.show()

plot_history(history)

print('\nEvaluating model on test set...')
test_generator.reset()
eval_results = model.evaluate(test_generator,
                              steps=max(1, test_generator.samples // CONFIG['BATCH_SIZE'] + (1 if test_generator.samples % CONFIG['BATCH_SIZE'] else 0)),
                              verbose=1)
metric_names = model.metrics_names
for name, val in zip(metric_names, eval_results):
    print(f"Test {name}: {val:.4f}")

test_generator.reset()
y_pred_probs = model.predict(test_generator,
                              steps=max(1, test_generator.samples // CONFIG['BATCH_SIZE'] + (1 if test_generator.samples % CONFIG['BATCH_SIZE'] else 0)),
                              verbose=1)
y_pred = np.argmax(y_pred_probs, axis=1)
y_true = test_generator.classes

# Ensure y_pred from predict covers all samples in y_true from generator.classes
# This is important if steps didn't perfectly align with total samples, though it should for predict.
if len(y_pred) > len(y_true):
    y_pred = y_pred[:len(y_true)]
elif len(y_true) > len(y_pred): # This case should be less common if predict steps cover all data
    y_true = y_true[:len(y_pred)]
    y_pred_probs = y_pred_probs[:len(y_pred),:]

print('\nClassification Report:')
print(classification_report(y_true, y_pred, target_names=CLASSES, zero_division=0))

print('\nConfusion Matrix:')
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=CLASSES, yticklabels=CLASSES)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

print('\nROC Curves:')
plt.figure(figsize=(12, 10))
for i, class_name in enumerate(CLASSES):
    y_true_binarized_for_class = (y_true == i).astype(int)
    fpr, tpr, _ = roc_curve(y_true_binarized_for_class, y_pred_probs[:, i])
    roc_auc = auc(fpr, tpr)
    plt.plot(fpr, tpr, lw=2, label=f'{class_name} (AUC = {roc_auc:.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2, label='Random Classifier')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve for Each Class')
plt.legend(loc='lower right')
plt.grid(True)
plt.show()