In [None]:
# Import necessary libraries
import os
import json
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from sklearn.metrics import r2_score

# Set configurations and paths
BASE_DIR = "coco"  # Base directory for the COCO dataset
SPLIT_DIR = os.path.join(BASE_DIR, "splits")
ANNOTATION_SPLITS_DIR = os.path.join(BASE_DIR, "annotation_splits")
IMAGE_SIZE = (224, 224)  # Resize all images to 224x224
BATCH_SIZE = 128
NUM_WORKERS = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mean and std for normalization (ImageNet values)
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

class COCOCustomDataset(Dataset):
    def __init__(self, split, transform=None):
        """
        Custom COCO Dataset for bounding box regression.
        Args:
            split (str): One of 'train', 'val', or 'test'.
            transform (callable, optional): Transformations to apply to the images and bounding boxes.
        """
        self.image_dir = os.path.join(SPLIT_DIR, split)
        self.annotation_file = os.path.join(ANNOTATION_SPLITS_DIR, f"instances_{split}2017.json")
        self.transform = transform

        # Load annotations
        with open(self.annotation_file, "r") as f:
            self.annotations = json.load(f)

        # Map image IDs to file names
        self.image_id_to_filename = {img["id"]: img["file_name"] for img in self.annotations["images"]}

        # Prepare data
        self.data = self._prepare_bounding_boxes()

    def _prepare_bounding_boxes(self):
        """Prepare bounding boxes for regression."""
        data = []
        for ann in self.annotations["annotations"]:
            image_id = ann["image_id"]
            bbox = ann["bbox"]  # COCO format: [x_min, y_min, width, height]
            # Convert bbox to [x1, y1, x2, y2] format
            x1 = bbox[0]
            y1 = bbox[1]
            x2 = bbox[0] + bbox[2]
            y2 = bbox[1] + bbox[3]
            bbox_converted = [x1, y1, x2, y2]
            data.append({"image_id": image_id, "bbox": bbox_converted})
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data_item = self.data[idx]
        image_id = data_item["image_id"]
        image_file = os.path.join(self.image_dir, self.image_id_to_filename[image_id])

        # Load image
        image = Image.open(image_file).convert("RGB")
        width, height = image.size

        # Load bbox
        bbox = data_item["bbox"]  # [x1, y1, x2, y2]

        # Apply transformations (image and bbox)
        if self.transform:
            image, bbox = self.transform(image, bbox)

        return image, bbox

class ResizeNormalizeAndToTensor:
    def __init__(self, size, mean, std):
        """
        Custom transform to resize image and adjust bounding boxes accordingly,
        then convert image to tensor and normalize.
        """
        self.size = size
        self.mean = mean
        self.std = std

    def __call__(self, image, bbox):
        # Resize image
        original_width, original_height = image.size
        image = image.resize(self.size)

        # Scale bbox
        scale_x = self.size[0] / original_width
        scale_y = self.size[1] / original_height
        bbox = [
            bbox[0] * scale_x,
            bbox[1] * scale_y,
            bbox[2] * scale_x,
            bbox[3] * scale_y
        ]

        # Convert image to tensor and normalize
        image = transforms.ToTensor()(image)
        image = transforms.Normalize(self.mean, self.std)(image)
        bbox = torch.tensor(bbox, dtype=torch.float32)

        return image, bbox

# Define transformations
transform = ResizeNormalizeAndToTensor(IMAGE_SIZE, mean, std)

# Create datasets
train_dataset = COCOCustomDataset(split="train", transform=transform)
val_dataset = COCOCustomDataset(split="val", transform=transform)
test_dataset = COCOCustomDataset(split="test", transform=transform)

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS
)
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)
test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS
)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Validation dataset: {len(val_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

def unnormalize_image(image_tensor, mean, std):
    """
    Unnormalize a tensor image for visualization.
    """
    mean = torch.tensor(mean).view(-1, 1, 1)
    std = torch.tensor(std).view(-1, 1, 1)
    image_tensor = image_tensor * std + mean
    return image_tensor

def visualize_sample_with_annotation(dataset, index=0):
    """
    Visualize a single image with its ground truth bounding box and print its annotation on the image.
    Args:
        dataset: The dataset object (e.g., train_dataset, val_dataset, test_dataset).
        index: The index of the sample to visualize.
    """
    # Get the image and bbox from the dataset
    image, bbox = dataset[index]

    # Un-normalize image
    image = unnormalize_image(image, mean, std)

    # Convert tensor image back to numpy for visualization
    image_np = image.permute(1, 2, 0).numpy()  # Convert from CxHxW to HxWxC
    image_np = (image_np * 255).astype("uint8")  # Convert to uint8 for visualization

    # Plot the image and bbox
    fig, ax = plt.subplots(1, figsize=(8, 8))
    ax.imshow(image_np)
    # Draw the bounding box
    rect = patches.Rectangle(
        (bbox[0].item(), bbox[1].item()),  # (x_min, y_min)
        bbox[2].item() - bbox[0].item(),  # width
        bbox[3].item() - bbox[1].item(),  # height
        linewidth=2,
        edgecolor="r",
        facecolor="none",
    )
    ax.add_patch(rect)

    # Annotate the image with the bounding box coordinates
    bbox_annotation = f"Bounding Box: {bbox.tolist()}"
    ax.text(
        10, 10,  # Position at the top-left corner
        bbox_annotation,
        color="yellow",
        fontsize=12,
        bbox=dict(facecolor="black", alpha=0.7),
    )

    ax.set_title(f"Ground Truth Bounding Box (Index {index})")
    plt.axis("off")  # Remove axis for better visualization
    plt.show()

# Example usage
visualize_sample_with_annotation(train_dataset, index=4)

def print_input_and_target_shapes(dataloader):
    """
    Print the shapes of inputs and targets in a batch.
    Args:
        dataloader: DataLoader object.
    """
    for images, targets in dataloader:
        print(f"Input Shape (Images): {images.shape}")  # Shape of the input images
        print(f"Target Shape (Bounding Boxes): {targets.shape}")  # Shape of the target bounding boxes
        print(f"Example Target Bounding Box: {targets[0]}")  # Show an example bounding box
        break  # Print for one batch only

# Example usage with the train loader
print_input_and_target_shapes(train_loader)

class MLPModel(nn.Module):
    def __init__(self, input_size=(3, 224, 224), output_size=4):
        super(MLPModel, self).__init__()
        self.input_dim = input_size[0] * input_size[1] * input_size[2]  # Flatten image dimensions
        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(self.input_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, output_size)  # Output 4 bounding box coordinates
        )

    def forward(self, x):
        return self.model(x)

class CNNModel(nn.Module):
    def __init__(self, output_size=4):
        super(CNNModel, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # Reduce to 112x112
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2), # Reduce to 56x56
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)  # Reduce to 1x1
            )
        self.regressor = nn.Sequential(
            nn.Flatten(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, output_size)  # Output 4 bounding box coordinates
        )

    def forward(self, x):
        features = self.encoder(x)
        return self.regressor(features)

class ViTModel(nn.Module):
    def __init__(self, image_size=224, patch_size=16, dim=768, depth=6, heads=8, mlp_dim=1024, output_size=4):
        super(ViTModel, self).__init__()
        assert image_size % patch_size == 0, "Image size must be divisible by patch size"
        self.num_patches = (image_size // patch_size) ** 2
        self.patch_dim = 3 * patch_size * patch_size

        self.patch_embedding = nn.Linear(self.patch_dim, dim)
        self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim),
            num_layers=depth
        )

        self.regressor = nn.Sequential(
            nn.Linear(dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, output_size)  # Output 4 bounding box coordinates
        )

    def forward(self, x):
        batch_size = x.size(0)
        # Create patches
        patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
        patches = patches.contiguous().view(batch_size, -1, self.patch_dim)
        x = self.patch_embedding(patches)
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.positional_embedding
        x = self.transformer(x.permute(1, 0, 2)).permute(1, 0, 2)
        return self.regressor(x[:, 0])

class TransformerEncoderModel(nn.Module):
    def __init__(self, image_size=224, dim=768, depth=6, heads=8, mlp_dim=1024, output_size=4):
        super(TransformerEncoderModel, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_embedding = nn.Linear(3 * image_size * image_size, dim)

        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=dim, nhead=heads, dim_feedforward=mlp_dim),
            num_layers=depth
        )

        self.regressor = nn.Sequential(
            nn.Linear(dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, output_size)  # Output 4 bounding box coordinates
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear_embedding(x)
        x = self.transformer(x.unsqueeze(1)).squeeze(1)
        return self.regressor(x)

# Create output directory
output_dir = "./regressors"
os.makedirs(output_dir, exist_ok=True)

def calculate_regression_metrics(preds, targets):
    """Calculate regression metrics: MAE, MSE, and R2."""
    preds = preds.detach().cpu().numpy()
    targets = targets.detach().cpu().numpy()

    mae = abs(preds - targets).mean()
    mse = ((preds - targets) ** 2).mean()
    r2 = r2_score(targets, preds)

    return mae, mse, r2

def train_model(model, train_loader, val_loader, model_name, num_epochs=10, learning_rate=1e-4):
    """Train and validate a single model."""
    print(f"Training {model_name}...")
    model.to(device)
    criterion = nn.SmoothL1Loss()  # Smooth L1 Loss for bounding box regression
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    best_val_loss = float("inf")

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} (Train)"):
            images, targets = images.to(device), targets.to(device)

            # Forward pass
            preds = model(images)
            loss = criterion(preds, targets)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)

        # Validation phase
        model.eval()
        val_loss = 0.0
        maes = []
        mses = []
        with torch.no_grad():
            for images, targets in tqdm(val_loader, desc=f"Epoch {epoch + 1}/{num_epochs} (Validation)"):
                images, targets = images.to(device), targets.to(device)

                preds = model(images)
                loss = criterion(preds, targets)

                val_loss += loss.item()

                # Metrics
                mae, mse, _ = calculate_regression_metrics(preds, targets)
                maes.append(mae)
                mses.append(mse)

        val_loss /= len(val_loader)
        mean_mae = sum(maes) / len(maes)
        mean_mse = sum(mses) / len(mses)

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}, MAE: {mean_mae:.4f}, MSE: {mean_mse:.4f}")

        # Save the best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), os.path.join(output_dir, f"{model_name}.pth"))
            print(f"Best model saved for {model_name}!")

# Adjusted loaders with appropriate batch size
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=NUM_WORKERS)

# Models
#mlp_model = MLPModel(input_size=(3, 224, 224), output_size=4)
cnn_model = CNNModel(output_size=4)
vit_model = ViTModel(image_size=224, patch_size=16, dim=768, depth=6, heads=8, mlp_dim=1024, output_size=4)
transformer_encoder_model = TransformerEncoderModel(image_size=224, dim=768, depth=6, heads=8, mlp_dim=1024, output_size=4)

# Training each model
#train_model(mlp_model, train_loader, val_loader, "MLPModel", num_epochs=30)
train_model(cnn_model, train_loader, val_loader, "CNNModel", num_epochs=30)
train_model(vit_model, train_loader, val_loader, "ViTModel", num_epochs=30)
train_model(transformer_encoder_model, train_loader, val_loader, "TransformerEncoderModel", num_epochs=30)

print("Training complete. Models are saved in ./regressors")

def visualize_predictions(model_name, dataset, index, prediction):
    """
    Visualize ground truth and predicted bounding box for a single sample.
    Args:
        model_name: Name of the model being visualized.
        dataset: Dataset object.
        index: Index of the sample in the dataset.
        prediction: Predicted bounding box.
    """
    image, gt_bbox = dataset[index]

    # Un-normalize image
    image = unnormalize_image(image, mean, std)

    # Convert tensor image to numpy for visualization
    image_np = image.permute(1, 2, 0).numpy()
    image_np = (image_np * 255).astype("uint8")

    # Plot the image
    fig, ax = plt.subplots(1, figsize=(8, 8))
    ax.imshow(image_np)

    # Draw ground truth bounding box
    gt_rect = patches.Rectangle(
        (gt_bbox[0].item(), gt_bbox[1].item()),  # x_min, y_min
        gt_bbox[2].item() - gt_bbox[0].item(),  # width
        gt_bbox[3].item() - gt_bbox[1].item(),  # height
        linewidth=2,
        edgecolor="g",
        facecolor="none",
        label="Ground Truth",
    )
    ax.add_patch(gt_rect)

    # Draw predicted bounding box
    pred_rect = patches.Rectangle(
        (prediction[0].item(), prediction[1].item()),  # x_min, y_min
        prediction[2].item() - prediction[0].item(),  # width
        prediction[3].item() - prediction[1].item(),  # height
        linewidth=2,
        edgecolor="r",
        facecolor="none",
        label="Prediction",
    )
    ax.add_patch(pred_rect)

    ax.legend(loc="upper left")
    ax.set_title(f"{model_name}: Ground Truth vs Prediction")
    plt.axis("off")
    plt.show()

def test_model(model, test_loader, model_name, dataset):
    """Test the model and calculate metrics."""
    model.to(device)
    model.eval()

    maes, mses, r2_scores = [], [], []
    with torch.no_grad():
        for images, targets in tqdm(test_loader, desc=f"Testing {model_name}"):
            images, targets = images.to(device), targets.to(device)
            preds = model(images)

            # Metrics
            mae, mse, r2 = calculate_regression_metrics(preds, targets)
            maes.append(mae)
            mses.append(mse)
            r2_scores.append(r2)

    # Calculate mean metrics
    mean_mae = sum(maes) / len(maes)
    mean_mse = sum(mses) / len(mses)
    mean_r2 = sum(r2_scores) / len(r2_scores)

    print(f"{model_name} Test Metrics:")
    print(f"Mean Absolute Error (MAE): {mean_mae:.4f}")
    print(f"Mean Squared Error (MSE): {mean_mse:.4f}")
    print(f"R2 Score: {mean_r2:.4f}")

    # Visualize one sample prediction
    sample_index = 0  # Change this to visualize different samples
    image, _ = dataset[sample_index]
    image = image.unsqueeze(0).to(device)  # Add batch dimension
    prediction = model(image).squeeze(0).cpu()  # Remove batch dimension
    visualize_predictions(model_name, dataset, sample_index, prediction)

# Load models
#mlp_model = MLPModel(input_size=(3, 224, 224), output_size=4)
cnn_model = CNNModel(output_size=4)
vit_model = ViTModel(image_size=224, patch_size=16, dim=768, depth=6, heads=8, mlp_dim=1024, output_size=4)
transformer_encoder_model = TransformerEncoderModel(image_size=224, dim=768, depth=6, heads=8, mlp_dim=1024, output_size=4)

#mlp_model.load_state_dict(torch.load("./regressors/MLPModel.pth"))
cnn_model.load_state_dict(torch.load("./regressors/CNNModel.pth"))
vit_model.load_state_dict(torch.load("./regressors/ViTModel.pth"))
transformer_encoder_model.load_state_dict(torch.load("./regressors/TransformerEncoderModel.pth"))

# Test each model
#test_model(mlp_model, test_loader, "MLPModel", test_dataset)
test_model(cnn_model, test_loader, "CNNModel", test_dataset)
test_model(vit_model, test_loader, "ViTModel", test_dataset)
test_model(transformer_encoder_model, test_loader, "TransformerEncoderModel", test_dataset)
