In [18]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision.transforms as transforms
from pycocotools.coco import COCO
from PIL import Image
import pylab
pylab.rcParams['figure.figsize'] = (8.0, 10.0)
from torch.utils.data import DataLoader
import random

from cam_net import AlexNet_GAP
from cam_utils import label_data, ResizeAndPad

In [19]:
dataDir='data'
dataType='val2017'
annFile='{}/annotations/instances_{}.json'.format(dataDir,dataType)
coco=COCO(annFile)
all_data = label_data(coco)

loading annotations into memory...
Done (t=0.76s)
creating index...
index created!


In [20]:
class ImageListDataset(Dataset):
    def __init__(self, data_list, data_dir, transform=None):
        self.data_list = data_list
        self.transform = transform or transforms.ToTensor()
        self.data_dir = data_dir

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        file_name, label = self.data_list[idx]
        image = Image.open(f'{self.data_dir}/images/{file_name}').convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, label


In [21]:
# Cross Validation
random.shuffle(all_data)
split = int(0.8 * len(all_data))
train_data, val_data = all_data[:split], all_data[split:]

# Transform
transform = transforms.Compose([
    ResizeAndPad(),
    transforms.ToTensor(),
])

# Datasets
train_dataset = ImageListDataset(train_data, dataDir, transform)
val_dataset = ImageListDataset(val_data, dataDir, transform)

# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

In [22]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlexNet_GAP(num_classes=2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
num_epochs = 1
for epoch in range(num_epochs):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

In [31]:
testimage = transforms.ToTensor()(Image.open(f'{dataDir}/images/{all_data[random.randint(0, len(all_data) - 1)][0]}').convert("RGB"))
output = model(testimage)

# Get feature maps and class weights
feature_maps = model.feature_maps.squeeze(0)  # [256, H, W]
weights = model.classifier.weight  # [num_classes, 256]

# Choose predicted class
pred_class = torch.argmax(output, dim=1).item()
class_weights = weights[pred_class]  # [256]

# Compute CAM
cam = torch.einsum("c,chw->hw", class_weights, feature_maps)  # [H, W]
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-5)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (256x1 and 256x2)

In [None]:
def overlay_cam_with_centroid(image, cam, center, alpha=0.5):
    """
    image: numpy array [H, W, 3], values in [0, 1]
    cam: numpy array [H, W], values in [0, 1]
    center: tuple (x, y) in pixel coordinates
    alpha: transparency of heatmap
    """
    fig, ax = plt.subplots()

    # Show the original image
    ax.imshow(image)

    # Overlay the CAM
    heatmap = ax.imshow(cam, cmap='jet', alpha=alpha)

    # Add colorbar if you want
    # plt.colorbar(heatmap, ax=ax)

    # Mark the centroid
    x_com, y_com = center
    ax.plot(x_com, y_com, 'wo')  # white circle
    ax.plot(x_com, y_com, 'r+', markersize=12, markeredgewidth=2)  # red cross

    ax.set_title("CAM Overlay with Centroid")
    ax.axis('off')
    plt.show()


In [None]:
overlay_cam_with_centroid(testimage, cam)