In [6]:
try:
    # Comment out if not using colab
    from google.colab import drive
    drive.mount('/content/drive')

    # Specific for luca's computer
    %cd "/content/drive/Othercomputers/Min MacBook Pro/INFO381-GitHub"
    using_colab = True
except:
    print("Not using Google Colab")
    using_colab = False

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/Othercomputers/Min MacBook Pro/INFO381-GitHub


In [7]:
!pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-1huskbnd
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-1huskbnd
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [22]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
from torchvision import models
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder

import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import torch.nn.functional as F
import clip
from skimage.transform import resize
from PIL import Image

import os


from utils import get_dataloaders, cherry_pick_img_real, cherry_pick_img_ai_generated
from model_definitions import CLIPClassifier

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
train_loader, test_loader = get_dataloaders(batch_size=32, split = 'both')
print(test_loader)

Running in Google Colab
<torch.utils.data.dataloader.DataLoader object at 0x7fd3e77e2bd0>


**Load CLIP preprocessing and CNN Transform**

In [11]:
clip_model, preprocess_clip = clip.load("ViT-B/32", device=device)

cnn_transform = transforms.Compose([
    transforms.Resize((512, 512)),
    transforms.ToTensor()
])

100%|███████████████████████████████████████| 338M/338M [00:04<00:00, 81.1MiB/s]


**Load dataloaders with both transforms**

In [12]:
cnn_model = models.resnet18(pretrained=False)
cnn_model.fc = nn.Linear(cnn_model.fc.in_features, 2)
cnn_model.load_state_dict(torch.load("models/resnet18_cnn.pth", map_location=device))
cnn_model.eval().to(device)
cnn_model.input_size = (512, 512)



In [13]:
clip_classifier = CLIPClassifier(clip_model, embed_dim=512, num_classes=2).to(device)
clip_classifier.load_state_dict(torch.load("models/clip_classifier_10epochs.pth", map_location=device))
clip_classifier.eval()
clip_classifier.input_size = (224, 224)

**Generating masks**

In [14]:
def generate_masks(model, N, s, p1):
    cell_size = np.ceil(np.array(model.input_size) / s).astype(int)
    up_size = ((s + 1) * cell_size).astype(int)

    grid = np.random.rand(N, s, s) < p1
    grid = grid.astype('float32')

    masks = np.empty((N, *model.input_size))

    for i in tqdm(range(N), desc='Generating masks'):
        x = np.random.randint(0, cell_size[0])
        y = np.random.randint(0, cell_size[1])
        upsampled = resize(grid[i], up_size, order=1, mode='reflect', anti_aliasing=False)
        masks[i] = upsampled[x:x + model.input_size[0], y:y + model.input_size[1]]

    masks = masks.reshape(N, 1, *model.input_size)
    return masks


**Rise explainer**

In [15]:
def explain(model, inp, masks, N, p1, batch_size=100):
    with torch.no_grad():
        inp_np = inp.cpu().numpy()  # (1, 3, H, W)
        inp_np = inp_np.squeeze(0)  # -> (3, H, W)
        masked = masks * inp_np  # (N, 1, H, W) * (3, H, W) → (N, 3, H, W)

        preds = []
        for i in tqdm(range(0, N, batch_size), desc='Explaining'):
            batch = torch.from_numpy(masked[i:i+batch_size]).to(device).float()
            out = model(batch)
            probs = torch.softmax(out, dim=1).cpu().numpy()
            preds.append(probs)

        preds = np.concatenate(preds, axis=0)
        sal = preds.T @ masks.reshape(N, -1)
        sal = sal.reshape(2, *model.input_size)  # 2 classes
        sal = sal / N / p1
    return sal


**Run RISE model**

In [16]:
def run_rise_get_outputs(image_path, model, transform, class_names, N=2000, s=8, p1=0.5):
    # Load + preprocess image
    img = Image.open(image_path).convert("RGB")
    input_tensor = transform(img).unsqueeze(0)

    # Predict
    with torch.no_grad():
        output = model(input_tensor.to(device))
        probs = F.softmax(output, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        pred_label = class_names[pred_class]

    # RISE saliency
    masks = generate_masks(model, N, s, p1)
    saliency = explain(model, input_tensor.to(device), masks, N, p1)

    # Also return image as numpy
    img_np = transform(img).permute(1, 2, 0).numpy()

    return {
        "img_np": img_np,
        "saliency": saliency[pred_class],
        "label": pred_label,
        "class_id": pred_class,
        "image_path": image_path
    }

**Visualization function**

In [17]:
def visualize_side_by_side(original_img_np, cnn_result, clip_result, titles=None, filename="xai_output.png"):
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Titles
    if titles is None:
        titles = ["Original Image", "CNN + RISE", "CLIP + RISE"]

    # Original
    axes[0].imshow(original_img_np)
    axes[0].set_title(titles[0], fontsize=13, fontweight="bold")
    axes[0].axis("off")

    # CNN
    axes[1].imshow(cnn_result["img_np"])
    axes[1].imshow(cnn_result["saliency"], cmap='jet', alpha=0.5)
    axes[1].set_title(f"{titles[1]}: {cnn_result['label'].upper()}", fontsize=13)
    axes[1].axis("off")

    # CLIP
    axes[2].imshow(clip_result["img_np"])
    axes[2].imshow(clip_result["saliency"], cmap='jet', alpha=0.5)
    axes[2].set_title(f"{titles[2]}: {clip_result['label'].upper()}", fontsize=13)
    axes[2].axis("off")

    save_path = os.path.join("gui_images/RISE/", filename)
    plt.savefig(save_path, bbox_inches="tight")

    plt.show()
    plt.close()

    print(f"Saved to: {save_path}")


**Defining class names**

In [18]:
class_names = ['AI GENERATED', 'REAL']

### Explaining real images

In [19]:
for i in range(1, 11):
    globals()[f"img_path_real_{i}"] = cherry_pick_img_real[i]

In [20]:
cnn_results = []
clip_results = []

for i in range(1, 11):
    img_path = globals()[f"img_path_real_{i}"]

    cnn_result = run_rise_get_outputs(
        image_path=img_path,
        model=cnn_model,
        transform=cnn_transform,
        class_names=class_names
    )

    clip_result = run_rise_get_outputs(
        image_path=img_path,
        model=clip_classifier,
        transform=preprocess_clip,
        class_names=class_names
    )

    cnn_results.append(cnn_result)
    clip_results.append(clip_result)


Generating masks: 100%|██████████| 2000/2000 [00:22<00:00, 89.92it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.81it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 467.64it/s]
Explaining: 100%|██████████| 20/20 [00:01<00:00, 15.74it/s]
Generating masks: 100%|██████████| 2000/2000 [00:21<00:00, 94.79it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.86it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 461.67it/s]
Explaining: 100%|██████████| 20/20 [00:01<00:00, 15.88it/s]
Generating masks: 100%|██████████| 2000/2000 [00:21<00:00, 94.54it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.86it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 460.38it/s]
Explaining: 100%|██████████| 20/20 [00:01<00:00, 15.90it/s]
Generating masks: 100%|██████████| 2000/2000 [00:20<00:00, 95.27it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.85it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 458.25it/s]
Explaining: 100%

### Visualize real images

In [30]:
for i, (cnn_result, clip_result) in enumerate(zip(cnn_results, clip_results), start=1):
    visualize_side_by_side(
        original_img_np=cnn_result["img_np"],
        cnn_result=cnn_result,
        clip_result=clip_result,
        filename=f"real/img{i}.png"
    )

    print(f"[{i}] CNN predicts:  {cnn_result['label'].upper()} (klasse {cnn_result['class_id']})")
    print(f"[{i}] CLIP predicts: {clip_result['label'].upper()} (klasse {clip_result['class_id']})")


Output hidden; open in https://colab.research.google.com to view.

### Explaining AI generated images

In [24]:
for i in range(1, 11):
    globals()[f"img_path_ai_{i}"] = cherry_pick_img_ai_generated[i]

In [31]:
cnn_results_ai = []
clip_results_ai = []

for i in range(1, 11):
    img_path_ai = globals()[f"img_path_ai_{i}"]

    cnn_result_ai = run_rise_get_outputs(
        image_path=img_path_ai,
        model=cnn_model,
        transform=cnn_transform,
        class_names=class_names
    )

    clip_result_ai = run_rise_get_outputs(
        image_path=img_path_ai,
        model=clip_classifier,
        transform=preprocess_clip,
        class_names=class_names
    )

    cnn_results_ai.append(cnn_result_ai)
    clip_results_ai.append(clip_result_ai)


Generating masks: 100%|██████████| 2000/2000 [00:21<00:00, 94.71it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.82it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 464.65it/s]
Explaining: 100%|██████████| 20/20 [00:01<00:00, 15.71it/s]
Generating masks: 100%|██████████| 2000/2000 [00:20<00:00, 95.65it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 462.13it/s]
Explaining: 100%|██████████| 20/20 [00:01<00:00, 15.74it/s]
Generating masks: 100%|██████████| 2000/2000 [00:21<00:00, 95.17it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.82it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 463.57it/s]
Explaining: 100%|██████████| 20/20 [00:01<00:00, 15.69it/s]
Generating masks: 100%|██████████| 2000/2000 [00:20<00:00, 95.70it/s]
Explaining: 100%|██████████| 20/20 [00:07<00:00,  2.83it/s]
Generating masks: 100%|██████████| 2000/2000 [00:04<00:00, 464.08it/s]
Explaining: 100%

### Visualize AI generated images

In [32]:
for i, (cnn_result_ai, clip_result_ai) in enumerate(zip(cnn_results_ai, clip_results_ai), start=1):
    visualize_side_by_side(
        original_img_np=cnn_result_ai["img_np"],
        cnn_result=cnn_result_ai,
        clip_result=clip_result_ai,
        filename=f"ai_generated/img{i}.png"
    )

    print(f"[{i}] CNN predicts:  {cnn_result_ai['label'].upper()} (klasse {cnn_result_ai['class_id']})")
    print(f"[{i}] CLIP predicts: {clip_result_ai['label'].upper()} (klasse {clip_result_ai['class_id']})")


Output hidden; open in https://colab.research.google.com to view.