#**Data import**

In [None]:
import os
from pathlib import Path
import subprocess
!pip install gdown
import gdown  # Install gdown on Kaggle using pip if necessary

# --- Set Up Directories ---
tmp_dir = Path("/kaggle/temp")
target_dir = Path("/kaggle/working/spatialsense_data")
tmp_dir.mkdir(parents=True, exist_ok=True)
target_dir.mkdir(parents=True, exist_ok=True)

# --- Download and Extract SpatialSense Dataset ---
print("Downloading and extracting SpatialSense dataset...")
spatialsense_dir = target_dir / "spatialsense"
spatialsense_dir.mkdir(parents=True, exist_ok=True)

spatialsense_image_dir = spatialsense_dir / "images"
spatialsense_image_dir.mkdir(parents=True, exist_ok=True)

# Step 1: Download and unzip `spatialsense.zip`
spatialsense_zip_url = "https://zenodo.org/api/records/8104370/files-archive"
spatialsense_zip_file = tmp_dir / "spatialsense.zip"
subprocess.run(["wget", spatialsense_zip_url, "-O", str(spatialsense_zip_file)], check=True)

# Unzip SpatialSense archive
subprocess.run(["unzip", "-o", str(spatialsense_zip_file), "-d", str(spatialsense_dir)], check=True)

# Step 2: Extract `images.tar.gz`
spatialsense_images_tar = spatialsense_dir / "images.tar.gz"
subprocess.run(["tar", "-zxvf", str(spatialsense_images_tar), "-C", str(spatialsense_image_dir)], check=True)

# --- Download SpatialSense+ Annotations ---
print("Downloading SpatialSense+ annotations...")
gdrive_link = "https://docs.google.com/uc?export=download&id=1vIOozqk3OlxkxZgL356pD1EAGt06ZwM4"
annotations_file = spatialsense_dir / "annots_spatialsenseplus.json"
gdown.download(url=gdrive_link, output=str(annotations_file), quiet=False)

# --- Cleanup Temporary Files ---
print("Cleaning up temporary files...")
spatialsense_zip_file.unlink(missing_ok=True)  # Remove the zip file
tmp_dir.rmdir()  # Remove the temp directory

print("Download and extraction completed.")
print(f"Dataset extracted to: {spatialsense_dir}")

!pip install grad-cam

#**ViT**

In [None]:
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from PIL import Image
import pytorch_grad_cam
import matplotlib.pyplot as plt
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
from pytorch_grad_cam.utils.image import show_cam_on_image
from collections import Counter
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from copy import deepcopy
import pandas as pd
import seaborn as sns

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Dataset Class ---
class SpatialSenseDataset(torch.utils.data.Dataset):
    def __init__(self, annotations, predicates, transform=None):
        self.annotations = annotations
        self.predicates = predicates
        self.predicate_to_index = {pred: idx for idx, pred in enumerate(predicates)}
        self.transform = transform

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        annotation = self.annotations[idx]
        image_path = annotation["image_path"]
        predicate = annotation["predicate"]
        subject_bbox = annotation["subject_bbox"]
        object_bbox = annotation["object_bbox"]
        subject_name = annotation["subject_name"]
        object_name = annotation["object_name"]

        # Load and crop the image
        try:
            image = Image.open(image_path).convert('RGB')
            cropped_image = image.crop((
                min(subject_bbox[2], object_bbox[2]),  # x0
                min(subject_bbox[0], object_bbox[0]),  # y0
                max(subject_bbox[3], object_bbox[3]),  # x1
                max(subject_bbox[1], object_bbox[1])   # y1
            ))
        except Exception as e:
            raise FileNotFoundError(f"Image {image_path} could not be loaded: {str(e)}")

        if self.transform:
            cropped_image = self.transform(cropped_image)

        label = torch.tensor(self.predicate_to_index[predicate], dtype=torch.long)
        return cropped_image, label, subject_name, object_name

def parse_annotations(annotations_file, extracted_images_path):
    with open(annotations_file, 'r') as f:
        annotations = json.load(f)

    dataset = []
    for ann in annotations.get('sample_annots', []):
        url = ann.get('url')
        split = ann.get('split')

        for pred_ann in ann.get('annotations', []):
            predicate = pred_ann.get("predicate")
            label = pred_ann.get("label")
            subject = pred_ann.get("subject", {})
            object_ = pred_ann.get("object", {})

            if not url or not split or predicate not in predicates or str(label) != "True":
                continue

            folder = "flickr" if "staticflickr" in url else "nyu" if "nyu" in url else None
            if folder is None:
                continue

            filename = os.path.basename(url)
            if filename.startswith("._"):
                continue

            image_path = os.path.join(extracted_images_path, folder, filename)
            if not os.path.exists(image_path):
                continue

            dataset.append({
                "image_path": image_path,
                "predicate": predicate,
                "split": split,
                "subject_bbox": subject.get("bbox", [0, 0, 0, 0]),
                "object_bbox": object_.get("bbox", [0, 0, 0, 0]),
                "subject_name": subject.get("name", "unknown"),
                "object_name": object_.get("name", "unknown"),
            })
    return dataset

# --- Vision Transformer Feature Extractor ---
class ViTFeatureExtractor(nn.Module):
    def __init__(self):
        super(ViTFeatureExtractor, self).__init__()
        self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

    def forward(self, x):
        outputs = self.vit(pixel_values=x)
        return outputs.last_hidden_state[:, 0, :]  # CLS token

feature_extractor = ViTFeatureExtractor().to(device)

# --- MLP Model with Increased Dropout ---
class MLPModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(MLPModel, self).__init__()
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.BatchNorm1d(512),
            nn.Linear(512, 216),
            nn.ReLU(),
            nn.BatchNorm1d(216),
            nn.Linear(216, num_classes),
            nn.Dropout(0.5),
        )

    def forward(self, x):
        return self.fc(x)

# --- Full Model ---
class FullModel(nn.Module):
    def __init__(self, feature_extractor, mlp_model):
        super(FullModel, self).__init__()
        self.feature_extractor = feature_extractor
        self.mlp_model = mlp_model

    def forward_features(self, x):
        return self.feature_extractor(x)

    def forward(self, x):
        features = self.forward_features(x)
        return self.mlp_model(features)

# --- Training and Evaluation Function ---
def train_and_evaluate(
    train_loader, valid_loader, model, criterion, optimizer, scheduler, target_layer, 
    predicates, device, num_epochs=100, patience=5
):
    history = {
        "train_loss": [],
        "valid_loss": [],
        "valid_precision": [],
        "valid_recall": [],
        "valid_f1": []
    }

    best_f1 = 0
    epochs_without_improvement = 0

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")

        # Training
        model.train()
        total_train_loss = 0
        correct_predictions_train = 0
        total_samples_train = 0

        for images, labels, _, _ in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct_predictions_train += (preds == labels).sum().item()
            total_samples_train += labels.size(0)

        train_loss = total_train_loss / total_samples_train
        train_accuracy = correct_predictions_train / total_samples_train
        history["train_loss"].append(train_loss)

        print(f"Train Loss: {train_loss:.4f} - Train Accuracy: {train_accuracy:.4f}")

        # Validation
        model.eval()
        total_valid_loss = 0
        correct_predictions_valid = 0
        total_samples_valid = 0

        all_preds = []
        all_labels = []

        with torch.no_grad():
            for images, labels, _, _ in valid_loader:
                images, labels = images.to(device), labels.to(device)

                # Forward pass
                outputs = model(images)
                loss = criterion(outputs, labels)

                total_valid_loss += loss.item() * images.size(0)
                _, preds = torch.max(outputs, 1)
                correct_predictions_valid += (preds == labels).sum().item()
                total_samples_valid += labels.size(0)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

        valid_loss = total_valid_loss / total_samples_valid
        valid_accuracy = correct_predictions_valid / total_samples_valid

        # Calculate metrics
        precision = precision_score(all_labels, all_preds, average="weighted", zero_division=0)
        recall = recall_score(all_labels, all_preds, average="weighted", zero_division=0)
        f1 = f1_score(all_labels, all_preds, average="weighted", zero_division=0)

        history["valid_loss"].append(valid_loss)
        history["valid_precision"].append(precision)
        history["valid_recall"].append(recall)
        history["valid_f1"].append(f1)

        print(f"Valid Loss: {valid_loss:.4f} - Valid Accuracy: {valid_accuracy:.4f}")
        print(f"Valid Precision: {precision:.4f} - Valid Recall: {recall:.4f} - Valid F1: {f1:.4f}")

        # Early stopping logic
        if f1 > best_f1:
            best_f1 = f1
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= patience:
            print("Early stopping triggered.")
            break

        scheduler.step(valid_loss)

    return history

from pytorch_grad_cam import EigenCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from torchvision.transforms import Normalize


# Updated reshape_transform for ViTs
def reshape_transform(activations):
    if isinstance(activations, tuple):
        activations = activations[0]  # Extract tensor from tuple
    
    activations = activations[:, 1:, :]  # Exclude CLS token
    h, w = int(activations.shape[1] ** 0.5), int(activations.shape[1] ** 0.5)
    activations = activations.transpose(1, 2).reshape(activations.shape[0], -1, h, w)
    return activations

def generate_eigencam_vit(loader, model, dataset, n_samples=None):
    """
    Visualizes EigenCAM along with object name, subject name, and predicate information for the given number of samples.
    
    Args:
        loader: DataLoader containing the dataset.
        model: The ViT-based model for prediction.
        dataset: The dataset instance to access metadata (e.g., subject, predicate, object).
        n_samples: Number of samples to visualize. If None, all samples will be processed.
    """
    model.eval()
    target_layer = model.feature_extractor.vit.encoder.layer[-1]  # Use the last transformer block

    # Initialize EigenCAM
    cam = EigenCAM(
        model=model,
        target_layers=[target_layer],
        reshape_transform=reshape_transform,
    )

    count = 0
    for batch_idx, (images, labels, subject_names, object_names) in enumerate(loader):
        images = images.cuda()
        labels = labels.cuda()

        # Define the target classes
        targets = [ClassifierOutputTarget(label.item()) for label in labels]

        # Generate CAM for the batch
        grayscale_cams = cam(input_tensor=images, targets=targets)

        # Get predictions
        with torch.no_grad():
            outputs = model(images)  # Forward pass
            preds = torch.argmax(outputs, dim=1)  # Predicted class

        # Visualize and display metadata for each sample, but limit to n_samples
        for i in range(images.size(0)):
            if n_samples is not None and count >= n_samples:
                break

            # Extract metadata for the current sample
            subject = subject_names[i]
            object_name = object_names[i]
            ground_truth = dataset.predicates[labels[i].item()]  # Ground truth predicate
            prediction = dataset.predicates[preds[i].item()]  # Predicted predicate

            # Image processing for visualization
            image_np = images[i].permute(1, 2, 0).cpu().numpy()
            image_np = (image_np - image_np.min()) / (image_np.max() - image_np.min())  # Normalize to [0, 1]

            # Generate CAM overlay
            cam_image = show_cam_on_image(image_np, grayscale_cams[i], use_rgb=True)

            # Display the image, CAM, and metadata
            plt.imshow(cam_image)
            plt.title(f"Subject: {subject}, Object: {object_name}\n"
                      f"Ground Truth: {ground_truth}, Predicted: {prediction}")
            plt.axis("off")
            plt.show()

            count += 1
        if n_samples is not None and count >= n_samples:
            break

def compute_confusion_matrix(loader, model, dataset):
    """
    Computes the confusion matrix for the entire dataset based on model predictions.
    
    Args:
        loader: DataLoader containing the dataset.
        model: The ViT-based model for prediction.
        dataset: The dataset instance to access predicates.
        
    Returns:
        confusion matrix (2D array).
    """
    model.eval()

    all_preds = []
    all_labels = []
    
    for batch_idx, (images, labels, _, _) in enumerate(loader):
        images = images.cuda()
        labels = labels.cuda()

        with torch.no_grad():
            outputs = model(images)  # Forward pass
            preds = torch.argmax(outputs, dim=1)  # Predicted class

        # Store predictions and true labels
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds, labels=range(len(dataset.predicates)))
    return cm

# --- Main Execution ---
# Initialize data and model
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

annotations_path = "/kaggle/working/spatialsense_data/spatialsense/annots_spatialsenseplus.json"
extracted_images_path = "/kaggle/working/spatialsense_data/spatialsense/images/images"
predicates = ["above", "to the left of", "to the right of", "under"]

# Parse Annotations
data = parse_annotations(annotations_path, extracted_images_path)

# Train/Validation Split
train_data = [item for item in data if item['split'] == 'train' and item['predicate'] in predicates]
valid_data = [item for item in data if item['split'] == 'valid' and item['predicate'] in predicates]

train_dataset = SpatialSenseDataset(train_data, predicates, transform=transform)
valid_dataset = SpatialSenseDataset(valid_data, predicates, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=16, shuffle=False)

# Class Weights
predicate_counts = Counter(item["predicate"] for item in train_data)
class_weights = [1.0 / predicate_counts[pred] if pred in predicate_counts else 1.0 for pred in predicates]
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float32).to(device)

# Initialize Models and Optimizer
mlp_model = MLPModel(input_dim=768, num_classes=len(predicates)).to(device)
full_model = FullModel(feature_extractor, mlp_model).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

optimizer = optim.Adam([
    {"params": feature_extractor.parameters(), "lr": 1e-4},
    {"params": mlp_model.parameters(), "lr": 1e-4}
], weight_decay=1e-4)

scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

# Train and evaluate the model
history = train_and_evaluate(
    train_loader=train_loader,
    valid_loader=valid_loader,
    model=full_model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    target_layer=full_model.feature_extractor.vit.encoder.layer[0].attention.output,  
    predicates=predicates,
    device=device,
    num_epochs=100,
    patience=20
)

# Plotting Functions
def plot_loss(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["train_loss"], label="Train Loss")
    plt.plot(history["valid_loss"], label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training vs. Validation Loss")
    plt.legend()
    plt.grid(True)
    plt.show()

def plot_metrics(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history["valid_precision"], label="Precision")
    plt.plot(history["valid_recall"], label="Recall")
    plt.plot(history["valid_f1"], label="F1 Score")
    plt.xlabel("Epoch")
    plt.ylabel("Score")
    plt.title("Validation Metrics: Precision, Recall, F1-Score")
    plt.legend()
    plt.grid(True)
    plt.show()

# Display Plots
plot_loss(history)
plot_metrics(history)

# Generate confusion matrix for the entire validation set
cm = compute_confusion_matrix(valid_loader, full_model, valid_dataset)

# Convert confusion matrix to dataframe for easy visualization
cm_df = pd.DataFrame(cm, index=valid_dataset.predicates, columns=valid_dataset.predicates)

# Plot the confusion matrix using seaborn heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm_df, annot=True, cmap='Blues', fmt='g', xticklabels=valid_dataset.predicates, yticklabels=valid_dataset.predicates)
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')
plt.xticks(rotation=90)
plt.yticks(rotation=0)
plt.show()
# Call the updated visualization function
generate_eigencam_vit(valid_loader, full_model,valid_dataset, n_samples=5)
