LIME applied to ResNet-18 image dataset

In [1]:
from google.colab import drive
drive.mount('/content/drive')
import zipfile
import os
!pip install lime

Mounted at /content/drive
Collecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lime
  Building wheel for lime (setup.py) ... [?25l[?25hdone
  Created wheel for lime: filename=lime-0.2.0.1-py3-none-any.whl size=283834 sha256=b4df7158c8441c43697f29acac299712f152cff7292823c1c2d1e5d628ef61aa
  Stored in directory: /root/.cache/pip/wheels/85/fa/a3/9c2d44c9f3cd77cf4e533b58900b2bf4487f2a17e8ec212a3d
Successfully built lime
Installing collected packages: lime
Successfully installed lime-0.2.0.1


In [2]:
# Define your dataset path
TESTE_femur_cesarean_test = "/content/drive/MyDrive/TESE/Project/TEST_cv1_femur_cesarean"
TESTE_femur_cesarean_train = "/content/drive/MyDrive/TESE/Project/TRAIN_cv1_femur_cesarean"

TESTE_femur_vaginal_test = "/content/drive/MyDrive/TESE/Project/TEST_cv1_femur_vaginal"
TESTE_femur_vaginal_train = "/content/drive/MyDrive/TESE/Project/TRAIN_cv1_femur_vaginal"

In [3]:
# Check if directory exists
if os.path.exists(TESTE_femur_cesarean_train):
    print(f"Dataset found at {TESTE_femur_cesarean_train}")
    print("Listing files in dataset directory:\n")
    print(os.listdir(TESTE_femur_cesarean_train))  # List files and folders
else:
    print(f"Dataset not found at {TESTE_femur_cesarean_train}")

Dataset found at /content/drive/MyDrive/TESE/Project/TRAIN_cv1_femur_cesarean
Listing files in dataset directory:

['PU3002681_femur1_802.png', 'PU1003661_femur1_4.png', 'PU2004178_femur1_85.png', 'PU2000870_femur1_711.png', 'PU7002074_femur1_1017.png']


In [4]:
from lime import lime_image
from skimage.segmentation import mark_boundaries
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import torch
from torchvision import transforms, models
import torch.nn as nn

In [5]:
# Define device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
# Load base ResNet-18 model
model = models.resnet18(weights=None)

# Modify first convolutional layer for grayscale input
model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

# Modify the final fully connected layer for binary classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)

# Move the model to the correct device (CPU or GPU)
model.to(device)

# Load the saved model weights
model_path = "/content/drive/MyDrive/TESE/Project/best-model_cv1_femur_resnet18.pth"
model.load_state_dict(torch.load(model_path, map_location=device))

# Set the model to evaluation mode (important for inference)
model.eval()

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [7]:
# Preprocessing (grayscale, 80x80)
def image_to_tensor(img):
    preprocess = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.Resize((80, 80)),
        transforms.ToTensor()
    ])
    return preprocess(img)

In [8]:
# Batch prediction wrapper for LIME

def batch_predict(imgs):
    model.eval()
    batch = torch.stack([image_to_tensor(Image.fromarray(img).convert("L")) for img in imgs])
    batch = batch.to(device)

    with torch.no_grad():
        logits = model(batch)
        probs = torch.nn.functional.softmax(logits, dim=1)

    np_probs = probs.cpu().numpy()

    # Print or save to CSV/JSON
    print("Perturbation probabilities:", np_probs)

    return np_probs

In [9]:
# LIME explanation function

def calculate_lime(img_path, label_to_explain, num_features=5, num_samples=1000): #1000 perturbation samples/image
    explainer = lime_image.LimeImageExplainer()

    img_np = np.array(Image.open(img_path).convert("L"))

    explanation = explainer.explain_instance(
        img_np,
        batch_predict,
        top_labels=1,
        labels=(label_to_explain,),
        num_features=num_features, # how many superpixels to highlight in the explanation
        hide_color=0,
        num_samples=num_samples,
        random_seed=42
    )
    return explanation

In [10]:
# Visualization

def plot_images_lime(paths, class_names, num_features=5, num_samples=1000, columns=4, rows=3):
    fig = plt.figure(figsize=(16, 10))
    end = min(columns * rows + 1, len(paths) + 1)

    for i in range(1, end):
        explanation = calculate_lime(paths[i-1], num_features, num_samples)

        predicted_class = explanation.top_labels[0]
        temp, mask = explanation.get_image_and_mask(
            predicted_class,
            positive_only=True,
            num_features=num_features,
            hide_rest=True
        )

        overlay = mark_boundaries(temp / 255.0, mask)
        pred_class_name = class_names[predicted_class]

        ax = fig.add_subplot(rows, columns, i)
        ax.imshow(overlay)
        ax.set_title(f"Predicted: {pred_class_name}", fontsize=10)
        ax.axis('off')

    plt.tight_layout()
    plt.show()

In [11]:
# Save LIME Overlayed Images

def save_lime_outputs(image_paths, output_folder, class_names, num_features=5, num_samples=1000):
    os.makedirs(output_folder, exist_ok=True)

    for img_path in image_paths:
        img_name = os.path.basename(img_path)
        image = Image.open(img_path).convert("L")
        image_np = np.array(image)

        # Get LIME explanation using the model's prediction
        explanation = lime_image.LimeImageExplainer().explain_instance(
            image_np,
            batch_predict,
            top_labels=1, #select the top predicted class
            num_features=num_features, # how many superpixels to highlight in the explanation
            hide_color=0,# when hiding superpixels, replace them with black
            num_samples=num_samples, # how many perturbed versions of the image
            random_seed=42
        )

        pred_class = explanation.top_labels[0]
        pred_class_name = class_names[pred_class]

        # Generate overlay
        temp, mask = explanation.get_image_and_mask(
            pred_class,
            positive_only=True,
            num_features=num_features,
            hide_rest=False
        )

        overlay = mark_boundaries(temp / 255.0, mask)
        overlay_bgr = (overlay * 255).astype(np.uint8)
        overlay_bgr = cv2.cvtColor(overlay_bgr, cv2.COLOR_RGB2BGR)

        # Save image
        output_name = f"lime_{pred_class_name}_{img_name}"
        output_path = os.path.join(output_folder, output_name)
        cv2.imwrite(output_path, overlay_bgr)

        print(f"[LIME] Saved → {output_name}")


In [12]:
# Input and output folders
input_folder = TESTE_femur_vaginal_train
output_folder = "/content/drive/MyDrive/TESE/Project/LIME_cv1_femur_train_vaginalPROBS"


os.makedirs(output_folder, exist_ok=True)

# Define the create_path_list function to get image paths
def create_path_list(folder):
    """Creates a list of image paths from a given folder."""
    return [os.path.join(folder, filename) for filename in os.listdir(folder) if filename.endswith(('.png', '.jpg', '.jpeg'))]

# Load image paths
image_paths = create_path_list(input_folder)

# Run the LIME saving process
save_lime_outputs(
    image_paths=image_paths,
    output_folder=output_folder,
    class_names=["Cesarean Birth", "Vaginal Birth"],
    num_features=5,
    num_samples=1000
)

  0%|          | 0/1000 [00:00<?, ?it/s]

Perturbation probabilities: [[0.05684144 0.94315857]
 [0.17807236 0.8219276 ]
 [0.08697952 0.91302055]
 [0.085535   0.914465  ]
 [0.13137056 0.8686294 ]
 [0.22042185 0.77957815]
 [0.16153961 0.8384603 ]
 [0.22833672 0.7716633 ]
 [0.06141793 0.93858206]
 [0.13799037 0.86200964]]
Perturbation probabilities: [[0.13354677 0.8664533 ]
 [0.08233905 0.917661  ]
 [0.11121108 0.8887889 ]
 [0.10074715 0.89925283]
 [0.12324786 0.8767521 ]
 [0.10097037 0.8990296 ]
 [0.09188507 0.9081149 ]
 [0.04619442 0.9538055 ]
 [0.18886998 0.81113005]
 [0.3034887  0.6965113 ]]
Perturbation probabilities: [[0.34885806 0.65114194]
 [0.26021534 0.7397846 ]
 [0.12050582 0.87949413]
 [0.04895175 0.95104825]
 [0.1706701  0.8293299 ]
 [0.18254465 0.8174554 ]
 [0.13405052 0.8659495 ]
 [0.16899662 0.8310033 ]
 [0.05927887 0.94072115]
 [0.06493101 0.935069  ]]
Perturbation probabilities: [[0.22146994 0.77853   ]
 [0.09062378 0.9093762 ]
 [0.10312114 0.89687884]
 [0.18233068 0.81766933]
 [0.2141098  0.7858902 ]
 [0.054552

  0%|          | 0/1000 [00:00<?, ?it/s]

Perturbation probabilities: [[0.10003983 0.89996016]
 [0.05902371 0.9409763 ]
 [0.12314241 0.8768576 ]
 [0.08434531 0.91565466]
 [0.1640956  0.83590436]
 [0.07361082 0.9263892 ]
 [0.05398247 0.9460175 ]
 [0.11565912 0.8843409 ]
 [0.1875901  0.8124099 ]
 [0.10516619 0.8948338 ]]
Perturbation probabilities: [[0.14649428 0.8535058 ]
 [0.09004719 0.90995276]
 [0.09829082 0.9017092 ]
 [0.15840612 0.84159386]
 [0.15593225 0.84406775]
 [0.14736663 0.8526334 ]
 [0.08397073 0.91602933]
 [0.20724006 0.7927599 ]
 [0.18418488 0.81581515]
 [0.09557641 0.90442353]]
Perturbation probabilities: [[0.0851552  0.9148448 ]
 [0.07705485 0.92294514]
 [0.09448804 0.9055119 ]
 [0.10522656 0.8947734 ]
 [0.14230031 0.85769963]
 [0.06630916 0.93369085]
 [0.13345397 0.8665461 ]
 [0.07593336 0.92406666]
 [0.1496248  0.85037524]
 [0.08810107 0.9118989 ]]
Perturbation probabilities: [[0.10615153 0.8938485 ]
 [0.0707101  0.92928994]
 [0.12554929 0.8744507 ]
 [0.06436403 0.935636  ]
 [0.09294876 0.90705127]
 [0.141806

  0%|          | 0/1000 [00:00<?, ?it/s]

Perturbation probabilities: [[0.01503432 0.9849657 ]
 [0.07566646 0.9243336 ]
 [0.04449498 0.955505  ]
 [0.06720439 0.9327955 ]
 [0.04208415 0.95791584]
 [0.11594214 0.8840579 ]
 [0.07178959 0.92821044]
 [0.07229    0.92771   ]
 [0.10967791 0.89032215]
 [0.11645386 0.8835462 ]]
Perturbation probabilities: [[0.11768028 0.88231975]
 [0.08871762 0.9112824 ]
 [0.0543988  0.9456012 ]
 [0.07809135 0.9219086 ]
 [0.09067152 0.90932846]
 [0.08107477 0.9189253 ]
 [0.1010798  0.8989202 ]
 [0.043359   0.956641  ]
 [0.10251712 0.8974829 ]
 [0.1804967  0.8195033 ]]
Perturbation probabilities: [[0.23321228 0.76678765]
 [0.11925617 0.88074386]
 [0.04239617 0.9576039 ]
 [0.09254323 0.9074567 ]
 [0.07153433 0.9284656 ]
 [0.13314456 0.86685544]
 [0.06375065 0.9362494 ]
 [0.0883073  0.91169274]
 [0.10793193 0.8920681 ]
 [0.14354183 0.8564582 ]]
Perturbation probabilities: [[0.12302691 0.8769731 ]
 [0.07103609 0.9289639 ]
 [0.0925492  0.90745085]
 [0.16610545 0.83389455]
 [0.07594337 0.9240566 ]
 [0.090974

  0%|          | 0/1000 [00:00<?, ?it/s]

Perturbation probabilities: [[0.05108616 0.9489138 ]
 [0.09452842 0.90547156]
 [0.09280409 0.90719587]
 [0.11614529 0.88385475]
 [0.09297341 0.9070266 ]
 [0.10266788 0.89733213]
 [0.24204174 0.7579583 ]
 [0.2101205  0.7898795 ]
 [0.13957126 0.86042875]
 [0.12440449 0.8755955 ]]
Perturbation probabilities: [[0.12644805 0.8735519 ]
 [0.11493052 0.8850695 ]
 [0.23498045 0.7650196 ]
 [0.22927111 0.7707289 ]
 [0.07103945 0.9289605 ]
 [0.1296983  0.87030166]
 [0.11734046 0.8826595 ]
 [0.3032802  0.69671977]
 [0.14553057 0.8544694 ]
 [0.18165979 0.8183402 ]]
Perturbation probabilities: [[0.12474068 0.87525934]
 [0.06727149 0.9327285 ]
 [0.10894009 0.8910599 ]
 [0.14780235 0.8521977 ]
 [0.23266593 0.76733404]
 [0.17863499 0.821365  ]
 [0.10163959 0.8983604 ]
 [0.11847205 0.88152796]
 [0.17089382 0.8291062 ]
 [0.11131571 0.8886843 ]]
Perturbation probabilities: [[0.10946792 0.8905321 ]
 [0.1000367  0.89996326]
 [0.0767604  0.9232395 ]
 [0.1571508  0.84284914]
 [0.1625266  0.83747345]
 [0.171325

  0%|          | 0/1000 [00:00<?, ?it/s]

Perturbation probabilities: [[0.17007048 0.82992953]
 [0.09215036 0.90784967]
 [0.30515423 0.6948458 ]
 [0.22045812 0.77954185]
 [0.1284978  0.8715022 ]
 [0.17923784 0.82076216]
 [0.28895092 0.7110491 ]
 [0.13597488 0.8640251 ]
 [0.18094216 0.8190579 ]
 [0.223918   0.776082  ]]
Perturbation probabilities: [[0.1909616  0.80903834]
 [0.2298806  0.7701194 ]
 [0.12918296 0.870817  ]
 [0.16353737 0.8364627 ]
 [0.21512346 0.7848765 ]
 [0.22374922 0.77625084]
 [0.17978306 0.82021695]
 [0.19019017 0.80980986]
 [0.3045494  0.6954506 ]
 [0.20854558 0.79145443]]
Perturbation probabilities: [[0.19426139 0.80573857]
 [0.21563642 0.7843636 ]
 [0.20115831 0.7988417 ]
 [0.21948056 0.7805195 ]
 [0.28689155 0.7131085 ]
 [0.42254826 0.57745177]
 [0.3661379  0.63386214]
 [0.3367017  0.6632983 ]
 [0.18300362 0.81699634]
 [0.21168953 0.78831047]]
Perturbation probabilities: [[0.11974648 0.8802535 ]
 [0.26883098 0.731169  ]
 [0.20000997 0.79999   ]
 [0.27514657 0.7248534 ]
 [0.26663828 0.7333617 ]
 [0.230990