In [2]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.mixture import GaussianMixture
from scipy.ndimage import maximum_filter
import matplotlib.pyplot as plt
from PIL import Image
import json
import collections.abc as container_abcs
from typing import List, Tuple, Dict
import warnings
warnings.filterwarnings('ignore')
from segment_anything import sam_model_registry, SamPredictor
from scipy import ndimage
import torchvision.transforms as transforms

In [3]:
SAM_CHECKPOINT_NAME = "sam_vit_h_4b8939.pth"
CHECKPOINTS_DIR = os.path.join(os.getcwd(), "checkpoints")
os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
SAM_CHECKPOINT_PATH = os.path.join(CHECKPOINTS_DIR, SAM_CHECKPOINT_NAME)

In [4]:
class MockSeemPredictor:
    def __init__(self, *args, **kwargs):
        print("MockSeemPredictor initialized. This means SEEM is not being truly used.")

    def __call__(self, image):
        return [], [], []

In [5]:
class CrowdDataset(Dataset):
    def __init__(self, data_root, split='train', transform=None):
        self.data_root = data_root
        self.split = split
        self.transform = transform

        self.image_dir = os.path.join(data_root, split, 'images')
        self.gt_dir = os.path.join(data_root, split, 'ground_truth')

        if not os.path.exists(self.image_dir):
            raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
        if not os.path.exists(self.gt_dir):
             print(f"Warning: Ground truth directory not found: {self.gt_dir}. Assuming no ground truth data for this split.")

        self.image_files = sorted([f for f in os.listdir(self.image_dir)
                                     if f.endswith(('.jpg', '.jpeg', '.png'))])
        if not self.image_files:
            raise ValueError(f"No image files found in {self.image_dir}. Please check your dataset structure.")

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.image_dir, img_name)

        try:
            img = Image.open(img_path)
        except Exception as e:
            print(f"[ERROR] Failed to open image {img_path}: {e}")
            raise RuntimeError(f"Failed to open image at {img_path}")

        if img.mode != "RGB":
            print(f"[Warning] Image {img_path} is in {img.mode}, converting to RGB.")
            img = img.convert("RGB")

        original_image_pil = img.copy()

        if self.transform:
            img = self.transform(img)

        if isinstance(img, torch.Tensor) and (img.ndim != 3 or img.shape[0] != 3):
            raise RuntimeError(f"Transformed image shape is invalid: {img.shape} from image {img_path}")

        gt_name = img_name.replace('.jpg', '.txt').replace('.png', '.txt').replace('.jpeg', '.txt')
        gt_path = os.path.join(self.gt_dir, gt_name)

        if os.path.exists(gt_path):
            try:
                gt_points = np.loadtxt(gt_path, delimiter=',', ndmin=2)
                gt_count = len(gt_points) if gt_points.size > 0 else 0
            except Exception as e:
                print(f"[Warning] Could not load ground truth from {gt_path}: {e}. Setting count to 0.")
                gt_points = np.array([])
                gt_count = 0
        else:
            gt_points = np.array([])
            gt_count = 0

        print(f"DEBUG: Type/Shape of 'image' being returned by CrowdDataset for {img_name}: {type(img)}, {img.shape if isinstance(img, torch.Tensor) else 'N/A'}")

        return {
            'image': img,
            'gt_count': gt_count,
            'gt_points': gt_points,
            'image_name': img_name,
            'original_image_pil': original_image_pil
        }


In [None]:
class AdaptiveSAM:
    def __init__(self, sam_checkpoint=SAM_CHECKPOINT_PATH, device="cuda"):
        self.device = device
        self.uncertain_threshold = 0.3

        print(f"Initializing SAM model from {sam_checkpoint}...")
        sam = sam_model_registry["vit_h"](checkpoint=sam_checkpoint)
        sam.to(device=device)
        self.sam_predictor = SamPredictor(sam)
        print("SAM Predictor initialized.")

        self.seem_predictor = MockSeemPredictor()
        print("SEEM Predictor initialized (Mocked).")

    def segment_with_seem(self, image_np):
        if isinstance(self.seem_predictor, MockSeemPredictor):
            return [], [], []

        return [], [], []

    def get_uncertain_regions(self, image, masks):
        h, w = image.shape[:2]
        covered_mask = np.zeros((h, w), dtype=bool)

        for mask in masks:
            if mask.shape == (h, w):
                covered_mask |= mask
            else:
                resized_mask = cv2.resize(mask.astype(np.uint8), (w, h), interpolation=cv2.INTER_NEAREST).astype(bool)
                covered_mask |= resized_mask

        uncertain_mask = ~covered_mask
        uncertain_ratio = np.sum(uncertain_mask) / (h * w) if (h * w) > 0 else 0.0

        return uncertain_mask, uncertain_ratio

    def adaptive_segmentation(self, image_pil):
        image_np = np.array(image_pil)
        h_orig, w_orig = image_np.shape[:2]

        if image_np.ndim == 2:  
            image_np = np.stack([image_np, image_np, image_np], axis=-1)
        elif image_np.ndim == 3 and image_np.shape[2] == 1: 
            image_np = np.concatenate([image_np, image_np, image_np], axis=-1)
        elif image_np.ndim == 3 and image_np.shape[2] != 3:
            image_pil = image_pil.convert("RGB")
            image_np = np.array(image_pil)


        all_masks = []
        all_scores = []

        masks_seem, labels_seem, scores_seem = self.segment_with_seem(image_np)
        all_masks.extend(masks_seem)
        all_scores.extend(scores_seem)

        self.sam_predictor.set_image(image_np)

        uncertain_mask, uncertain_ratio = self.get_uncertain_regions(image_np, all_masks)

        if uncertain_ratio > self.uncertain_threshold:
            y_coords, x_coords = np.where(uncertain_mask)
            if len(y_coords) > 0:
                num_points_to_sample = min(50, len(y_coords))
                indices = np.random.choice(len(y_coords), num_points_to_sample, replace=False)
                input_points = np.array([[x_coords[i], y_coords[i]] for i in indices])

                if input_points.size > 0:
                    masks_sam_points, scores_sam_points, _ = self.sam_predictor.predict(
                        point_coords=input_points,
                        point_labels=np.ones(input_points.shape[0]),
                        multimask_output=True
                    )
                    for i in range(masks_sam_points.shape[0]):
                        all_masks.append(masks_sam_points[i, np.argmax(scores_sam_points[i])])
                        all_scores.append(np.max(scores_sam_points[i]))

        if not all_scores and all_masks:
            all_scores = [1.0] * len(all_masks)

        filtered_masks = self.apply_nms(all_masks, all_scores)

        final_resized_masks = []
        for mask in filtered_masks:
            if mask.shape[:2] != (h_orig, w_orig):
                mask_resized = cv2.resize(mask.astype(np.uint8), (w_orig, h_orig), interpolation=cv2.INTER_NEAREST).astype(bool)
                final_resized_masks.append(mask_resized)
            else:
                final_resized_masks.append(mask)

        return final_resized_masks

    def apply_nms(self, masks, scores, iou_threshold=0.5):
        if not masks:
            return []

        if not scores:
            scores = [1.0] * len(masks)

        sorted_indices = np.argsort(scores)[::-1]
        sorted_masks = [masks[i] for i in sorted_indices]
        sorted_scores = [scores[i] for i in sorted_indices]

        keep_indices = []
        suppressed = np.zeros(len(sorted_masks), dtype=bool)

        for i in range(len(sorted_masks)):
            if suppressed[i]:
                continue

            keep_indices.append(i)

            for j in range(i + 1, len(sorted_masks)):
                if suppressed[j]:
                    continue

                mask1 = sorted_masks[i]
                mask2 = sorted_masks[j]

                if mask1.shape != mask2.shape:
                    min_h = min(mask1.shape[0], mask2.shape[0])
                    min_w = min(mask1.shape[1], mask2.shape[1])
                    mask1_resized = cv2.resize(mask1.astype(np.uint8), (min_w, min_h)).astype(bool)
                    mask2_resized = cv2.resize(mask2.astype(np.uint8), (min_w, min_h)).astype(bool)
                else:
                    mask1_resized = mask1
                    mask2_resized = mask2

                intersection = np.sum(mask1_resized & mask2_resized)
                union = np.sum(mask1_resized | mask2_resized)

                iou = intersection / union if union > 0 else 0

                if iou > iou_threshold:
                    suppressed[j] = True

        final_kept_masks = [sorted_masks[i] for i in keep_indices]
        return final_kept_masks

In [7]:
class HeadLocalizer:
    def __init__(self, n_samples=10):
        self.n_samples = n_samples

    def localize_head(self, mask):
        if np.sum(mask) == 0:
            return None

        soft_mask = self.generate_soft_mask(mask)

        head_position = self.fit_gmm_and_extract_head(soft_mask)

        return head_position

    def generate_soft_mask(self, mask):
        print(f"DEBUG: Type of mask in generate_soft_mask: {type(mask)}")
        if isinstance(mask, np.ndarray):
            print(f"DEBUG: Shape of mask in generate_soft_mask: {mask.shape}, Dtype: {mask.dtype}")
        y_coords, x_coords = np.where(mask)

        if len(y_coords) == 0:
            return mask.astype(float)

        n_points = min(self.n_samples, len(y_coords))
        indices = np.random.choice(len(y_coords), n_points, replace=False)

        sample_points_for_sam = np.array([[x_coords[i], y_coords[i]] for i in indices])

        soft_mask = ndimage.gaussian_filter(mask.astype(float), sigma=2)
        return soft_mask

    def fit_gmm_and_extract_head(self, soft_mask):
        print(f"DEBUG: Type of soft_mask in fit_gmm_and_extract_head: {type(soft_mask)}")
        if isinstance(soft_mask, np.ndarray):
            print(f"DEBUG: Shape of soft_mask in fit_gmm_and_extract_head: {soft_mask.shape}, Dtype: {soft_mask.dtype}")
        y_coords, x_coords = np.where(soft_mask > 0.05)
        weights = soft_mask[y_coords, x_coords]

        if len(y_coords) < 2:
            if len(y_coords) == 0:
                return None
            head_x = np.sum(x_coords * weights) / np.sum(weights) if np.sum(weights) > 0 else np.mean(x_coords)
            head_y = np.sum(y_coords * weights) / np.sum(weights) if np.sum(weights) > 0 else np.mean(y_coords)
            return np.array([head_x, head_y])

        data = np.column_stack([x_coords, y_coords])

        try:
            gmm = GaussianMixture(n_components=2, covariance_type='full', random_state=42, n_init=10)
            gmm.fit(data, sample_weight=weights)

            means = gmm.means_

            head_mean = means[np.argmin(means[:, 1])]

            return head_mean

        except Exception as e:
            total_weight = np.sum(weights)
            if total_weight == 0:
                return np.array([np.mean(x_coords), np.mean(y_coords)])
            head_x = np.sum(x_coords * weights) / total_weight
            head_y = np.sum(y_coords * weights) / total_weight
            return np.array([head_x, head_y])

In [None]:
class CrowdCountingNetwork(nn.Module):
    def __init__(self):
        super(CrowdCountingNetwork, self).__init__()

        self.features = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(256, 512, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, 3, padding=1), nn.ReLU(inplace=True),
        )

        self.regressor = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(256, 128, 3, padding=1), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(128, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 1, 1),
            nn.ReLU()
        )

    def forward(self, x):
        if x.ndim != 4:
            raise ValueError(f"CrowdCountingNetwork expected 4D input (N, C, H, W), but got {x.ndim}D input with shape {x.shape}")
        if x.shape[1] != 3: 
             raise ValueError(f"CrowdCountingNetwork expected 3 channels, but got {x.shape[1]} channels with shape {x.shape}")
        features = self.features(x)
        density_map = self.regressor(features)
        return density_map

In [9]:
class RobustLoss(nn.Module):
    def __init__(self, lambda_weight=0.01, background_weight=0.1):
        super(RobustLoss, self).__init__()
        self.lambda_weight = lambda_weight
        self.background_weight = background_weight
        self.mse_loss = nn.MSELoss(reduction='sum')

    def forward(self, pred_density, masks_batch, head_positions_batch, background_mask_batch, uncertain_mask_batch):
        total_loss = 0.0
        batch_size = pred_density.size(0)

        for b in range(batch_size):
            pred = pred_density[b, 0]

            H_pred, W_pred = pred.shape
            H_orig, W_orig = background_mask_batch[b].shape

            bg_mask_orig = background_mask_batch[b]
            uncertain_mask_orig = uncertain_mask_batch[b]

            bg_mask_resized = transforms.Resize((H_pred, W_pred), interpolation=transforms.InterpolationMode.NEAREST)(bg_mask_orig.unsqueeze(0).float()).squeeze(0).bool()
            uncertain_mask_resized = transforms.Resize((H_pred, W_pred), interpolation=transforms.InterpolationMode.NEAREST)(uncertain_mask_orig.unsqueeze(0).float()).squeeze(0).bool()


            bg_region_for_loss = bg_mask_resized & (~uncertain_mask_resized)
            if torch.sum(bg_region_for_loss) > 0:
                bg_loss = torch.mean((pred[bg_region_for_loss] - 0) ** 2)
                total_loss += self.background_weight * bg_loss

            image_masks = masks_batch[b]
            image_head_positions = head_positions_batch[b]

            if image_masks and len(image_masks) > 0:
                for i, mask_np in enumerate(image_masks):
                    mask_tensor = torch.from_numpy(mask_np).to(pred.device).bool()
                    mask_tensor_resized = transforms.Resize((H_pred, W_pred), interpolation=transforms.InterpolationMode.NEAREST)(mask_tensor.unsqueeze(0).float()).squeeze(0).bool()

                    mask_tensor_for_loss = mask_tensor_resized & (~uncertain_mask_resized)

                    if torch.sum(mask_tensor_for_loss) == 0:
                        continue

                    mask_density_sum = torch.sum(pred[mask_tensor_for_loss])
                    count_loss = (mask_density_sum - 1.0) ** 2

                    localization_loss = 0.0
                    if image_head_positions and i < len(image_head_positions):
                        head_pos = image_head_positions[i]
                        if head_pos is not None:
                            head_pos_scaled = torch.tensor([
                                head_pos[0] * (W_pred / W_orig),
                                head_pos[1] * (H_pred / H_orig)
                            ], device=pred.device)

                            y_grid, x_grid = torch.meshgrid(torch.arange(H_pred), torch.arange(W_pred), indexing='ij')
                            y_grid, x_grid = y_grid.to(pred.device), x_grid.to(pred.device)

                            dist_matrix = torch.sqrt((x_grid - head_pos_scaled[0])**2 + (y_grid - head_pos_scaled[1])**2)
                            weight_matrix = torch.exp(-dist_matrix / 5.0)

                            localization_loss = torch.mean((pred[mask_tensor_for_loss] - weight_matrix[mask_tensor_for_loss]) ** 2)

                    total_loss += self.lambda_weight * count_loss + localization_loss


        return total_loss / batch_size if batch_size > 0 else 0.0

In [None]:
class CrowdCountingTrainer:
    def __init__(self, data_root, device="cuda"):
        self.data_root = data_root
        self.device = device

        self.adaptive_sam = AdaptiveSAM(sam_checkpoint=SAM_CHECKPOINT_PATH, device=device)
        self.head_localizer = HeadLocalizer()
        self.model = CrowdCountingNetwork().to(device)
        self.criterion = RobustLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=1e-5)

        self.image_size = (512, 512)
        transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
        ])

        self.train_dataset = CrowdDataset(data_root, 'train_data', transform=transform)
        self.test_dataset = CrowdDataset(data_root, 'test_data', transform=transform)

        def custom_collate_fn(batch):
            elem = batch[0]
            collated_batch = {}
            for key, value in elem.items():
                if isinstance(value, torch.Tensor):
                    collated_batch[key] = value
                elif isinstance(value, np.ndarray):
                    tensor_value = torch.from_numpy(value).float()
                    if tensor_value.ndim == 1 and tensor_value.numel() == 0:
                        collated_batch[key] = torch.empty((0, 2), dtype=torch.float32)
                    elif tensor_value.ndim == 1 and tensor_value.numel() > 0:
                        collated_batch[key] = tensor_value.unsqueeze(0)
                    else:
                        collated_batch[key] = tensor_value
                elif isinstance(value, (int, float, str)):
                    collated_batch[key] = value
                elif isinstance(value, Image.Image):
                    value = value.convert("RGB")
                    temp_transform = transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
                    ])
                    transformed_image_tensor = temp_transform(value)
                    
                    if transformed_image_tensor.ndim == 2:
                        transformed_image_tensor = transformed_image_tensor.unsqueeze(0) 
                    
                    if transformed_image_tensor.ndim != 3:
                        raise ValueError(f"Image tensor for key '{key}' after custom collate transform has unexpected dimensions: {transformed_image_tensor.ndim}. Expected 3 (C, H, W).")
                    
                    collated_batch[key] = transformed_image_tensor
                else:
                    raise TypeError(f"Unhandled type in custom_collate_fn for key '{key}': {type(value)}")
            
            if isinstance(value, Image.Image):
                print(f"Image mode: {value.mode}") 

            return collated_batch

        self.train_loader = DataLoader(self.train_dataset, batch_size=1, shuffle=True, collate_fn=custom_collate_fn)
        self.test_loader = DataLoader(self.test_dataset, batch_size=1, shuffle=False, collate_fn=custom_collate_fn)

        print(f"Train dataset size: {len(self.train_dataset)}")
        print(f"Test dataset size: {len(self.test_dataset)}")


    def generate_pseudo_labels(self, iteration=0):
        print(f"Generating pseudo-labels for iteration {iteration}...")

        pseudo_labels = {}

        self.adaptive_sam.sam_predictor.model.eval()
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

        for idx, batch in enumerate(self.train_loader):
            if idx % 10 == 0:
                print(f"Processing image {idx+1}/{len(self.train_loader)}")
            image_input_from_batch = batch['original_image_pil'][0]
            image_name = batch['image_name'][0]

            print(f"DEBUG (generate_pseudo_labels): Initial Type of 'image_input_from_batch': {type(image_input_from_batch)}")

            image_for_sam_and_dims = None

            if isinstance(image_input_from_batch, Image.Image):
                image_for_sam_and_dims = image_input_from_batch
            elif isinstance(image_input_from_batch, torch.Tensor):
                print(f"DEBUG (generate_pseudo_labels): Detected torch.Tensor for original_image_pil. Converting to PIL.")
                
                denormalized_tensor = image_input_from_batch.cpu() * std + mean
                
                image_np_converted = (denormalized_tensor.clamp(0, 1).permute(1, 2, 0).numpy() * 255).astype(np.uint8)
                
                image_for_sam_and_dims = Image.fromarray(image_np_converted)
            else:
                raise TypeError(
                    f"Unexpected type for original_image_pil in batch: {type(image_input_from_batch)}. "
                    "Expected PIL.Image.Image or torch.Tensor for conversion."
                )
            w_orig, h_orig = image_for_sam_and_dims.size 

            print(f"DEBUG (generate_pseudo_labels): Using original dimensions W={w_orig}, H={h_orig}")


            masks = self.adaptive_sam.adaptive_segmentation(image_for_sam_and_dims) 

            head_positions = []
            for mask in masks:
                head_pos = self.head_localizer.localize_head(mask)
                head_positions.append(head_pos)

            background_mask_orig = np.ones((h_orig, w_orig), dtype=bool)
            uncertain_mask_orig = np.zeros((h_orig, w_orig), dtype=bool)

            for mask in masks:
                if mask.size > 0:
                    mask_np_resized = cv2.resize(mask.astype(np.uint8), (w_orig, h_orig), interpolation=cv2.INTER_NEAREST).astype(bool)
                    background_mask_orig &= ~mask_np_resized

            pseudo_labels[image_name] = {
                'masks': masks,
                'head_positions': head_positions,
                'original_image_shape': (h_orig, w_orig),
                'background_mask': background_mask_orig,
                'uncertain_mask': uncertain_mask_orig
            }

        if not pseudo_labels:
            print("Warning: No pseudo-labels were generated. This might indicate an issue with dataset loading or SAM segmentation.")

        print("Pseudo-label generation complete.")
        return pseudo_labels

    def train_iteration(self, pseudo_labels, num_epochs=100):
        print(f"Training crowd counting network for {num_epochs} epochs...")

        self.model.train()

        for epoch in range(num_epochs):
            total_loss = 0.0
            num_batches = 0

            for batch_idx, batch in enumerate(self.train_loader):
                image = batch['image'].to(self.device)
                if image.ndim == 3:
                    image = image.unsqueeze(0)
                elif image.ndim != 4:
                    raise ValueError(f"Input image tensor to model in train_iteration has unexpected dimensions: {image.ndim}. Expected 3 (C, H, W) or 4 (B, C, H, W).")
                image_name = batch['image_name']

                if image_name not in pseudo_labels:
                    continue

                labels_for_image = pseudo_labels[image_name]

                masks_for_loss = [labels_for_image['masks']]
                head_positions_for_loss = [labels_for_image['head_positions']]
                background_mask_for_loss = torch.from_numpy(labels_for_image['background_mask']).unsqueeze(0).to(self.device)
                uncertain_mask_for_loss = torch.from_numpy(labels_for_image['uncertain_mask']).unsqueeze(0).to(self.device)

                pred_density = self.model(image)

                loss = self.criterion(
                    pred_density,
                    masks_for_loss,
                    head_positions_for_loss,
                    background_mask_for_loss,
                    uncertain_mask_for_loss
                )

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                total_loss += loss.item()
                num_batches += 1

            if num_batches > 0:
                avg_loss = total_loss / num_batches
                if (epoch + 1) % 10 == 0 or epoch == 0:
                    print(f"Epoch {epoch+1}/{num_epochs}, Avg Loss: {avg_loss:.4f}")
            else:
                print(f"Epoch {epoch+1}/{num_epochs}, No batches processed (check pseudo-labels or dataset).")

        print("Training iteration complete.")

    def predict_locations(self, threshold=0.1):
        print("Predicting locations for pseudo-label refinement...")
        self.model.eval()
        predictions = {}

        with torch.no_grad():
            for batch in self.train_loader:
                image = batch['image'].to(self.device)
                if image.ndim == 3: 
                    image = image.unsqueeze(0)
                elif image.ndim != 4: 
                    raise ValueError(f"Input image tensor to model in predict_locations has unexpected dimensions: {image.ndim}. Expected 3 (C, H, W) or 4 (B, C, H, W).")
                image_name = batch['image_name']
                
                img_path = os.path.join(self.train_dataset.image_dir, image_name)
                original_image_pil = Image.open(img_path).convert('RGB')
                original_image_shape = (original_image_pil.size[1], original_image_pil.size[0])

                pred_density = self.model(image)
                pred_density = pred_density[0, 0].cpu().numpy()

                H_orig, W_orig = original_image_shape
                pred_density_orig_size = cv2.resize(pred_density, (W_orig, H_orig), interpolation=cv2.INTER_LINEAR)

                smoothed_density = maximum_filter(pred_density_orig_size, size=1)
                local_maxima = maximum_filter(smoothed_density, size=3) == smoothed_density
                
                peak_coords = np.where((local_maxima) & (smoothed_density > threshold))
                pred_locations = list(zip(peak_coords[1], peak_coords[0])) 

                predictions[image_name] = pred_locations

        print(f"Predicted locations for {len(predictions)} images.")
        return predictions

    def iterative_pseudo_label_refinement(self, initial_pseudo_labels, predicted_locations):
        print("Refining pseudo-labels with predicted locations...")

        refined_labels = {}
        self.adaptive_sam.sam_predictor.model.eval()

        for image_name, current_predicted_locations in predicted_locations.items():
            if image_name not in initial_pseudo_labels:
                continue

            initial_labels = initial_pseudo_labels[image_name]
            
            img_path = os.path.join(self.train_dataset.image_dir, image_name)
            try:
                original_image_pil = Image.open(img_path).convert('RGB')
            except FileNotFoundError:
                print(f"Warning: Original image file not found at {img_path}. Skipping refinement for this image.")
                refined_labels[image_name] = initial_labels
                continue

            image_np_rgb = np.array(original_image_pil)

            all_masks_from_refinement = []
            all_head_positions_from_refinement = []

            if current_predicted_locations:
                input_points = np.array(current_predicted_locations)
                input_labels = np.ones(len(current_predicted_locations))

                self.adaptive_sam.sam_predictor.set_image(image_np_rgb)
                masks, scores, _ = self.adaptive_sam.sam_predictor.predict(
                    point_coords=input_points,
                    point_labels=input_labels,
                    multimask_output=True
                )

                for i in range(masks.shape[0]):
                    if masks[i].size > 0:
                        best_mask_for_point = masks[i, np.argmax(scores[i])]
                        all_masks_from_refinement.append(best_mask_for_point)
                        head_pos = self.head_localizer.localize_head(best_mask_for_point)
                        all_head_positions_from_refinement.append(head_pos)

            final_masks_for_image = self.adaptive_sam.apply_nms(all_masks_from_refinement, all_head_positions_from_refinement)

            final_head_positions_for_image = []
            for mask in final_masks_for_image:
                head_pos = self.head_localizer.localize_head(mask)
                final_head_positions_for_image.append(head_pos)

            h_orig, w_orig = initial_labels['original_image_shape']
            background_mask_new = np.ones((h_orig, w_orig), dtype=bool)
            uncertain_mask_new = np.zeros((h_orig, w_orig), dtype=bool)

            for mask in final_masks_for_image:
                if mask.size > 0:
                    mask_resized_for_bg = cv2.resize(mask.astype(np.uint8), (w_orig, h_orig), interpolation=cv2.INTER_NEAREST).astype(bool)
                    background_mask_new &= ~mask_resized_for_bg

            refined_labels[image_name] = {
                'masks': final_masks_for_image,
                'head_positions': final_head_positions_for_image,
                'original_image_shape': initial_labels['original_image_shape'],
                'background_mask': background_mask_new,
                'uncertain_mask': uncertain_mask_new
            }

        print("Pseudo-label refinement complete.")
        return refined_labels

    def evaluate(self):
        print("Evaluating model...")

        self.model.eval()
        mae_total = 0.0
        mse_total = 0.0
        num_images = 0

        with torch.no_grad():
            for batch in self.test_loader:
                image = batch['image'].to(self.device)
                if image.ndim == 3: 
                    image = image.unsqueeze(0)
                elif image.ndim != 4: 
                    raise ValueError(f"Input image tensor to model in evaluate has unexpected dimensions: {image.ndim}. Expected 3 (C, H, W) or 4 (B, C, H, W).")
                gt_count = batch['gt_count']

                pred_density = self.model(image)
                pred_count = torch.sum(pred_density).item()

                mae_total += abs(pred_count - gt_count)
                mse_total += (pred_count - gt_count) ** 2
                num_images += 1

        if num_images == 0:
            print("No test images found for evaluation.")
            return 0.0, 0.0

        mae = mae_total / num_images
        mse = (mse_total / num_images)**0.5

        print(f"Evaluation Results - MAE: {mae:.2f}, RMSE: {mse:.2f}")
        return mae, mse

    def train_full_pipeline(self, num_iterations=2):
        print("Starting full training pipeline...")

        pseudo_labels = self.generate_pseudo_labels(iteration=0)
        if not pseudo_labels:
            print("No pseudo-labels generated. Cannot proceed with training.")
            return None

        for iteration in range(num_iterations):
            print(f"\n--- Iteration {iteration+1}/{num_iterations} ---")

            self.train_iteration(pseudo_labels, num_epochs=100)

            mae, mse = self.evaluate()

            if iteration < num_iterations - 1:
                predicted_locations = self.predict_locations()

                pseudo_labels = self.iterative_pseudo_label_refinement(
                    pseudo_labels, predicted_locations)
                if not pseudo_labels:
                    print("Refined pseudo-labels are empty. Stopping pipeline.")
                    break

        print("\nTraining completed!")
        return self.model

In [None]:
def main():
    data_root = os.path.join(os.getcwd(), "crowd_wala_dataset")
    device = "cuda" if torch.cuda.is_available() else "cpu"

    print(f"Using device: {device}")

    trainer = CrowdCountingTrainer(data_root, device)

    trained_model = trainer.train_full_pipeline(num_iterations=2)

    if trained_model:
        torch.save(trained_model.state_dict(), "crowd_counting_model_sam.pth")
        print("Model saved as 'crowd_counting_model_sam.pth'")

        final_mae, final_rmse = trainer.evaluate()
        print(f"Final Results - MAE: {final_mae:.2f}, MSE: {final_rmse:.2f}")
    else:
        print("Model training failed or did not produce a trained model.")

if __name__ == "__main__":
    main()

Using device: cuda
Initializing SAM model from c:\Users\Ishita\Downloads\crowd_counting\checkpoints\sam_vit_h_4b8939.pth...
SAM Predictor initialized.
MockSeemPredictor initialized. This means SEEM is not being truly used.
SEEM Predictor initialized (Mocked).
Train dataset size: 400
Test dataset size: 316
Starting full training pipeline...
Generating pseudo-labels for iteration 0...
DEBUG: Type/Shape of 'image' being returned by CrowdDataset for IMG_363.jpg: <class 'torch.Tensor'>, torch.Size([3, 512, 512])
Image mode: RGB
Processing image 1/400
DEBUG (generate_pseudo_labels): Initial Type of 'image_input_from_batch': <class 'torch.Tensor'>
DEBUG (generate_pseudo_labels): Detected torch.Tensor for original_image_pil. Converting to PIL.
DEBUG (generate_pseudo_labels): Using original dimensions W=1024, H=768
DEBUG: Type of mask in generate_soft_mask: <class 'numpy.ndarray'>
DEBUG: Shape of mask in generate_soft_mask: (768, 1024), Dtype: bool
DEBUG: Type of soft_mask in fit_gmm_and_extrac