In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import numpy as np
import cv2
import os
import glob
import scipy.io as sio
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
import clip
import warnings
import math
warnings.filterwarnings('ignore')

In [None]:
class EBCHead(nn.Module):
    def _init_(self, input_dim, num_bins=100, reduction=8):
        super(EBCHead, self)._init_()
        self.num_bins = num_bins
        self.reduction = reduction 

        self.classifier = nn.Sequential(
            nn.Conv2d(input_dim, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_bins, kernel_size=1),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(input_dim, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 1, kernel_size=1),
            nn.ReLU(inplace=True) 
        )

        self.register_buffer('bin_centers', torch.arange(0, num_bins, dtype=torch.float32))

    def forward(self, x):
        cls_logits = self.classifier(x)
        cls_probs = F.softmax(cls_logits, dim=1)
        density_map = self.regressor(x)
        count_map = torch.sum(cls_probs * self.bin_centers.view(1, -1, 1, 1), dim=1, keepdim=True)

        return {
            'cls_logits': cls_logits,
            'cls_probs': cls_probs,
            'density_map': density_map,
            'count_map': count_map
        }

In [None]:
class CLIPEncoder(nn.Module):
    def _init_(self, model_name='ViT-B/16', freeze_clip=True, input_size=384):
        super(CLIPEncoder, self)._init_()
        self.input_size = input_size
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.clip_model, self.preprocess = clip.load(model_name, device=device)

        if hasattr(self.clip_model.visual, 'conv1') and hasattr(self.clip_model.visual.conv1, 'weight'):
            self.clip_conv1_dtype = self.clip_model.visual.conv1.weight.dtype
        else:
            self.clip_conv1_dtype = torch.float32

        if freeze_clip:
            for param in self.clip_model.parameters():
                param.requires_grad = False
        self.freeze_clip = freeze_clip

        visual_encoder_type = self.clip_model.visual._class.name_

        if 'VisionTransformer' in visual_encoder_type or 'ViT' in visual_encoder_type:
            self.clip_patch_size = self.clip_model.visual.conv1.kernel_size[0]
            if model_name == 'ViT-B/16':
                self.feature_dim = 768
            else:
                self.feature_dim = self.clip_model.visual.output_dim

            if hasattr(self.clip_model.visual, 'positional_embedding'):
                original_pos_embed = self.clip_model.visual.positional_embedding.float()
                original_seq_len = original_pos_embed.shape[0]
                target_seq_len = (input_size // self.clip_patch_size) * (input_size // self.clip_patch_size) + 1

                if original_seq_len != target_seq_len:
                    cls_pos_embed = original_pos_embed[:1, :]
                    patch_pos_embed = original_pos_embed[1:, :]

                    original_grid_size = int(math.sqrt(patch_pos_embed.shape[0]))
                    patch_pos_embed = patch_pos_embed.view(original_grid_size, original_grid_size, -1)
                    patch_pos_embed = patch_pos_embed.permute(2, 0, 1).unsqueeze(0)

                    new_grid_size = input_size // self.clip_patch_size
                    interpolated_patch_pos_embed = F.interpolate(patch_pos_embed,
                                                                 size=(new_grid_size, new_grid_size),
                                                                 mode='bicubic',
                                                                 align_corners=False)

                    interpolated_patch_pos_embed = interpolated_patch_pos_embed.squeeze(0).flatten(1, 2).permute(1, 0)
                    new_pos_embed = torch.cat([cls_pos_embed, interpolated_patch_pos_embed], dim=0)

                    self.clip_model.visual.positional_embedding = nn.Parameter(new_pos_embed.to(self.clip_conv1_dtype).to(device))
                else:
                    self.clip_model.visual.positional_embedding = nn.Parameter(self.clip_model.visual.positional_embedding.to(device).to(self.clip_conv1_dtype))

        elif 'ResNet' in visual_encoder_type or 'RN' in visual_encoder_type:
            self.clip_patch_size = 32
            if hasattr(self.clip_model.visual, 'attnpool'):
                raise NotImplementedError(f"CLIPEncoder does not currently support ResNet models with attention pooling like {model_name} for spatial features.")
            elif hasattr(self.clip_model.visual, 'layer4'):
                if hasattr(self.clip_model.visual.layer4[-1], 'conv3'):
                    self.feature_dim = self.clip_model.visual.layer4[-1].conv3.out_channels
                elif hasattr(self.clip_model.visual.layer4[-1], 'conv2'):
                    self.feature_dim = self.clip_model.visual.layer4[-1].conv2.out_channels
                else:
                    raise AttributeError("Could not find output channels in CLIP ResNet layer4.")
            else:
                raise NotImplementedError(f"CLIPEncoder does not currently support ResNet models like {model_name} for spatial features.")
        else:
            raise NotImplementedError(f"Unsupported CLIP visual encoder type: {visual_encoder_type}")

        self.output_spatial_size = input_size // self.clip_patch_size
        self.adapter_conv1 = nn.Conv2d(self.feature_dim, 512, kernel_size=3, padding=1).to(device)
        self.adapter_relu1 = nn.ReLU(inplace=True).to(device)
        self.adapter_conv2 = nn.Conv2d(512, 512, kernel_size=3, padding=1).to(device)
        self.adapter_relu2 = nn.ReLU(inplace=True).to(device)

    def forward(self, x):
        visual_transformer = self.clip_model.visual
        current_device = x.device

        with torch.no_grad() if self.freeze_clip else torch.enable_grad():
            visual_encoder_type = visual_transformer._class.name_

            if 'VisionTransformer' in visual_encoder_type:
                x = x.to(visual_transformer.conv1.weight.dtype)
                x = visual_transformer.conv1(x)
                x = x.flatten(2).permute(0, 2, 1)
                class_embedding = visual_transformer.class_embedding.to(x.dtype).to(current_device)
                pos_embedding = visual_transformer.positional_embedding.to(x.dtype).to(current_device)
                x = torch.cat([class_embedding.unsqueeze(0).expand(x.shape[0], -1, -1), x], dim=1)
                x = x + pos_embedding.unsqueeze(0)
                x = visual_transformer.ln_pre(x)
                for resblock in visual_transformer.transformer.resblocks:
                    x = resblock(x)
                if hasattr(visual_transformer, 'ln_post'):
                    x = visual_transformer.ln_post(x)

                spatial_features_flat = x[:, 1:, :]
                B, N_patches, D = spatial_features_flat.shape
                H = W = int(math.sqrt(N_patches))
                if H * W != N_patches:
                    raise ValueError(f"Could not reshape {N_patches} patches into a square grid. Input size might not be suitable.")
                spatial_features = spatial_features_flat.permute(0, 2, 1).view(B, D, H, W)

            elif 'ResNet' in visual_encoder_type:
                x_rn = x.to(visual_transformer.conv1.weight.dtype)

                try:
                    x_rn = visual_transformer.conv1(x_rn)
                    x_rn = visual_transformer.bn1(x_rn)
                    x_rn = visual_transformer.relu(x_rn)
                    x_rn = visual_transformer.avgpool(x_rn)
                    x_rn = visual_transformer.layer1(x_rn)
                    x_rn = visual_transformer.layer2(x_rn)
                    x_rn = visual_transformer.layer3(x_rn)
                    spatial_features = visual_transformer.layer4(x_rn)

                    _, _, h_feat, w_feat = spatial_features.shape
                    expected_h = expected_w = self.output_spatial_size

                    if h_feat != expected_h or w_feat != expected_w:
                        print(f"Warning: ResNet spatial feature size mismatch. Expected {expected_h}x{expected_w}, got {h_feat}x{w_feat}.")

                    if spatial_features.shape[1] != self.feature_dim:
                        print(f"Warning: ResNet feature dimension mismatch. Expected {self.feature_dim}, got {spatial_features.shape[1]}.")

                except Exception as e:
                    raise NotImplementedError(f"Failed to extract spatial features from CLIP ResNet model: {model_name}.")

            else:
                raise NotImplementedError(f"Unsupported CLIP visual encoder type in forward pass: {visual_encoder_type}")

        features = spatial_features.float()
        features = self.adapter_conv1(features)
        features = self.adapter_relu1(features)

        if features.shape[1] != 512:
            raise ValueError(f"Channel mismatch: Expected 512 channels after adapter_conv1 and relu1, but got {features.shape[1]}.")

        features = self.adapter_conv2(features)
        features = self.adapter_relu2(features)

        return features

In [None]:
class CLIPEBC(nn.Module):
    def __init__(self, clip_model='ViT-B/16', num_bins=100, reduction=8, freeze_clip=True, input_size=384):
        super(CLIPEBC, self).__init__()

        self.reduction = reduction
        self.input_size = input_size

        self.encoder = CLIPEncoder(clip_model, freeze_clip, input_size=input_size)
        self.actual_encoder_output_spatial_size = input_size // self.encoder.clip_patch_size
        self.model_output_size = input_size // reduction

        if self.model_output_size != self.actual_encoder_output_spatial_size:
            print(f"WARNING (CLIPEBC init): Dataset target GT size ({self.model_output_size}x{self.model_output_size}) based on reduction={self.reduction}")
            print(f"         does NOT match actual encoder output size ({self.actual_encoder_output_spatial_size}x{self.actual_encoder_output_spatial_size}) based on input_size={self.input_size} and CLIP patch_size={self.encoder.clip_patch_size}.")
            print(f"         The 'run_clip_ebc_pipeline' should have corrected the 'reduction' parameter to {self.encoder.clip_patch_size}.")

        self.ebc_head = EBCHead(512, num_bins, reduction)

    def forward(self, x):
        features = self.encoder(x)
        outputs = self.ebc_head(features)
        return outputs

In [None]:
class DMCountLoss(nn.Module):
    def __init__(self, alpha=1.0, beta=1.0, gamma=1.0, num_bins=100):
        super(DMCountLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

        self.mse_loss = nn.MSELoss(reduction='sum')
        self.l1_loss = nn.L1Loss(reduction='sum')

        self.register_buffer('bin_centers', torch.arange(0, num_bins, dtype=torch.float32))

    def forward(self, pred_outputs, gt_density_map, gt_total_count, gt_block_counts):
        pred_density = pred_outputs['density_map']
        pred_count_map = pred_outputs['count_map']

        pred_total_count = pred_count_map.sum(dim=(2, 3)).squeeze()
        if pred_total_count.ndim == 0:
            pred_total_count = pred_total_count.unsqueeze(0)

        total_count_loss = self.l1_loss(pred_total_count, gt_total_count.float()) / pred_total_count.shape[0]
        gt_block_counts_float = gt_block_counts.float()
        block_count_loss = self.l1_loss(pred_count_map.squeeze(1), gt_block_counts_float) / pred_count_map.shape[0]

        if pred_density.dtype != gt_density_map.dtype:
            gt_density_map = gt_density_map.to(pred_density.dtype)

        if pred_density.shape[-2:] != gt_density_map.shape[-2:]:
            raise ValueError("Spatial shape mismatch between predicted and ground truth density maps. Adjust 'reduction' in pipeline configuration.")

        density_loss = self.mse_loss(pred_density, gt_density_map) / pred_density.shape[0]

        total_loss = self.alpha * density_loss + \
                     self.beta * total_count_loss + \
                     self.gamma * block_count_loss

        return total_loss, density_loss, total_count_loss, block_count_loss


In [None]:
class CrowdDataset(Dataset):
    def __init__(self, images_dir, gt_dir, input_size=384, is_train=True, reduction=8, num_bins=100):
        super().__init__()
        self.images_dir = images_dir
        self.gt_dir = gt_dir
        self.input_size = input_size
        self.is_train = is_train
        self.reduction = reduction
        self.model_output_size = self.input_size // self.reduction
        self.num_bins = num_bins

        self.image_files = sorted(glob.glob(os.path.join(images_dir, '*')))

        self.transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        if is_train:
            self.transform_aug = transforms.Compose([
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            ])
        else:
            self.transform_aug = None

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = self.image_files[idx]
        image = Image.open(img_path).convert('RGB')
        original_width, original_height = image.size

        gt_name = os.path.basename(img_path).replace('.jpg', '.mat').replace('.png', '.mat')
        gt_path = os.path.join(self.gt_dir, gt_name)

        points = np.array([]).reshape(0, 2)
        if os.path.exists(gt_path):
            try:
                gt_data = sio.loadmat(gt_path)
                if 'image_info' in gt_data:
                    points = gt_data['image_info'][0,0][0,0][0]
                elif 'annPoints' in gt_data:
                    points = gt_data['annPoints']
                else:
                    found = False
                    for key in gt_data.keys():
                        if not key.startswith(''):
                            temp_points = gt_data[key]
                            if len(temp_points.shape) == 2 and temp_points.shape[1] >= 2:
                                points = temp_points[:, :2]
                                found = True
                                break
                    if not found:
                        points = np.array([]).reshape(0, 2)
            except Exception:
                points = np.array([]).reshape(0, 2)

        density_map_orig_res = self.create_density_map(points, (original_width, original_height))

        gt_block_counts = np.zeros((self.model_output_size, self.model_output_size), dtype=np.int32)
        if len(points) > 0:
            scaled_points_x = points[:, 0] * (self.model_output_size / original_width)
            scaled_points_y = points[:, 1] * (self.model_output_size / original_height)
            for i in range(len(scaled_points_x)):
                block_x = int(np.clip(scaled_points_x[i], 0, self.model_output_size - 1))
                block_y = int(np.clip(scaled_points_y[i], 0, self.model_output_size - 1))
                gt_block_counts[block_y, block_x] += 1

        gt_block_counts_tensor = torch.from_numpy(gt_block_counts).long()

        if self.is_train and self.transform_aug:
            image = self.transform_aug(image)

        image_tensor = self.transform(image)

        density_map_tensor_orig = torch.from_numpy(density_map_orig_res).float().unsqueeze(0)
        original_sum_density = density_map_tensor_orig.sum().item()

        density_map_resized = F.interpolate(density_map_tensor_orig.unsqueeze(0),
                                            size=(self.model_output_size, self.model_output_size),
                                            mode='bilinear',
                                            align_corners=False).squeeze(0)

        if original_sum_density > 0:
            current_sum_density = density_map_resized.sum().item()
            if current_sum_density > 1e-6:
                density_map_resized = density_map_resized * (original_sum_density / current_sum_density)
            else:
                density_map_resized = torch.zeros_like(density_map_resized)
        elif density_map_resized.sum().item() > 1e-6:
            density_map_resized = torch.zeros_like(density_map_resized)

        gt_total_count = torch.tensor(len(points), dtype=torch.float32)

        return {
            'image': image_tensor,
            'density_map': density_map_resized,
            'gt_count': gt_total_count,
            'gt_block_counts': gt_block_counts_tensor,
            'points': points
        }

    def create_density_map(self, points, img_size):
        width, height = img_size
        density_map = np.zeros((height, width), dtype=np.float32)
        if len(points) == 0:
            return density_map
        sigma = max(1, min(width, height) // 80)
        for point in points:
            x, y = int(round(point[0])), int(round(point[1]))
            if 0 <= x < width and 0 <= y < height:
                kernel_radius = int(round(3 * sigma))
                kernel_size = 2 * kernel_radius + 1
                gaussian_kernel = np.outer(
                    cv2.getGaussianKernel(kernel_size, sigma),
                    cv2.getGaussianKernel(kernel_size, sigma)
                )
                x_start_map = max(0, x - kernel_radius)
                y_start_map = max(0, y - kernel_radius)
                x_end_map = min(width, x + kernel_radius + 1)
                y_end_map = min(height, y + kernel_radius + 1)
                x_start_kernel = kernel_radius - (x - x_start_map)
                y_start_kernel = kernel_radius - (y - y_start_map)
                x_end_kernel = kernel_radius + (x_end_map - x)
                y_end_kernel = kernel_radius + (y_end_map - y)
                density_map[y_start_map:y_end_map, x_start_map:x_end_map] += \
                    gaussian_kernel[y_start_kernel:y_end_kernel, x_start_kernel:x_end_kernel]
        current_sum = density_map.sum()
        if current_sum > 1e-6:
            if len(points) > 0:
                density_map = density_map / (current_sum / len(points))
        return density_map

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=100, lr=1e-4, device='cuda'):
    model.to(device)

    criterion = DMCountLoss(alpha=1.0, beta=1.0, gamma=1.0, num_bins=model.ebc_head.num_bins)

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    best_mae = float('inf')
    best_model_state = None

    train_losses = []
    val_maes = []

    actual_model_output_spatial_size = model.actual_encoder_output_spatial_size
    print(f"Model's actual spatial output size from encoder: {actual_model_output_spatial_size}x{actual_model_output_spatial_size}")
    dataset_target_output_size = train_loader.dataset.model_output_size
    if dataset_target_output_size != actual_model_output_spatial_size:
         print(f"ERROR: Dataset expects GT at {dataset_target_output_size}x{dataset_target_output_size}, but model outputs {actual_model_output_spatial_size}x{actual_model_output_spatial_size}.")
         print("         This indicates a mismatch in the 'reduction' parameter. The run_clip_ebc_pipeline should have already corrected this.")
         raise ValueError(f"Dataset/Model spatial size mismatch. Dataset target: {dataset_target_output_size}, Model actual: {actual_model_output_spatial_size}. Check 'reduction' parameter setup.")

    for epoch in range(num_epochs):
        model.train()
        total_train_loss = 0.0
        train_density_loss_sum = 0.0
        train_total_count_loss_sum = 0.0
        train_block_count_loss_sum = 0.0

        train_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Training')

        for batch in train_pbar:
            images = batch['image'].to(device)
            gt_density_map = batch['density_map'].to(device).float()
            gt_total_counts = batch['gt_count'].to(device).float()
            gt_block_counts = batch['gt_block_counts'].to(device).long()

            optimizer.zero_grad()

            outputs = model(images)

            current_total_loss, density_loss, total_count_loss, block_count_loss = criterion(
                outputs, gt_density_map, gt_total_counts, gt_block_counts
            )

            current_total_loss.backward()
            optimizer.step()

            total_train_loss += current_total_loss.item()
            train_density_loss_sum += density_loss.item()
            train_total_count_loss_sum += total_count_loss.item()
            train_block_count_loss_sum += block_count_loss.item()

            train_pbar.set_postfix({
                'T_Loss': f'{current_total_loss.item():.4f}',
                'D_Loss': f'{density_loss.item():.4f}',
                'C_Loss': f'{total_count_loss.item():.4f}',
                'B_Loss': f'{block_count_loss.item():.4f}'
            })

        scheduler.step()

        avg_train_total_loss = total_train_loss / len(train_loader)
        avg_train_density_loss = train_density_loss_sum / len(train_loader)
        avg_train_total_count_loss = train_total_count_loss_sum / len(train_loader)
        avg_train_block_count_loss = train_block_count_loss_sum / len(train_loader)

        train_losses.append(avg_train_total_loss)

        model.eval()
        val_mae = 0.0
        val_rmse = 0.0
        val_samples = 0

        with torch.no_grad():
            val_pbar = tqdm(val_loader, desc=f'Epoch {epoch+1}/{num_epochs} - Validation')

            for batch in val_pbar:
                images = batch['image'].to(device)
                gt_counts_batch = batch['gt_count'].cpu().numpy()

                outputs = model(images)
                pred_counts_batch = outputs['count_map'].sum(dim=(2, 3)).squeeze().cpu().numpy()

                if pred_counts_batch.ndim == 0:
                    pred_counts_batch = np.array([pred_counts_batch.item()])
                if gt_counts_batch.ndim == 0:
                    gt_counts_batch = np.array([gt_counts_batch.item()])

                mae_batch = np.mean(np.abs(pred_counts_batch - gt_counts_batch))
                rmse_batch = np.sqrt(np.mean((pred_counts_batch - gt_counts_batch) ** 2))

                val_mae += mae_batch * len(gt_counts_batch)
                val_rmse += rmse_batch * len(gt_counts_batch)
                val_samples += len(gt_counts_batch)

                val_pbar.set_postfix({'MAE': f'{mae_batch:.2f}'})

        avg_val_mae = val_mae / val_samples
        avg_val_rmse = val_rmse / val_samples

        val_maes.append(avg_val_mae)

        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Losses: Total={avg_train_total_loss:.4f}, Density={avg_train_density_loss:.4f}, '\
              f'TotalCount={avg_train_total_count_loss:.4f}, BlockCount={avg_train_block_count_loss:.4f}')
        print(f'  Val MAE: {avg_val_mae:.2f}')
        print(f'  Val RMSE: {avg_val_rmse:.2f}')
        print('-' * 50)

        if avg_val_mae < best_mae:
            best_mae = avg_val_mae
            best_model_state = model.state_dict().copy()
            print(f'New best MAE: {best_mae:.2f} - Model state saved.')

    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f'Loaded best model state with MAE: {best_mae:.2f}')

    return model, train_losses, val_maes


In [None]:
def test_model(model, test_loader, device='cuda'):
    model.eval()
    model.to(device)

    all_pred_counts = []
    all_gt_counts = []

    with torch.no_grad():
        test_pbar = tqdm(test_loader, desc='Testing')

        for batch in test_pbar:
            images = batch['image'].to(device)
            gt_counts_batch = batch['gt_count'].cpu().numpy()

            outputs = model(images)
            pred_counts_batch = outputs['count_map'].sum(dim=(2, 3)).squeeze().cpu().numpy()

            if pred_counts_batch.ndim == 0:
                pred_counts_batch = np.array([pred_counts_batch.item()])
            if gt_counts_batch.ndim == 0:
                gt_counts_batch = np.array([gt_counts_batch.item()])

            all_pred_counts.extend(pred_counts_batch)
            all_gt_counts.extend(gt_counts_batch)

    mae = mean_absolute_error(all_gt_counts, all_pred_counts)
    rmse = np.sqrt(mean_squared_error(all_gt_counts, all_pred_counts))

    print(f'--- Test Results ---')
    print(f'Final MAE: {mae:.2f}')
    print(f'Final RMSE: {rmse:.2f}')

    return mae, rmse, all_pred_counts, all_gt_counts


In [None]:
def visualize_results(model, dataset, device='cuda', num_samples=5):
    model.eval()
    model.to(device)

    fig, axes = plt.subplots(num_samples, 3, figsize=(15, 5*num_samples))
    if num_samples == 1:
        axes = axes.reshape(1, -1)

    inv_normalize = transforms.Normalize(
        mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225],
        std=[1/0.229, 1/0.224, 1/0.225]
    )

    print(f"\nVisualizing {min(num_samples, len(dataset))} sample results...")
    visual_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

    for i, batch in enumerate(visual_loader):
        if i >= num_samples:
            break

        image_tensor = batch['image'].to(device)
        gt_density = batch['density_map'].squeeze().cpu().numpy()
        gt_count = batch['gt_count'].item()

        with torch.no_grad():
            outputs = model(image_tensor)
            pred_density = outputs['density_map'].squeeze().cpu().numpy()
            pred_count = outputs['count_map'].sum().item()

        img_denorm = image_tensor.squeeze(0).cpu()
        img_denorm = inv_normalize(img_denorm)
        img_denorm = torch.clamp(img_denorm, 0, 1)

        axes[i, 0].imshow(img_denorm.permute(1, 2, 0))
        axes[i, 0].set_title(f'Original (GT: {gt_count:.0f})')
        axes[i, 0].axis('off')

        axes[i, 1].imshow(gt_density, cmap='jet')
        axes[i, 1].set_title(f'GT Density (Sum: {gt_density.sum():.0f})')
        axes[i, 1].axis('off')

        axes[i, 2].imshow(pred_density, cmap='jet')
        axes[i, 2].set_title(f'Pred Density (Sum: {pred_density.sum():.1f}, Count: {pred_count:.1f})')
        axes[i, 2].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
def run_clip_ebc_pipeline(train_images_dir, train_gt_dir, test_images_dir, test_gt_dir,
                          input_size=384, batch_size=8, num_epochs=50, lr=1e-4,
                          clip_model_name='ViT-B/16', num_bins=100, freeze_clip=False):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f'Using device: {device}')

    print("Determining CLIP encoder output size and correct reduction factor...")
    try:
        dummy_encoder = CLIPEncoder(clip_model_name, freeze_clip=True, input_size=input_size)
        actual_encoder_spatial_output_size = input_size // dummy_encoder.clip_patch_size
        print(f"CLIP model '{clip_model_name}' with input size {input_size} has patch size {dummy_encoder.clip_patch_size}, resulting in an actual spatial output size of {actual_encoder_spatial_output_size}x{actual_encoder_spatial_output_size}.")
        correct_reduction = dummy_encoder.clip_patch_size
        print(f"Setting pipeline's 'reduction' parameter to {correct_reduction} to ensure spatial alignment between dataset GT and model output.")
        reduction_for_pipeline = correct_reduction
        del dummy_encoder
        if device.type == 'cuda':
            torch.cuda.empty_cache()
    except NotImplementedError as e:
        print(f"Error determining CLIP encoder output size: {e}")
        print("Cannot proceed as dataset GT size cannot be aligned with model output size.")
        return None, None, None
    except Exception as e:
        print(f"An unexpected error occurred while determining CLIP encoder output size: {e}")
        print("Cannot proceed.")
        return None, None, None

    print('Creating datasets...')
    train_dataset = CrowdDataset(train_images_dir, train_gt_dir, input_size,
                                 is_train=True, reduction=reduction_for_pipeline, num_bins=num_bins)
    test_dataset = CrowdDataset(test_images_dir, test_gt_dir, input_size,
                                is_train=False, reduction=reduction_for_pipeline, num_bins=num_bins)

    print(f'Number of training samples: {len(train_dataset)}')
    print(f'Number of test/validation samples: {len(test_dataset)}')
    print(f"Dataset GT target size (input_size // reduction): {train_dataset.model_output_size}x{train_dataset.model_output_size}")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=torch.cuda.is_available())

    print('Creating CLIP-EBC model...')
    model = CLIPEBC(clip_model=clip_model_name, num_bins=num_bins, reduction=reduction_for_pipeline,
                    freeze_clip=freeze_clip, input_size=input_size)

    if model.actual_encoder_output_spatial_size != train_dataset.model_output_size:
        print(f"FATAL ERROR: Model's actual encoder output size ({model.actual_encoder_output_spatial_size}) still doesn't match Dataset target size ({train_dataset.model_output_size}).")
        print("There is likely a deep-seated issue in how the model's output size is determined or a fundamental incompatibility.")
        return None, None, None

    print(f"Model successfully created. Actual encoder output spatial size: {model.actual_encoder_output_spatial_size}x{model.actual_encoder_output_spatial_size}")

    print('Starting training...')
    model, train_losses, val_maes = train_model(
        model, train_loader, test_loader, num_epochs=num_epochs, lr=lr, device=device
    )

    if model is None:
        print("Training did not complete successfully.")
        return None, None, None

    print('\nTesting model on the full test set...')
    mae, rmse, _, _ = test_model(model, test_loader, device)

    print('\nVisualizing sample predictions...')
    visualize_results(model, test_dataset, device, num_samples=3)

    print('\nPlotting training curves...')
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(train_losses)
    plt.title('Training Total Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(val_maes)
    plt.title('Validation MAE Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    return model, mae, rmse


In [None]:
if __name__ == '__main__':
    TRAIN_IMAGES_DIR = 'crowd_wala_dataset\\train_data\\images'
    TRAIN_GT_DIR = 'crowd_wala_dataset\\train_data\\ground_truth'
    TEST_IMAGES_DIR = 'crowd_wala_dataset\\test_data\\images'
    TEST_GT_DIR = 'crowd_wala_dataset\\test_data\\ground_truth'

    if not all(os.path.exists(d) for d in [TRAIN_IMAGES_DIR, TRAIN_GT_DIR, TEST_IMAGES_DIR, TEST_GT_DIR]):
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
        print("! WARNING: Dataset paths are not set or do not exist.              !")
        print("! Please update TRAIN_IMAGES_DIR, TRAIN_GT_DIR, TEST_IMAGES_DIR,   !")
        print("! and TEST_GT_DIR in the example usage section with your actual    !")
        print("! dataset paths. The pipeline will not run without valid data.     !")
        print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    else:
        INPUT_SIZE = 384
        BATCH_SIZE = 8
        NUM_EPOCHS = 5
        LEARNING_RATE = 1e-4
        CLIP_MODEL = 'ViT-B/16'
        NUM_BINS = 100
        FREEZE_CLIP = False

        print(f"Starting CLIP-EBC pipeline with CLIP model: {CLIP_MODEL}, Input Size: {INPUT_SIZE}, Freeze CLIP: {FREEZE_CLIP}")
        trained_model, final_mae, final_rmse = run_clip_ebc_pipeline(
            train_images_dir=TRAIN_IMAGES_DIR,
            train_gt_dir=TRAIN_GT_DIR,
            test_images_dir=TEST_IMAGES_DIR,
            test_gt_dir=TEST_GT_DIR,
            input_size=INPUT_SIZE,
            batch_size=BATCH_SIZE,
            num_epochs=NUM_EPOCHS,
            lr=LEARNING_RATE,
            clip_model_name=CLIP_MODEL,
            num_bins=NUM_BINS,
            freeze_clip=FREEZE_CLIP
        )

        if trained_model is not None:
            print(f'\nPipeline Finished: Final Test MAE: {final_mae:.2f}, Final Test RMSE: {final_rmse:.2f}')
        else:
            print("\nPipeline aborted due to configuration or runtime issues.")
