Skip to content

Commit

Permalink
adding spurious correlation workflow.ipynb tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
allincowell committed Jun 24, 2024
1 parent e1d79cf commit 2b01057
Showing 1 changed file with 205 additions and 0 deletions.
205 changes: 205 additions & 0 deletions docs/source/tutorials/datalab/workflows.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1326,6 +1326,211 @@
"assert all(class_imbalance_issues.query(\"is_class_imbalance_issue\")[\"class_imbalance_score\"] == 0.02), \"Class imbalance issue scores are not as expected\"\n",
"assert all(class_imbalance_issues.query(\"not is_class_imbalance_issue\")[\"class_imbalance_score\"] == 1.0), \"Class imbalance issue scores are not as expected\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Find Spurious Correlation between Vision Dataset features and class labels\n",
"\n",
"In this section, we demonstrate how to identify spurious correlations in a vision dataset using the `cleanlab` library. Spurious correlations are unintended associations in the data that do not reflect the true underlying relationships, potentially leading to misleading model predictions and poor generalization.\n",
"\n",
"We will utilize the `Datalab` class from cleanlab with the `image_key` attribute to pinpoint vision-specific issues such as `dark_score`, `blurry_score`, `odd_aspect_ratio_score`, and more in the dataset. By analyzing these correlations, we can understand their impact on model performance and take steps to enhance the robustness and reliability of our machine learning models."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Load the dataset\n",
"\n",
"We will demonstrate this workflow using the CIFAR-10 dataset by selecting 100 images from two random classes. To illustrate the impact of spurious correlations between image features and class labels, we will showcase how altering all images of a class, such as darkening them, significantly reduces the `dark_score`. This demonstrates the strong correlation detection of darkness within the dataset.\n",
"\n",
"Similarly, we can observe significant reductions in `blurry_score` and `odd_aspect_ratio_score` when one of the classes contains images with corresponding characteristics such as blurriness or an unusual aspect ratio between width and height."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from cleanlab import Datalab\n",
"from torchvision.datasets import CIFAR10\n",
"from datasets import Dataset\n",
"import io\n",
"from PIL import Image, ImageEnhance\n",
"import random\n",
"import numpy as np\n",
"from IPython.display import display, Markdown\n",
"\n",
"# Download the CIFAR-10 test dataset\n",
"data = CIFAR10(root='./data', train=False, download=True)\n",
"\n",
"# Set seed for reproducibility\n",
"np.random.seed(0)\n",
"random.seed(0)\n",
"\n",
"# Randomly select two classes\n",
"classes = list(range(len(data.classes)))\n",
"selected_classes = random.sample(classes, 2)\n",
"\n",
"# Function to convert PIL object to PNG image to be passed to the Datalab object\n",
"def convert_to_png_image(image):\n",
" buffer = io.BytesIO()\n",
" image.save(buffer, format='PNG')\n",
" buffer.seek(0)\n",
" return Image.open(buffer)\n",
"\n",
"# Generating 100 ('max_num_images') images from each of the two chosen classes\n",
"max_num_images = 100\n",
"list_images, list_labels = [], []\n",
"num_images = {selected_classes[0]: 0, selected_classes[1]: 0}\n",
"\n",
"for img, label in data:\n",
" if num_images[selected_classes[0]] == max_num_images and num_images[selected_classes[1]] == max_num_images:\n",
" break\n",
" if label in selected_classes:\n",
" if num_images[label] == max_num_images:\n",
" continue\n",
" list_images.append(convert_to_png_image(img))\n",
" list_labels.append(label)\n",
" num_images[label] += 1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Creating `Dataset` object to be passed to the `Datalab` object to find vision-related issues"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a datasets.Dataset object from list of images and their corresponding labels\n",
"dataset_dict = {'image': list_images, 'label': list_labels}\n",
"dataset = Dataset.from_dict(dataset_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. (Optional) Creating a transformed dataset using `ImageEnhance` to induce darkness"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to reduce brightness to 30%\n",
"def apply_dark(image):\n",
" \"\"\"Decreases brightness of the image.\"\"\"\n",
" enhancer = ImageEnhance.Brightness(image)\n",
" return enhancer.enhance(0.3)\n",
"\n",
"# Applying the darkness filter to one of the classes\n",
"transformed_list_images = [\n",
" apply_dark(img) if label == selected_classes[0] else img\n",
" for label, img in zip(list_labels, list_images)\n",
"]\n",
"\n",
"# Creating datasets.Dataset object from the transformed dataset\n",
"transformed_dataset_dict = {'image': transformed_list_images, 'label': list_labels}\n",
"transformed_dataset = Dataset.from_dict(transformed_dataset_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. (Optional) Visualizing Images in the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"\n",
"def plot_images(dataset_dict):\n",
" \"\"\"Plots the first 15 images from the dataset dictionary.\"\"\"\n",
" images = dataset_dict['image']\n",
" labels = dataset_dict['label']\n",
" \n",
" # Define the number of images to plot\n",
" num_images_to_plot = 15\n",
" num_cols = 5 # Number of columns in the plot grid\n",
" num_rows = (num_images_to_plot + num_cols - 1) // num_cols # Calculate rows needed\n",
" \n",
" # Create a figure\n",
" fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 6))\n",
" axes = axes.flatten()\n",
" \n",
" # Plot each image\n",
" for i in range(num_images_to_plot):\n",
" img = images[i]\n",
" label = labels[i]\n",
" axes[i].imshow(img)\n",
" axes[i].set_title(f'Label: {label}')\n",
" axes[i].axis('off')\n",
" \n",
" # Hide any remaining empty subplots\n",
" for i in range(num_images_to_plot, len(axes)):\n",
" axes[i].axis('off')\n",
" \n",
" # Show the plot\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"plot_images(dataset_dict)\n",
"plot_images(transformed_dataset_dict)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Finding image-specific property scores"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to find image-specific property scores given the dataset object\n",
"def get_property_scores(dataset):\n",
" lab = Datalab(data=dataset, label_name=\"label\", image_key=\"image\")\n",
" lab.find_issues()\n",
" return lab._spurious_correlation()\n",
"\n",
"# Finds specific property score in the dataframe containing property scores \n",
"def get_specific_property_score(property_scores_df, property_name):\n",
" return property_scores_df[property_scores_df['property'] == property_name]['score'].iloc[0]\n",
"\n",
"# Finding scores in original and transformed dataset\n",
"standard_property_scores = get_property_scores(dataset)\n",
"transformed_property_scores = get_property_scores(transformed_dataset)\n",
"\n",
"# Displaying the scores dataframe\n",
"display(Markdown(\"### Vision-specific property scores in the original dataset\"))\n",
"display(standard_property_scores)\n",
"display(Markdown(\"### Vision-specific property scores in the transformed dataset\"))\n",
"display(transformed_property_scores)\n",
"\n",
"# Smaller 'dark_score' value for modified dataframe shows strong correlation with the class labels in the transformed dataset\n",
"assert get_specific_property_score(standard_property_scores, 'dark_score') > get_specific_property_score(transformed_property_scores, 'dark_score')"
]
}
],
"metadata": {
Expand Down

0 comments on commit 2b01057

Please sign in to comment.