In [None]:
!pip install dill

!pip install ultralytics
import torch
from ultralytics import YOLO
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import clear_output


clear_output()

In [None]:

class ActivationExtractor:
    def __init__(self, layer_name):
        self.layer_name = layer_name
        self.activation = None

    def hook_fn(self, module, input, output):
        self.activation = output.detach()

def get_layer_activation(model, layer_name, image):
    extractor = ActivationExtractor(layer_name)

    for name, module in model.named_modules():
        if name == layer_name:
            module.register_forward_hook(extractor.hook_fn)
            break

    image_tensor = torch.from_numpy(np.array(image).transpose(2, 0, 1)).float().unsqueeze(0) / 255.0
    image_tensor = image_tensor.to(next(model.parameters()).device)

    with torch.no_grad():
        model(image_tensor)

    return extractor.activation

def visualize_activation(image, activation, layer_name):
    # Sum across channels and normalize
    activation_sum = activation.sum(dim=1).squeeze()
    activation_normalized = (activation_sum - activation_sum.min()) / (activation_sum.max() - activation_sum.min())

    # Resize activation to match image size
    activation_resized = torch.nn.functional.interpolate(
        activation_normalized.unsqueeze(0).unsqueeze(0),
        size=image.size[::-1],
        mode='bilinear',
        align_corners=False
    ).squeeze()

    # Convert to numpy for matplotlib
    activation_np = activation_resized.cpu().numpy()

    # Create the visualization
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(image)
    plt.title('Original Image')
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(activation_np, cmap='jet')
    plt.title(f'Activation Heatmap ({layer_name})')
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(image)
    plt.imshow(activation_np, cmap='jet', alpha=0.5)
    plt.title('Heatmap Overlay')
    plt.axis('off')

    plt.tight_layout()
    plt.show()

# Load the YOLOv8 model
model_after_attention = YOLO('/content/best (4).pt')
model_just_modified=YOLO('/content/MODIFIED.pt')
model_org=YOLO('/content/original.pt')
# Load the image
image_path = '/content/drive/MyDrive/Dogcat.v1i.yolov8/test/images/videoplayback_mp4-0_jpg.rf.cc044ba7afd1afff104a1edb2f028c2b.jpg'
image = Image.open(image_path).convert('RGB')

# Choose the layer name you want to extract activations from
layer_name = 'model.30.cv3.3.1.conv'  # Example: a specific conv layer

# Get activations
activation_after_attention = get_layer_activation(model_after_attention.model, layer_name, image)
activation_after_modification = get_layer_activation(model_just_modified.model, layer_name, image)
activation_original=get_layer_activation(model_org.model, layer_name, image)

# Visualize the activation
visualize_activation(image, activation_original, layer_name)
visualize_activation(image, activation_after_modification, layer_name)
visualize_activation(image, activation_after_attention, layer_name)

In [None]:
!git clone https://github.com/rigvedrs/YOLO-V8-CAM/
yolo_cam_parent_dir = '/content/YOLO-V8-CAM'
sys.path.append(yolo_cam_parent_dir)


In [None]:
from yolo_cam.utils.svd_on_activations import get_2d_projection
from yolo_cam.eigen_cam import EigenCAM
from yolo_cam.utils.image import show_cam_on_image, scale_cam_image

In [None]:
m=[i for i in range(2,15)]
l=os.listdir('/content/drive/MyDrive/Dogcat.v1i.yolov8/test/images')

for i in l:
  img = cv2.imread('/content/drive/MyDrive/Dogcat.v1i.yolov8/test/images/'+i)
  img = cv2.resize(img, (256, 256))
  rgb_img = img.copy()
  img = np.float32(img) / 255
  rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
  plt.imshow(rgb_img)
  print('original')
  plt.show()
  for j in m:
    model = YOLO('/content/yolov8n.pt')
    target_layers =[model.model.model[-j]]
    cam = EigenCAM(model, target_layers,task='od')
    grayscale_cam = cam(rgb_img)[0, :, :]
    cam_image = show_cam_on_image(img, grayscale_cam, use_rgb=True)
    plt.imshow(cam_image)
    print('modified')
    plt.show()