In [1]:
from google.colab import drive
drive.mount('/content/drive')
!pwd

Mounted at /content/drive
/content


In [None]:
import os
import numpy as np
from scipy.io import loadmat
from scipy.ndimage import gaussian_filter
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.optim.lr_scheduler import CosineAnnealingLR
from collections import OrderedDict
import matplotlib.pyplot as plt
import matplotlib.cm as cm
plt.switch_backend('Agg')

def _find_points_in_mat(obj, current_path=""):
    if isinstance(obj, np.ndarray) and obj.ndim == 2 and obj.shape[1] == 2:
        return obj
    if isinstance(obj, dict):
        for key, value in obj.items():
            result = _find_points_in_mat(value, f"{current_path}.{key}" if current_path else key)
            if result is not None:
                return result
    if isinstance(obj, np.void) and obj.dtype.fields is not None:
        for field_name in obj.dtype.fields:
            field_value = obj[field_name]
            if field_value.ndim > 0 and field_value.size > 0 and isinstance(field_value.item(0), np.void):
                for item in field_value.flatten():
                    result = _find_points_in_mat(item, f"{current_path}.{field_name}")
                    if result is not None:
                        return result
            else:
                result = _find_points_in_mat(field_value, f"{current_path}.{field_name}")
                if result is not None:
                    return result
    if isinstance(obj, np.ndarray) and obj.dtype == object:
        for item in obj.flatten():
            result = _find_points_in_mat(item, current_path)
            if result is not None:
                return result
    return None

def generate_density_map(img_path, mat_path, sigma=4):
    img = Image.open(img_path).convert('RGB')
    w, h = img.size
    density = np.zeros((h, w), dtype=np.float32)
    try:
        mat = loadmat(mat_path)
        points = None
        if 'annPoints' in mat:
            points = mat['annPoints']
        elif 'image_info' in mat:
            info = mat['image_info']
            if info.ndim > 0 and info.shape[0] > 0 and info.shape[1] > 0:
                first_info_elem = info[0, 0]
                if 'location' in first_info_elem.dtype.names:
                    points = first_info_elem['location']
                elif 'number' in first_info_elem.dtype.names and 'location' in first_info_elem.dtype.names:
                    points = first_info_elem['location']
        if points is None:
            points = _find_points_in_mat(mat)
        if not (isinstance(points, np.ndarray) and points.ndim == 2 and points.shape[1] == 2):
            raise ValueError("")
        if points.size == 0:
            return gaussian_filter(density, sigma=sigma)
        for x_raw, y_raw in points:
            x, y = int(round(x_raw)), int(round(y_raw))
            if 0 <= x < w and 0 <= y < h:
                density[y, x] += 1
        return gaussian_filter(density, sigma=sigma)
    except Exception as e:
        raise ValueError(f" {mat_path}: {e}")

def preprocess_dataset(img_dir, mat_dir, out_dir, sigma=4):
    os.makedirs(out_dir, exist_ok=True)
    image_files = sorted([f for f in os.listdir(img_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))])
    processed_count, skipped_count = 0, 0
    skipped_file_details = []
    for fname in image_files:
        img_path = os.path.join(img_dir, fname)
        mat_fname = 'GT_' + os.path.splitext(fname)[0] + '.mat'
        mat_path = os.path.join(mat_dir, mat_fname)
        out_path = os.path.join(out_dir, os.path.splitext(fname)[0] + '.npy')
        if not os.path.isfile(mat_path):
            skipped_file_details.append(f"{fname}: ")
            skipped_count += 1
            continue
        try:
            dm = generate_density_map(img_path, mat_path, sigma)
            np.save(out_path, dm)
            processed_count += 1
        except ValueError as e:
            skipped_file_details.append(f"{fname}: {e}")
            skipped_count += 1
    print(f"Finished preprocessing for {img_dir}. skipped: {processed_count}, Success: {skipped_count}")
    if skipped_count > 0:
        for detail in skipped_file_details:
            print(f"- {detail}")

class CrowdDataset(Dataset):
    def __init__(self, img_dir, dm_dir, transform=None):
        self.img_dir, self.dm_dir, self.transform = img_dir, dm_dir, transform
        self.files = [f for f in sorted(os.listdir(img_dir)) if os.path.exists(os.path.join(dm_dir, os.path.splitext(f)[0] + '.npy'))]
        if not self.files:
            raise FileNotFoundError(f"No matching files found in {img_dir} and {dm_dir}.")
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        name = self.files[idx]
        img = Image.open(os.path.join(self.img_dir, name)).convert('RGB')
        dm = np.load(os.path.join(self.dm_dir, os.path.splitext(name)[0] + '.npy'))
        original_img = img.copy()
        original_dm_sum = np.sum(dm)
        if self.transform:
            img = self.transform(img)
            dm_tensor = torch.from_numpy(dm).unsqueeze(0).unsqueeze(0)
            dm_resized = F.interpolate(dm_tensor, size=img.shape[1:], mode='bilinear', align_corners=False).squeeze(0)
            if original_dm_sum > 0:
                resized_sum = torch.sum(dm_resized)
                if resized_sum > 0:
                    dm_resized *= (original_dm_sum / resized_sum)
            dm = dm_resized
        else:
            dm = torch.from_numpy(dm).unsqueeze(0)
        return img, dm, original_img, name

def custom_collate_fn(batch):
    imgs = torch.stack([item[0] for item in batch], 0)
    dms = torch.stack([item[1] for item in batch], 0)
    original_imgs = [item[2] for item in batch]
    names = [item[3] for item in batch]
    return imgs, dms, original_imgs, names

class CCTrans(nn.Module):
    def __init__(self, img_size=(384, 384), patch_size=16, vit_out_channels=768):
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
        self.vit.heads = nn.Identity()
        self.grid_h, self.grid_w = img_size[0] // patch_size, img_size[1] // patch_size
        self.features = OrderedDict()
        self.vit.conv_proj.register_forward_hook(self._get_features_hook('patch_embed'))
        self.vit.encoder.layers[5].register_forward_hook(self._get_features_hook('block_6'))
        self.vit.encoder.layers[-1].register_forward_hook(self._get_features_hook('block_last'))
        self.conv_adapt_high = nn.Conv2d(vit_out_channels, 256, 1)
        self.conv_adapt_mid = nn.Conv2d(vit_out_channels, 256, 1)
        self.conv_adapt_low = nn.Conv2d(vit_out_channels, 256, 1)
        self.decoder_up1 = nn.Sequential(nn.ConvTranspose2d(256, 128, 2, 2), nn.ReLU(True), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU(True))
        self.fuse1 = nn.Sequential(nn.Conv2d(128 + 256, 128, 1), nn.ReLU(True), nn.Conv2d(128, 128, 3, padding=2, dilation=2), nn.ReLU(True))
        self.decoder_up2 = nn.Sequential(nn.ConvTranspose2d(128, 64, 2, 2), nn.ReLU(True), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU(True))
        self.fuse2 = nn.Sequential(nn.Conv2d(64 + 256, 64, 1), nn.ReLU(True), nn.Conv2d(64, 64, 3, padding=3, dilation=3), nn.ReLU(True))
        self.decoder_up3 = nn.Sequential(nn.ConvTranspose2d(64, 32, 2, 2), nn.ReLU(True), nn.Conv2d(32, 32, 3, padding=1), nn.ReLU(True))
        self.decoder_up4 = nn.Sequential(nn.ConvTranspose2d(32, 16, 2, 2), nn.ReLU(True), nn.Conv2d(16, 16, 3, padding=1), nn.ReLU(True))
        self.output_layer = nn.Conv2d(16, 1, kernel_size=1)
        self._initialize_decoder_weights()
    def _initialize_decoder_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)) and m not in self.vit.modules():
                nn.init.normal_(m.weight, std=0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    def _get_features_hook(self, name):
        def hook(_, __, output):
            if name == 'patch_embed':
                self.features[name] = output
            elif output.ndim == 3:
                self.features[name] = output[:, 1:].permute(0, 2, 1).reshape(output.size(0), -1, self.grid_h, self.grid_w)
        return hook
    def forward(self, x):
        original_size = x.shape[2:]
        self.features.clear()
        _ = self.vit(x)
        f_last = self.conv_adapt_low(self.features['block_last'])
        d1 = self.decoder_up1(f_last)
        f_mid = F.interpolate(self.features['block_6'], size=d1.shape[2:], mode='bilinear')
        d1_fused = self.fuse1(torch.cat([d1, self.conv_adapt_mid(f_mid)], dim=1))
        d2 = self.decoder_up2(d1_fused)
        f_patch = F.interpolate(self.features['patch_embed'], size=d2.shape[2:], mode='bilinear')
        d2_fused = self.fuse2(torch.cat([d2, self.conv_adapt_high(f_patch)], dim=1))
        d3 = self.decoder_up3(d2_fused)
        d4 = self.decoder_up4(d3)
        density_map = self.output_layer(d4)
        return F.relu(F.interpolate(density_map, size=original_size, mode='bilinear'))

def train_model(loader, model, criterion, optimizer, scheduler, device, epochs):
    model.to(device)
    model.train()
    for epoch in range(epochs):
        running_loss = 0.0
        for i, (imgs, dms, _, _) in enumerate(loader):
            imgs, dms = imgs.to(device), dms.to(device)
            optimizer.zero_grad()
            preds = model(imgs)
            loss = criterion(preds, dms)
            if torch.isnan(loss) or torch.isinf(loss):
                continue
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        scheduler.step()

def evaluate_model(loader, model, device, output_vis_dir=None, max_vis_images=20):
    model.to(device)
    model.eval()
    mae, mse, vis_count = 0.0, 0.0, 0
    if output_vis_dir:
        os.makedirs(output_vis_dir, exist_ok=True)
    with torch.no_grad():
        for imgs, dms, original_imgs, names in loader:
            preds = F.relu(model(imgs.to(device)))
            pred_counts = preds.sum(dim=[1, 2, 3]).cpu()
            gt_counts = dms.sum(dim=[1, 2, 3])
            mae += (pred_counts - gt_counts).abs().sum().item()
            mse += ((pred_counts - gt_counts) ** 2).sum().item()
    total = len(loader.dataset)
    print(f"MAE: {mae/total:.2f}, RMSE: {np.sqrt(mse/total):.2f}")

def save_density_map_visualization(img, gt, pred, path, name):
    fig, axes = plt.subplots(1, 3, figsize=(20, 7))
    vmax = max(gt.max(), pred.max(), 0.1)
    axes[0].imshow(img); axes[0].set_title(f'Original: {name}'); axes[0].axis('off')
    axes[1].imshow(gt, cmap='jet', vmin=0, vmax=vmax); axes[1].set_title(f'GT Count: {gt.sum():.2f}'); axes[1].axis('off')
    axes[2].imshow(pred, cmap='jet', vmin=0, vmax=vmax); axes[2].set_title(f'Pred Count: {pred.sum():.2f}'); axes[2].axis('off')
    plt.tight_layout()
    plt.savefig(path)
    plt.close(fig)

if __name__ == '__main__':
    train_images = '/content/drive/MyDrive/images'
    train_gt = '/content/drive/MyDrive/ground_truth'
    train_dm = '/content/drive/MyDrive/density_maps'
    test_images = '/content/drive/MyDrive/images1'
    test_gt = '/content/drive/MyDrive/ground_truth1'
    test_dm = '/content/drive/MyDrive/density_maps1'
    visualization_output_dir = '/content/drive/MyDrive/cctrans3_visualizations'
    PRETRAINED_WEIGHTS_PATH = '/content/drive/MyDrive/alt_gvt_large.pth'
    IMG_SIZE, PATCH_SIZE, EPOCHS, BATCH_SIZE = (384, 384), 16, 10, 4
    preprocess_dataset(train_images, train_gt, train_dm)
    preprocess_dataset(test_images, test_gt, test_dm)
    transform = {
        'train': transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ]),
        'test': transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    }
    train_dataset = CrowdDataset(train_images, train_dm, transform['train'])
    test_dataset = CrowdDataset(test_images, test_dm, transform['test'])
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2, collate_fn=custom_collate_fn)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CCTrans(img_size=IMG_SIZE)
    if os.path.exists(PRETRAINED_WEIGHTS_PATH):
        model.vit.load_state_dict(torch.load(PRETRAINED_WEIGHTS_PATH, map_location=device), strict=False)
    decoder_params = [p for n, p in model.named_parameters() if 'vit' not in n and p.requires_grad]
    backbone_params = [p for n, p in model.named_parameters() if 'vit' in n and p.requires_grad]
    optimizer = torch.optim.AdamW([
        {'params': backbone_params, 'lr': 1e-6},
        {'params': decoder_params, 'lr': 1e-5}
    ], weight_decay=1e-4)
    criterion = nn.MSELoss()
    scheduler = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-7)
    train_model(train_loader, model, criterion, optimizer, scheduler, device, epochs=EPOCHS)
    evaluate_model(test_loader, model, device, output_vis_dir=visualization_output_dir)



Finished preprocessing for /content/drive/MyDrive/images. skipped: 0, Success: 400
- IMG_1.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_1.mat: 
- IMG_10.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_10.mat: 
- IMG_100.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_100.mat: 
- IMG_101.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_101.mat: 
- IMG_102.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_102.mat: 
- IMG_103.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_103.mat: 
- IMG_104.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_104.mat: 
- IMG_105.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_105.mat: 
- IMG_106.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_106.mat: 
- IMG_107.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_107.mat: 
- IMG_108.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_108.mat: 
- IMG_109.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_109.mat: 
- IMG_11.jpg:  /content/drive/MyDrive/ground_truth/GT_IMG_11.mat: 
- IMG_110.jpg:  /content/dri

Downloading: "https://download.pytorch.org/models/vit_b_16_swag-9ac1b537.pth" to /root/.cache/torch/hub/checkpoints/vit_b_16_swag-9ac1b537.pth
100%|██████████| 331M/331M [00:05<00:00, 65.2MB/s]
