In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
import torchvision.transforms as transforms
from pycocotools.coco import COCO
from PIL import Image
import pylab
pylab.rcParams['figure.figsize'] = (8.0, 10.0)
from torch.utils.data import DataLoader
import random
from tqdm import tqdm

from cam_net import AlexNet_GAP, GoogLeNet_GAP
from cam_utils import label_data,overlay_cam_with_centroid, ResizeAndPad, ImageListDataset

In [None]:
data_dir = 'data'
train_coco=COCO(f'{data_dir}/annotations/instances_train2017.json')
val_coco=COCO(f'{data_dir}/annotations/instances_val2017.json')
train_data = label_data(train_coco)
val_data = label_data(val_coco)

In [None]:
# # Cross Validation
# random.shuffle(all_data)
# split = int(0.8 * len(all_data))
# train_data, val_data = all_data[:split], all_data[split:]

# Transform
transform = transforms.Compose([
    ResizeAndPad(target_size=224),
    transforms.RandomHorizontalFlip,
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.RandomHorizontalFlip(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
])

# Datasets
train_dir = f'{data_dir}/images/train'
val_dir = f'{data_dir}/images/val'
train_dataset = ImageListDataset(train_data, train_dir, transform)
val_dataset = ImageListDataset(val_data, val_dir, transform)


In [None]:
# DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = GoogLeNet_GAP(num_classes=2).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
train_losses = []
val_losses = []
num_epochs = 3
for epoch in range(num_epochs):
    running_train_loss = 0.0
    running_val_loss = 0.0
    train_loader_tqdm = tqdm(train_loader, desc="Training", leave=False)
    val_loader_tqdm = tqdm(val_loader, desc="Validating", leave=False)
    model.train()
    for image, label in train_loader_tqdm:
        image, label = image.to(device), label.to(device)
        output = model(image)
        loss = criterion(output, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_train_loss += loss.item()
        train_loader_tqdm.set_postfix(loss=loss.item())

    train_losses.append(running_train_loss / len(train_loader))

    model.eval()

    with torch.no_grad():
        for images, labels in val_loader_tqdm:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_val_loss += loss.item()
            
            val_loader_tqdm.set_postfix(loss=loss.item())
    
    val_losses.append(running_val_loss / len(val_loader))

    print(f"Epoch {epoch}  |  Train Loss: {train_losses[-1]:.4f} | Val Loss: {val_losses[-1]:.4f}")



In [None]:
torch.save(model.state_dict(), "model.pth")

In [None]:
# model = GoogLeNet_GAP(num_classes=2).to(device)
# model.load_state_dict(torch.load("model.pth"))

# model.eval()

In [None]:

plt.figure(figsize=(10, 5))
plt.plot(train_losses, label="Training Loss", color='blue', alpha=0.7)
plt.plot(val_losses, label="Val Loss", color='orange', alpha=0.7)
plt.xlabel("Iteration")
plt.ylabel("Loss")
plt.title("Training Loss Over Time")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
testdata = val_data[random.randint(0, len(val_data) - 1)]
while testdata[1] != 1: #pick one that has a person in it
    testdata = val_data[random.randint(0, len(val_data) - 1)]
testimage = Image.open(f'{data_dir}/images/val/{testdata[0]}').convert("RGB")
output = model(transforms.ToTensor()(testimage).unsqueeze(0).to(device))

# Get feature maps and class weights
feature_maps = model.feature_maps.squeeze(0)
weights = model.classifier.weight

# Choose predicted class
pred_class = torch.argmax(output, dim=1).item()
print(f"pred: {pred_class}, target: {testdata[1]}")
class_weights = weights[pred_class]

# Compute CAM
cam = torch.einsum("c,chw->hw", class_weights, feature_maps)
cam = F.relu(cam)
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-5)
overlay_cam_with_centroid(testimage, cam)