### Rule-Based Explanation

In [1]:
import pathlib
import os
import sys
from pathlib import Path
import tarfile
import random

# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import numpy as np

import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torchvision.transforms as T
from tqdm import tqdm

from PIL import Image

torch.backends.cudnn.benchmark = True

  from .autonotebook import tqdm as notebook_tqdm


# Params

In [2]:
CLASSES = ['Healthy', 'OC Degeneration']

load_checkpoints = True

modeltype = 'mobilenet'

ds = 'sick_ones_bendbias_v3_2class_normal'
eval_ds = 'sick_ones_bendbias_v3_2class_variation'

# Setup and Load Datasets

In [3]:
relative_model_path = "two4two_sickones_models_pytorch"
base_path = Path('./') / relative_model_path
base_path

PosixPath('two4two_sickones_models_pytorch')

In [4]:
def load_dataframe(data_dir, dataset):
  data_dir = data_dir / dataset
  df = pd.read_json(data_dir / 'parameters.jsonl', lines=True)
  df['filename'] = df['id'] + '.png'
  df['ill'] = df['ill'].astype(int).astype(str)

  return df

class ImageDataset(Dataset):
    def __init__(self, df, data_dir, transform=None):
        self.df = df
        self.data_dir = data_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path = os.path.join(str(self.data_dir), str(self.df.iloc[idx]['filename'])) ## Added str 
        image = Image.open(img_path).convert('RGB')
        label = int(self.df.iloc[idx]['ill'])

        if self.transform:
            image = self.transform(image)

        return image, label

## Load Dataset and Dataloaders


In [5]:

def download_file(url, file_name, cache_dir="data", extract=True, force_download=False, archive_folder=None):
    # Ensure the cache directory exists
    os.makedirs(cache_dir, exist_ok=True)
    file_path = os.path.join(cache_dir, file_name)

    # Download the file
    if not os.path.exists(file_path) or force_download:
      torch.hub.download_url_to_file(url, file_path)
      print(f"File downloaded to: {file_path}")
    else:
      print(f"File already exists at: {file_path}")

    if extract:
      with tarfile.open(file_path, "r:gz") as tar:
          tar.extractall(path=cache_dir)
      print(f"File extracted to: {cache_dir}")
      return Path(cache_dir) / archive_folder if archive_folder is not None else Path(cache_dir)
    elif archive_folder is not None and (Path(cache_dir) / archive_folder).exsists:
      return Path(cache_dir) / archive_folder
    else:
      return Path(cache_dir)

    return Path(file_path)

In [6]:
data_dir = download_file("https://uni-bielefeld.sciebo.de/s/2BgY19ixIaEUOmS/download",
                         "two4two_datasets.tar.gz",
                         cache_dir='data',
                         extract=True,
                         force_download=False,
                         archive_folder='two4two_datasets')

ds_dir = data_dir / ds
eval_ds_dir = data_dir / eval_ds
ds_dir, eval_ds_dir

File already exists at: data/two4two_datasets.tar.gz
File extracted to: data


(PosixPath('data/two4two_datasets/sick_ones_bendbias_v3_2class_normal'),
 PosixPath('data/two4two_datasets/sick_ones_bendbias_v3_2class_variation'))

In [7]:
train_df = load_dataframe(ds_dir, 'train')
train_transforms = T.Compose([
    T.ToTensor()
])
train_dataset = ImageDataset(train_df, ds_dir / 'train', transform=train_transforms)
dataloader = DataLoader(train_dataset, batch_size=100, shuffle=True,
                        num_workers=6, pin_memory=True)

# Initialize variables to calculate mean
mean = torch.zeros(3)  # For RGB channels
total_pixels = 0

# Loop through the dataset
for images, _ in tqdm(dataloader):
    # Sum pixel values per channel
    mean += images.sum(dim=[0, 2, 3])
    total_pixels += images.size(0) * images.size(2) * images.size(3)

# Divide by total number of pixels
mean /= total_pixels

print(f"Mean per channel: {mean}")

# Initialize variables for std calculation
std = torch.zeros(3)

# Loop again for standard deviation
for images, _ in tqdm(dataloader):
    std += ((images - mean.view(1, 3, 1, 1))**2).sum(dim=[0, 2, 3])

std = torch.sqrt(std / total_pixels)

print(f"Standard Deviation per channel: {std}")

100%|██████████| 400/400 [00:20<00:00, 19.48it/s]


Mean per channel: tensor([0.8068, 0.7830, 0.8005])


100%|██████████| 400/400 [00:14<00:00, 28.00it/s]

Standard Deviation per channel: tensor([0.1093, 0.1136, 0.1029])





In [8]:
train_df = load_dataframe(ds_dir, 'train')
val_df = load_dataframe(ds_dir, 'validation')
test_df = load_dataframe(ds_dir, 'test')
eval_df = load_dataframe(eval_ds_dir, 'test')

In [9]:
len(val_df), len(test_df), len(eval_df), len(train_df)

(1000, 3000, 3000, 40000)

In [10]:
transform = T.Compose([
    T.ToTensor(),
    T.Normalize(mean=mean, std=std)
])

In [11]:
test_df['filename'] = test_df['filename'].astype(str).str.strip()
for i, fname in enumerate(test_df['filename']):
    if '\n' in fname or ' ' in fname:
        print(f"[WARNING] Bad filename in row {i}: {repr(fname)}")

In [12]:
train_dataset = ImageDataset(train_df, ds_dir / 'train', transform=transform)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True,
                              num_workers=6, pin_memory=True)

train_eval_dataset = ImageDataset(train_df, ds_dir / 'train', transform=transform)
train_eval_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=False,
                                   num_workers=6, pin_memory=True)

val_dataset = ImageDataset(val_df,  ds_dir / 'validation', transform=transform)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False,
                            num_workers=6, pin_memory=True)

test_dataset = ImageDataset(test_df,  ds_dir / 'test', transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                             num_workers=6, pin_memory=True)

eval_dataset = ImageDataset(eval_df,  eval_ds_dir / 'test', transform=transform)
eval_dataloader = DataLoader(eval_dataset, batch_size=32, shuffle=False,
                             num_workers=6, pin_memory=True)

In [13]:
data_ex = next(iter(train_dataloader))
data_ex[0].shape, data_ex[1].shape

(torch.Size([32, 3, 128, 128]), torch.Size([32]))

## Analysis Dataset

In [14]:
# create column for absolute sphere difference
train_df['sphere_diff'] = np.abs(train_df['spherical'] - train_df['ill_spherical'])
val_df['sphere_diff'] = np.abs(val_df['spherical'] - val_df['ill_spherical'])
test_df['sphere_diff'] = np.abs(test_df['spherical'] - test_df['ill_spherical'])
eval_df['sphere_diff'] = np.abs(eval_df['spherical'] - eval_df['ill_spherical'])

# Model Training and Evaluation

In [15]:
def load_resnet50(num_classes, pretrained=True, checkpoint_path=None):
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT if pretrained else None)
    model.fc = nn.Linear(model.fc.in_features, num_classes)  # Replace final fully-connected layer

    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint)
        print(f"Loaded checkpoint from: {checkpoint_path}")

    return model

In [16]:
# setup model path
model_path = base_path / ds / f'{modeltype}'
model_path.mkdir(parents=True, exist_ok=True)
print("Model path:", model_path)

Model path: two4two_sickones_models_pytorch/sick_ones_bendbias_v3_2class_normal/mobilenet


In [17]:
# setup checkpoint folders
checkpoint_path = model_path / "torch_resnet50/"
(checkpoint_path / 'tmp').mkdir(parents=True, exist_ok=True)
(checkpoint_path / 'final').mkdir(parents=True, exist_ok=True)

In [18]:
# Define loss function
criterion = nn.CrossEntropyLoss()
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [19]:
device

device(type='cuda')

In [20]:
# load best model and evaluate

model = load_resnet50(num_classes=len(CLASSES),
                         pretrained=False,
                         checkpoint_path=checkpoint_path / 'final' / 'best_model.pth')
model.to(device)

Loaded checkpoint from: two4two_sickones_models_pytorch/sick_ones_bendbias_v3_2class_normal/mobilenet/torch_resnet50/final/best_model.pth


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [21]:
# Initialize a list to store predictions
model_predictions = []

# Disable gradient computation for inference
with torch.no_grad():
    for images, _ in tqdm(test_dataloader):
        # Move images to the same device as the model
        images = images.to(device)
        
        # Get model outputs
        outputs = model(images)
        
        # Get predicted class (highest probability)
        _, predicted = torch.max(outputs, 1)
        
        # Append predictions to the list
        model_predictions.extend(predicted.cpu().numpy())

# Convert predictions to a NumPy array
model_predictions = np.array(model_predictions)

  return F.conv2d(input, weight, bias, self.stride,
100%|██████████| 94/94 [00:12<00:00,  7.42it/s]


### Anchor Explanation

In [21]:
import torch
import torchvision.transforms as T
import torchvision.models as models
import matplotlib.pyplot as plt
import numpy as np
from alibi.datasets import load_cats
from alibi.explainers import AnchorImage
from skimage.segmentation import slic
from torchvision.models import inception_v3
import torch.nn.functional as F

In [22]:
from sklearn.utils import Bunch
from typing import Tuple, Union
import torch

def load_blockies(dataloader: torch.utils.data.DataLoader,
                               image_shape: tuple = (299, 299, 3),
                               return_X_y: bool = False
                              ) -> Union[Bunch, Tuple[np.ndarray, np.ndarray]]:
    """
    Load the full dataset from a DataLoader, similar to alibi.datasets.load_cats.
    """
    X_list, y_list = [], []
    target_size = image_shape[:2]  # (height, width)

    for images, labels in dataloader:
        # Resize images
        images = F.interpolate(images, size=target_size, mode='bilinear', align_corners=False)
        X_list.append(images.cpu())
        y_list.append(labels.cpu())

    X = torch.cat(X_list, dim=0).permute(0, 2, 3, 1).numpy()  # NCHW -> NHWC
    y = torch.cat(y_list, dim=0).numpy()

    if return_X_y:
        return X, y
    else:
        return Bunch(data=X, target=y)



data, labels = load_blockies(eval_dataloader, image_shape=(299, 299, 3), return_X_y=True)


In [None]:
transform = T.Compose([
    T.ToTensor(),
    T.Resize((299, 299)),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
images = torch.stack([transform(img) for img in data]).to(device)

# Prediction
with torch.no_grad():
    preds = model(images)
top3 = torch.topk(F.softmax(preds, dim=1), 3)
print(top3.indices[0].cpu().numpy())

# Wrapper
def predict_fn(x):
    x = torch.tensor(x).permute(0, 3, 1, 2).float()  # NHWC -> NCHW
    x = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(x/255.0)
    x = x.to(device)
    with torch.no_grad():
        preds = model(x)
    return preds.cpu().numpy()

In [None]:
# Explainer
image_shape = (299, 299, 3)
segmentation_fn = 'slic'
kwargs = {'n_segments': 15, 'compactness': 20, 'sigma': 0.5, 'start_label': 0}
explainer = AnchorImage(predict_fn, image_shape, segmentation_fn=segmentation_fn,
                        segmentation_kwargs=kwargs, images_background=None)

In [None]:
# Explanation
i = 0
plt.imshow(data[i])

image = data[i]  # raw image
np.random.seed(0)
explanation = explainer.explain(image, threshold=0.95, p_sample=0.5, tau=0.25)

In [None]:
# Visualization
plt.figure()
plt.imshow(explanation.anchor)
plt.figure()
plt.imshow(explanation.segments)