### Part 1: Saving the visualisation images

In [None]:
# Installing required libraries for model and image processing
!pip install transformers==4.44.2
!pip install ultralytics
!pip install numpy
!pip install opencv-python
!pip install shap
!pip install torch torchvision torchaudio 
!pip install flash-attn --no-build-isolation
!pip install bitsandbytes accelerate
%restart_python

In [None]:
# Import necessary libraries
import os
import cv2
import numpy as np
from ultralytics import YOLO
import shap
import torch
import matplotlib.pyplot as plt
import math
import logging

# Load your custom-trained YOLOv8 object detection model
model = YOLO("weights/yolov8lbest.pt")

# Define input directory with images and output directory for saving results
input_dir = 'images/original'
output_dir = "images/output"

# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)

# Class index to class name mapping
class_name_mapping = {0: "part_1", 1: "part_2", 2: "part_3"}

# YOLO prediction function for SHAP
def yolo_predict(images):
    images = [img for img in images]  # Ensure correct format for YOLOv8
    results = model.predict(source=images, save=False, verbose=False)
    
    top_confidences, top_classes, bboxes = [], [], []
    for result in results:
        if len(result.boxes.data) > 0:
            sorted_indices = torch.argsort(result.boxes.data[:, 4], descending=True)[:10]
            top_10_confidences = result.boxes.data[sorted_indices][:, 4].cpu().numpy()
            top_10_classes = result.boxes.data[sorted_indices][:, 5].cpu().numpy()
            top_10_bboxes = result.boxes.data[sorted_indices][:, :4].cpu().numpy()
            
            if top_10_confidences.shape[0] < 10:
                top_10_confidences = np.pad(top_10_confidences, (0, 10 - top_10_confidences.shape[0]), mode='constant')
                top_10_classes = np.pad(top_10_classes, (0, 10 - top_10_classes.shape[0]), mode='constant')
                top_10_bboxes = np.pad(top_10_bboxes, ((0, 10 - top_10_bboxes.shape[0]), (0, 0)), mode='constant')

            top_confidences.append(top_10_confidences)
            top_classes.append(top_10_classes)
            bboxes.append(top_10_bboxes)
        else:
            top_confidences.append(np.zeros(10))
            top_classes.append(np.zeros(10))
            bboxes.append(np.zeros((10, 4)))

    return np.array(top_confidences), np.array(top_classes), np.array(bboxes)

# SHAP and bounding box saving function
def save_shap_and_bbox(image_path, shap_values, image_batch, class_indices, confidences, bboxes, original_image):
    num_outputs = shap_values[0].shape[-1]  # Number of outputs (for each predicted class/confidence)
    highest_conf_per_class = {0: -1, 1: -1, 2: -1}
    shap_to_save = {}
    bbox_to_save = {}

    for i in range(num_outputs):
        class_idx = int(class_indices[i])
        confidence = confidences[i]
        
        if confidence > highest_conf_per_class[class_idx]:
            highest_conf_per_class[class_idx] = confidence
            shap_to_save[class_idx] = i
            bbox_to_save[class_idx] = bboxes[i]

    base_filename = os.path.splitext(os.path.basename(image_path))[0]
    for class_idx, i in shap_to_save.items():
        class_name = class_name_mapping.get(class_idx, f"Class {class_idx}")
        confidence = highest_conf_per_class[class_idx]
        bbox = bbox_to_save[class_idx]

        x_min, y_min, x_max, y_max = map(int, bbox)
        bbox_image = original_image.copy()
        cv2.rectangle(bbox_image, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
        cv2.putText(bbox_image, f"{class_name} {confidence:.2f}", (x_min, y_min - 10), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
        
        bbox_output_path = os.path.join(output_dir, f"{base_filename}_{class_name}_bbox.png")
        cv2.imwrite(bbox_output_path, bbox_image)

        plt.figure()
        image_rgb_for_plotting = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
        plt.figure(figsize=(7, 7))
        shap.image_plot([shap_values[..., i]], np.expand_dims(image_rgb_for_plotting, axis=0), show=False)
        shap_output_path = os.path.join(output_dir, f"{base_filename}_{class_name}_shap.png")
        plt.savefig(shap_output_path, bbox_inches='tight', dpi=600)
        plt.close()

# SHAP explainer function
def run_shap_explainer(image_path, image_rgb):
    image_batch = np.expand_dims(image_rgb, axis=0)
    confidences, classes, bboxes = yolo_predict([image_rgb])

    masker_blur = shap.maskers.Image("blur(128,128)", image_rgb.shape)
    explainer = shap.Explainer(lambda x: yolo_predict(x)[0], masker_blur)
    
    shap_values = explainer(image_batch, max_evals=5000, batch_size=50)
    save_shap_and_bbox(image_path, shap_values.values, image_batch, classes[0], confidences[0], bboxes[0], image_rgb)

# DRISE function to generate saliency maps for each class
def generate_drise_saliency_per_class(model, image_path, output_directory, label_names, n_masks=5000, grid_size=(16, 16), prob_thresh=0.2):
    os.makedirs(output_directory, exist_ok=True)
    image = cv2.imread(image_path)
    image_h, image_w = image.shape[:2]

    results = model(image)
    preds = results[0].boxes.xyxy.cpu().numpy()
    scores = results[0].boxes.conf.cpu().numpy()
    pred_classes = results[0].boxes.cls.cpu().numpy().astype(int)

    best_boxes = {}
    for i in range(len(preds)):
        class_id = pred_classes[i]
        confidence = scores[i]
        bbox = preds[i]
        if class_id not in best_boxes or confidence > best_boxes[class_id]['confidence']:
            best_boxes[class_id] = {'bbox': bbox, 'confidence': confidence}

    for class_id, info in best_boxes.items():
        if class_id >= len(label_names):
            logging.warning(f"Class ID {class_id} exceeds the label names list length. Skipping.")
            continue
        target_box = info['bbox']
        confidence_score = info['confidence']
        class_name = label_names[class_id]

        saliency_map = np.zeros((image_h, image_w), dtype=np.float32)
        for i in range(n_masks):
            mask = generate_mask((image_w, image_h), grid_size, prob_thresh)
            masked_image = mask_image(image, mask)
            masked_results = model(masked_image)
            masked_preds = masked_results[0].boxes.xyxy.cpu().numpy()
            masked_scores = masked_results[0].boxes.conf.cpu().numpy()
            masked_classes = masked_results[0].boxes.cls.cpu().numpy().astype(int)

            for box, score, cls in zip(masked_preds, masked_scores, masked_classes):
                if cls == class_id:
                    iou_score = iou(target_box, box)
                    saliency_map += mask * iou_score * score

        saliency_map = (saliency_map - saliency_map.min()) / (saliency_map.max() - saliency_map.min())
        heatmap = cv2.applyColorMap((saliency_map * 255).astype(np.uint8), cv2.COLORMAP_JET)
        cam = cv2.addWeighted(image, 0.5, heatmap, 0.5, 0)

        image_name = os.path.basename(image_path).split('.')[0]
        output_path = os.path.join(output_directory, f"{image_name}_{class_name}_saliency.png")
        cv2.imwrite(output_path, cam)
        print(f"Saved saliency map for class '{class_name}' at {output_path}")

# Supporting functions for DRISE
def generate_mask(image_size, grid_size, prob_thresh):
    image_w, image_h = image_size
    grid_w, grid_h = grid_size
    cell_w, cell_h = math.ceil(image_w / grid_w), math.ceil(image_h / grid_h)
    up_w, up_h = (grid_w + 1) * cell_w, (grid_h + 1) * cell_h
    mask = (np.random.uniform(0, 1, size=(grid_h, grid_w)) < prob_thresh).astype(np.float32)
    mask = cv2.resize(mask, (up_w, up_h), interpolation=cv2.INTER_LINEAR)
    offset_w = np.random.randint(0, cell_w)
    offset_h = np.random.randint(0, cell_h)
    return mask[offset_h:offset_h + image_h, offset_w:offset_w + image_w]

def mask_image(image, mask):
    return ((image.astype(np.float32) / 255 * np.dstack([mask] * 3)) * 255).astype(np.uint8)

def iou(box1, box2):
    box1 = np.asarray(box1)
    box2 = np.asarray(box2)
    tl = np.vstack([box1[:2], box2[:2]]).max(axis=0)
    br = np.vstack([box1[2:], box2[2:]]).min(axis=0)
    intersection = np.prod(br - tl) * np.all(tl < br).astype(float)
    area1 = np.prod(box1[2:] - box1[:2])
    area2 = np.prod(box2[2:] - box2[:2])
    return intersection / (area1 + area2 - intersection)

# Loop through each image in the directory and run both DRISE and SHAP explainers
for image_name in os.listdir(input_dir):
    image_path = os.path.join(input_dir, image_name)
    if image_name.lower().endswith(('.png', '.jpg', '.jpeg')):
        image = cv2.imread(image_path)
        image_rgb = cv2.resize(image, (512, 512)) if image.shape[:2] != (512, 512) else image

        # Run DRISE
        generate_drise_saliency_per_class(model, image_path, output_dir, list(class_name_mapping.values()))
        
        # Run SHAP
        run_shap_explainer(image_path, image_rgb)

print("DRISE and SHAP operations completed for all images.")


### Part 2: Large Multimodal Model Interaction

In [None]:
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from transformers import AutoProcessor, AutoModelForVision2Seq, BitsAndBytesConfig
from transformers.image_utils import load_image

# Function to clear memory and avoid out of memory issues
def free_gpu_memory():
    print("Clearing GPU memory...")
    torch.cuda.empty_cache()

DEVICE = "cuda:0" 

# BitsAndBytesConfig for 4-bit quantization
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",         # 4-bit quantization type
    bnb_4bit_use_double_quant=True,    # Double quantization for memory efficiency
    bnb_4bit_compute_dtype=torch.float16  # Use FP16 for computation
)

# Load processor and deactivate image splitting for better memory efficiency
processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=False,  # Deactivate image splitting
)

# Load model with FP16 and 4-bit quantization
model = AutoModelForVision2Seq.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    torch_dtype=torch.float16,           # Use FP16 for reduced memory and faster computation
    quantization_config=quantization_config,  # 4-bit quantization
)  

# Function to load images for a specific class ID
def load_images_for_analysis(image_name, class_id, input_dir):
    # Only three images of the same class ID
    image_filenames = [
        f"{image_name}_{class_id}_bbox.png",
        f"{image_name}_{class_id}_saliency.png",
        f"{image_name}_{class_id}_shap.png"
    ]
    
    images = []
    for filename in image_filenames:
        image_path = os.path.join(input_dir, filename)
        if os.path.exists(image_path):
            images.append(load_image(image_path))
        else:
            raise FileNotFoundError(f"Image {filename} not found in {input_dir}.")
    return images, image_filenames  # Also return the filenames for display

# Function to generate prompt for the model based on user input and image type
def generate_prompt_for_image(image_name, class_id, question):
    prompt = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": f"bounding_box_{class_id}"},  # Bounding box image
                {"type": "image", "image": f"saliency_{class_id}"},      # Saliency map
                {"type": "image", "image": f"shap_{class_id}"},          # SHAP visualization
                {
                 "type": "text",
                 "text": f"Context: This task involves three visualizations that highlight various aspects of how the model interprets an object in an image.\n\n"
                         f"1. **Bounding Box Visualization**: Shows the region where the object is detected, illustrating how the model identifies and localizes the object within the image.\n\n"
                         f"2. **Saliency Map Visualization**: Highlights the areas of the image that are most important to the model's decision-making process, emphasizing the regions that influenced the prediction.\n\n"
                         f"3. **SHAP (SHapley Additive exPlanations) Visualization**: Provides insights into how different parts of the image contribute to the model’s prediction, offering a clear understanding of the model's behavior.\n\n"
                         f"Task: {question}\n\n"
                         f"The goal of this analysis is to interpret and explain the provided visualizations, focusing on understanding how each image contributes to the overall prediction. Specifically, explain which areas of the image were most influential and where the object is located."
                }

            ]
        }
    ]
    return prompt


# Function to process images, ask questions, and conditionally display the images
def ask_model(image_name, class_id, question, input_dir, display_images_flag=True):
    try:
        # Free memory before starting inference
        free_gpu_memory()

        # Load the 3 associated images and filenames from the input directory
        images, image_filenames = load_images_for_analysis(image_name, class_id, input_dir)

        # Generate prompt based on the user's question
        prompt = generate_prompt_for_image(image_name, class_id, question)
        
        # Use processor to format the prompt and images for the model
        with torch.no_grad():  # Prevent building computation graph during inference
            inputs = processor.apply_chat_template(prompt, add_generation_prompt=True)
            inputs = processor(text=inputs, images=images, return_tensors="pt")
            inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

            print("Processing model inference...")  # Debug message

            # Generate response from the model
            generated_ids = model.generate(
                **inputs,
                max_new_tokens=500,
                top_k=10,
                do_sample=True,
            )

            # Debug check to see if model output is captured
            if generated_ids is None or len(generated_ids) == 0:
                print("No output generated by the model.")
                return
            
            print("Model inference complete.")  # Debug message

            # Decode and return the generated response
            generated_texts = processor.batch_decode(generated_ids, skip_special_tokens=True)
            
            # Print the model's response
            print("Model's Response:")
            print(generated_texts[0])
            
            # Conditionally display the images if the flag is set to True
            if display_images_flag:
                display_images(image_filenames, input_dir)

            return generated_texts

    except FileNotFoundError as e:
        return f"Error: {str(e)}"
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return None

# Function to display images using matplotlib
def display_images(image_filenames, input_dir):
    print("Displaying images...")  # Debug message
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))  # Create a row of 3 images
    for i, filename in enumerate(image_filenames):
        image_path = os.path.join(input_dir, filename)
        image = Image.open(image_path)
        axes[i].imshow(image)
        axes[i].set_title(filename.split('_')[-1].replace('.png', '').capitalize())  # Set the title as the type of image
        axes[i].axis('off')  # Hide axis for better display
    
    plt.tight_layout()
    plt.show()  # Display the images

In [None]:
# Usage
image_name = "12"       # User provides an image name
class_id = "part_1"    # Class ID (same for all three images)
question = "Explain this shap visualisation. What features are most important in this object?"
input_folder = "images/output/"

# Ask the model with the option to display images set to True
response = ask_model(image_name, class_id, question, input_folder, display_images_flag=True)

# Print the response if it's generated
if response:
    print("Response received and processed.")
else:
    print("No response generated.")