#**Data import**

In [None]:
import os
from pathlib import Path
import subprocess
!pip install gdown
import gdown  # Install gdown on Kaggle using pip if necessary

# --- Set Up Directories ---
tmp_dir = Path("/kaggle/temp")
target_dir = Path("/kaggle/working/spatialsense_data")
tmp_dir.mkdir(parents=True, exist_ok=True)
target_dir.mkdir(parents=True, exist_ok=True)

# --- Download and Extract SpatialSense Dataset ---
print("Downloading and extracting SpatialSense dataset...")
spatialsense_dir = target_dir / "spatialsense"
spatialsense_dir.mkdir(parents=True, exist_ok=True)

spatialsense_image_dir = spatialsense_dir / "images"
spatialsense_image_dir.mkdir(parents=True, exist_ok=True)

# Step 1: Download and unzip `spatialsense.zip`
spatialsense_zip_url = "https://zenodo.org/api/records/8104370/files-archive"
spatialsense_zip_file = tmp_dir / "spatialsense.zip"
subprocess.run(["wget", spatialsense_zip_url, "-O", str(spatialsense_zip_file)], check=True)

# Unzip SpatialSense archive
subprocess.run(["unzip", "-o", str(spatialsense_zip_file), "-d", str(spatialsense_dir)], check=True)

# Step 2: Extract `images.tar.gz`
spatialsense_images_tar = spatialsense_dir / "images.tar.gz"
subprocess.run(["tar", "-zxvf", str(spatialsense_images_tar), "-C", str(spatialsense_image_dir)], check=True)

# --- Download SpatialSense+ Annotations ---
print("Downloading SpatialSense+ annotations...")
gdrive_link = "https://docs.google.com/uc?export=download&id=1vIOozqk3OlxkxZgL356pD1EAGt06ZwM4"
annotations_file = spatialsense_dir / "annots_spatialsenseplus.json"
gdown.download(url=gdrive_link, output=str(annotations_file), quiet=False)

# --- Cleanup Temporary Files ---
print("Cleaning up temporary files...")
spatialsense_zip_file.unlink(missing_ok=True)  # Remove the zip file
tmp_dir.rmdir()  # Remove the temp directory

print("Download and extraction completed.")
print(f"Dataset extracted to: {spatialsense_dir}")

In [None]:
!pip install grad-cam

#**VGG**

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.models import squeezenet1_1, SqueezeNet1_1_Weights
from PIL import Image
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from collections import Counter
from torch.optim.lr_scheduler import ReduceLROnPlateau
import copy
import numpy as np

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Dataset Class ---
class SpatialSenseDataset(torch.utils.data.Dataset):
    def __init__(self, annotations, predicates, transform=None):
        self.annotations = annotations
        self.predicates = predicates
        self.predicate_to_index = {pred: idx for idx, pred in enumerate(predicates)}
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        image_path = annotation["image_path"]
        predicate = annotation["predicate"]
        subject_bbox = annotation["subject_bbox"]
        object_bbox = annotation["object_bbox"]
        subject_name = annotation["subject_name"]
        object_name = annotation["object_name"]

        # Load and crop the image
        try:
            image = Image.open(image_path).convert('RGB')
            cropped_image = image.crop((
                min(subject_bbox[2], object_bbox[2]),  # x0
                min(subject_bbox[0], object_bbox[0]),  # y0
                max(subject_bbox[3], object_bbox[3]),  # x1
                max(subject_bbox[1], object_bbox[1])   # y1
            ))
        except Exception as e:
            raise FileNotFoundError(f"Image {image_path} could not be loaded: {str(e)}")

        if self.transform:
            cropped_image = self.transform(cropped_image)

        label = torch.tensor(self.predicate_to_index[predicate], dtype=torch.long)
        return cropped_image, label, subject_name, object_name

# --- Parse Annotations ---
def parse_annotations(annotations_file, extracted_images_path):
    with open(annotations_file, 'r') as f:
        annotations = json.load(f)

    dataset = []
    for ann in annotations.get('sample_annots', []):
        url = ann.get('url')
        split = ann.get('split')

        for pred_ann in ann.get('annotations', []):
            predicate = pred_ann.get("predicate")
            label = pred_ann.get("label")
            subject = pred_ann.get("subject", {})
            object_ = pred_ann.get("object", {})

            if not url or not split or predicate not in predicates or str(label) != "True":
                continue

            folder = "flickr" if "staticflickr" in url else "nyu" if "nyu" in url else None
            if folder is None:
                continue

            filename = os.path.basename(url)
            if filename.startswith("._"):
                continue

            image_path = os.path.join(extracted_images_path, folder, filename)
            if not os.path.exists(image_path):
                continue

            dataset.append({
                "image_path": image_path,
                "predicate": predicate,
                "split": split,
                "subject_bbox": subject.get("bbox", [0, 0, 0, 0]),
                "object_bbox": object_.get("bbox", [0, 0, 0, 0]),
                "subject_name": subject.get("name", "unknown"),
                "object_name": object_.get("name", "unknown"),
            })
    return dataset

# --- Early Stopping ---
class EarlyStopping:
    def __init__(self, patience=5, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.best_model = None

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.best_model = copy.deepcopy(model.state_dict())
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.best_model = copy.deepcopy(model.state_dict())
            self.counter = 0

# --- Data Preparation ---
annotations_path = "/kaggle/working/spatialsense_data/spatialsense/annots_spatialsenseplus.json"
extracted_images_path = "/kaggle/working/spatialsense_data/spatialsense/images/images"
#predicates = ["above", "behind", "in", "in front of", "next to", "on", "to the left of", "to the right of", "under"]
predicates = ["above", "to the left of", "to the right of", "under"]
data = parse_annotations(annotations_path, extracted_images_path)
train_data = [item for item in data if item['split'] == 'train']
valid_data = [item for item in data if item['split'] == 'valid']

# --- Transforms ---
transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=20),
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
])

transform_valid = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = SpatialSenseDataset(train_data, predicates, transform=transform_train)
valid_dataset = SpatialSenseDataset(valid_data, predicates, transform=transform_valid)

# --- Count Images Per Predicate ---
predicate_counts = Counter([item["predicate"] for item in train_data])
total_images = sum(predicate_counts.values())

# --- Class Weights ---
class_weights = [
    total_images / predicate_counts[pred] if predicate_counts[pred] > 0 else 0
    for pred in predicates
]
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

# --- Data Loaders ---
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True,drop_last=True)
valid_loader = DataLoader(valid_dataset, batch_size=64, shuffle=False)

# --- Feature Extractor ---
class SqueezeNetFeatureExtractor(nn.Module):
    def __init__(self):
        super(SqueezeNetFeatureExtractor, self).__init__()
        squeezenet = squeezenet1_1(weights=SqueezeNet1_1_Weights.IMAGENET1K_V1) #same as default
        for param in squeezenet.features.parameters():
            param.requires_grad = True  
        self.features = squeezenet.features
        self.adaptive_pool = nn.AdaptiveAvgPool2d((7, 7))

    def forward(self, x):
        x = self.features(x)
        x = self.adaptive_pool(x)
        return x
feature_extractor = SqueezeNetFeatureExtractor().to(device)

# --- MLP Model with Increased Dropout ---
class MLPModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 216),
            nn.ReLU(),
            nn.BatchNorm1d(216),
            nn.Linear(216, num_classes),
            nn.Dropout(0.5),
        )
        #self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.fc(x)



# --- Determine Input Dim ---
dummy_input = torch.randn(1, 3, 224, 224).to(device)
with torch.no_grad():
    extracted_features = feature_extractor(dummy_input)
input_dim = extracted_features.view(-1).shape[0]

mlp_model = MLPModel(input_dim=input_dim, num_classes=len(predicates)).to(device)

# --- Weighted Loss Function ---
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor,label_smoothing=0.0)

# --- Weighted Loss and Weight Decay ---
optimizer = optim.SGD([
    {"params": feature_extractor.parameters(), "lr": 1e-3, "momentum": 0.9},
    {"params": mlp_model.parameters(), "lr": 1e-3, "momentum": 0.9}
], weight_decay=1e-4)


scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
class FullModel(nn.Module):
    def __init__(self, feature_extractor, mlp_model):
        super(FullModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.mlp_model = mlp_model

    def forward_features(self, x):
        return self.feature_extractor(x)

    def forward(self, x):
        features = self.forward_features(x)
        features = features.view(features.size(0), -1)  # Flatten the features
        return self.mlp_model(features)

full_model = FullModel(feature_extractor, mlp_model).to(device)

# --- Training and Evaluation ---
def train_one_epoch(loader, model, optimizer, criterion):
    model.train()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    for batch_idx, (images, labels, subject_names, object_names) in enumerate(loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model(images) 
        loss = criterion(outputs, labels)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()


        # Metrics calculation
        total_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct_predictions += (preds == labels).sum().item()
        total_samples += labels.size(0)

        # Log progress
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}/{len(loader)} - Loss: {loss.item():.4f}")

    epoch_loss = total_loss / total_samples
    epoch_accuracy = correct_predictions / total_samples
    return epoch_loss, epoch_accuracy



# --- Update Evaluation Function ---
def evaluate(loader, model, criterion, detailed_metrics=False):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for batch_idx, (images, labels, subject_names, object_names) in enumerate(loader):
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images) 
            loss = criterion(outputs, labels)

            # Metrics
            total_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions += (preds == labels).sum().item()
            total_samples += labels.size(0)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    epoch_loss = total_loss / total_samples
    epoch_accuracy = correct_predictions / total_samples

    if detailed_metrics:
        precision = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
        recall = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
        f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)
        cm = confusion_matrix(all_labels, all_preds, labels=range(len(predicates)))
        return epoch_loss, epoch_accuracy, precision, recall, f1, cm

    return epoch_loss, epoch_accuracy


history = {
    "train_loss": [],
    "valid_loss": [],
    "valid_precision": [],
    "valid_recall": [],
    "valid_f1": []
}
# --- Main Functionality ---
def train_and_evaluate(
    train_loader, valid_loader, model, criterion, optimizer, scheduler, 
    target_layer, predicates, num_epochs=20, patience=5
):
    """
    Train and evaluate the model, with options for early stopping, Grad-CAM visualization, and node ablation.

    Parameters:
        train_loader (DataLoader): DataLoader for training data.
        valid_loader (DataLoader): DataLoader for validation data.
        model (nn.Module): Full model (feature extractor + MLP).
        criterion (nn.Module): Loss function.
        optimizer (torch.optim.Optimizer): Optimizer.
        scheduler (torch.optim.lr_scheduler): Learning rate scheduler.
        target_layer (nn.Module): Target layer for Grad-CAM visualization.
        predicates (list): List of predicates (class labels).
        num_epochs (int): Maximum number of epochs.
        patience (int): Patience for early stopping.

    Returns:
        None
    """
    early_stopping = EarlyStopping(patience=patience)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Training
        train_loss, train_acc = train_one_epoch(train_loader, model, optimizer, criterion)
        print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}")

        # Validation
        valid_loss, valid_acc, valid_prec, valid_rec, valid_f1, cm = evaluate(valid_loader, model, criterion, detailed_metrics=True)
        print(f"Validation Loss: {valid_loss:.4f}, Accuracy: {valid_acc:.4f}")
        print(f"Precision: {valid_prec:.4f}, Recall: {valid_rec:.4f}, F1: {valid_f1:.4f}")

        history["train_loss"].append(train_loss)
        history["valid_loss"].append(valid_loss)
        history["valid_precision"].append(valid_prec)
        history["valid_recall"].append(valid_rec)
        history["valid_f1"].append(valid_f1)
        
        # Step the scheduler
        scheduler.step(valid_loss)
        
        # Early stopping
        early_stopping(valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered. Restoring best model.")
            model.load_state_dict(early_stopping.best_model)
            break

    # Grad-CAM and Node Ablation
    print("Generating Grad-CAM heatmaps and performing node ablation...")
    generate_heatmaps_and_ablate(
        valid_loader, model, target_layer, criterion, predicates, device, n_samples=10
    )


    # Display Confusion Matrix
    print("Confusion Matrix:")
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=predicates)
    disp.plot(cmap=plt.cm.Blues, xticks_rotation=45)
    plt.show()
# --- Heatmap Generation and Node Ablation ---
def generate_heatmaps_and_ablate(loader, model, target_layer, criterion, predicates, device, n_samples=10):
    """
    Generate Grad-CAM heatmaps and perform node ablation analysis for a subset of images.

    Parameters:
        loader (DataLoader): DataLoader for the dataset.
        model (nn.Module): Full model (feature extractor + MLP).
        target_layer (nn.Module): Target layer for Grad-CAM visualization.
        criterion (nn.Module): Loss function.
        predicates (list): List of predicates (class labels).
        device (torch.device): Device to perform computations on.
        n_samples (int): Number of samples to process.

    Returns:
        None
    """
    from pytorch_grad_cam import GradCAM
    from pytorch_grad_cam.utils.image import show_cam_on_image
    import matplotlib.pyplot as plt
    import numpy as np

    grad_cam = GradCAM(model=model, target_layers=[target_layer])
    model.eval()

    processed = 0
    for batch_idx, (images, labels, subject_names, object_names) in enumerate(loader):
        if processed >= n_samples:
            break
        images, labels = images.to(device), labels.to(device)

        with torch.no_grad():
            outputs = model(images)
            _, preds = torch.max(outputs, 1)  
            loss = criterion(outputs, labels)

        for i in range(images.size(0)):
            if processed >= n_samples:
                break

            image = images[i].unsqueeze(0)  
            label = labels[i].item()
            pred = preds[i].item()
        
            subject_name = subject_names[i]
            object_name = object_names[i]

            # Generate heatmap
            grayscale_cam = grad_cam(input_tensor=image)[0]
            original_image = image.squeeze().permute(1, 2, 0).cpu().numpy()
            original_image = (original_image * [0.229, 0.224, 0.225]) + [0.485, 0.456, 0.406]  # Unnormalize
            original_image = np.clip(original_image, 0, 1)
            heatmap = show_cam_on_image(original_image, grayscale_cam, use_rgb=True)

            # Display the heatmap
            plt.figure(figsize=(12, 6))
            plt.subplot(1, 2, 1)
            plt.imshow(original_image)
            plt.title(
                f"Original Image\n"
                f"Subject: {subject_name}, Object: {object_name}\n"
                f"True: {predicates[label]}, Predicted: {predicates[pred]}"
            )
            plt.axis("off")

            plt.subplot(1, 2, 2)
            plt.imshow(heatmap)
            plt.title("Grad-CAM Heatmap")
            plt.axis("off")

            plt.tight_layout()
            plt.show()

            processed += 1

        # Node Ablation Analysis
        features = model.forward_features(images)
        for i in range(features.size(1)):  # Iterate over feature dimensions
            modified_features = features.clone()
            modified_features[:, i] = 0  # Zero out one feature dimension

            outputs_modified = model.mlp_model(modified_features)
            loss_modified = criterion(outputs_modified, labels)


    print("Heatmap generation and node ablation completed.")

# --- Call the Training and Evaluation Function ---
target_layer = feature_extractor.features[-1] 
train_and_evaluate(
    train_loader=train_loader,
    valid_loader=valid_loader,
    model=full_model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    target_layer=target_layer,
    predicates=predicates,
    num_epochs=50,
    patience=10
)

# --- Plotting Functions ---
def plot_loss(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["valid_loss"], label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs. Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_metrics(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["valid_precision"], label="Precision")
    plt.plot(history["valid_recall"], label="Recall")
    plt.plot(history["valid_f1"], label="F1 Score")
    plt.xlabel("Epoch")
    plt.ylabel("Score")
    plt.title("Validation Metrics: Precision, Recall, F1-Score")
    plt.legend()
    plt.grid(True)
    plt.show()

# --- Display Plots ---
plot_loss(history)
plot_metrics(history)