In [None]:
import os

class Config:
    train_img_dir = os.path.abspath(os.path.join('..', 'data', 'train'))
    val_img_dir = os.path.abspath(os.path.join('..', 'data', 'val'))
    test_img_dir = os.path.abspath(os.path.join('..', 'data', 'test'))

    train_fixations_json = os.path.abspath(os.path.join('..', '..', 'salgan', 'data', 'fixations_train2014.json'))
    val_fixations_json = os.path.abspath(os.path.join('..', '..', 'salgan', 'data', 'fixations_val2014.json'))

    image_size = (256, 192)
    saliency_size = (256, 192)

    batch_size = 4
    num_workers = 4
    num_epochs = 5
    learning_rate = 1e-4
    weight_decay = 1e-5

    timesteps = 1000
    beta_start = 1e-4
    beta_end = 0.02

    output_dir = os.path.abspath('./saliency_diffusion_outputs')
    checkpoint_path = os.path.join(output_dir, 'saliency_diffusion_unet.pt')
    sample_dir = os.path.join(output_dir, 'samples')

    device = 'cuda' if os.environ.get('CUDA_VISIBLE_DEVICES') is not None else 'cuda' if __import__('torch').cuda.is_available() else 'cpu'

cfg = Config()
os.makedirs(cfg.output_dir, exist_ok=True)
os.makedirs(cfg.sample_dir, exist_ok=True)
cfg.__dict__

In [None]:
import json
from typing import Dict, List, Tuple
import numpy as np
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.utils as vutils
import matplotlib.pyplot as plt

print('Using device:', cfg.device)
print('Train dir exists:', os.path.isdir(cfg.train_img_dir))
print('Val dir exists:', os.path.isdir(cfg.val_img_dir))
print('Test dir exists:', os.path.isdir(cfg.test_img_dir))
print('Train fixations JSON exists:', os.path.isfile(cfg.train_fixations_json))
print('Val fixations JSON exists:', os.path.isfile(cfg.val_fixations_json))


In [None]:
def load_json(path: str) -> Dict:
    with open(path, 'r') as f:
        return json.load(f)
    
def parse_fixations_json(fixations: Dict) -> Dict[str, np.ndarray]:
    img_id_to_name = {}
    if 'images' in fixations:
        for img in fixations['images']:
            img_id_to_name[img['id']] = img['file_name']

    mapping: Dict[str. List[Tuple[int, int]]] = {}
    if 'annotations' in fixations:
        for ann in fixations['annotations']:
            img_id = ann.get('image_id')
            pts = ann.get('fixations') or ann.get('points') or []
            if img_id is None:
                continue
            fname = img_id_to_name.get(img_id, None)
            if fname is None:
                continue
            mapping.setdefault(fname, []).extend(pts)

    saliency_maps: Dict[str, np.ndarray] = {}
    W, H = cfg.saliency_size
    for fname, pts in mapping.items():
        sal_map = np.zeros((H, W), dtype=np.float32)
        for p in pts:
            if len(p) < 2:
                continue
            x, y = p[0], p[1]
            ix = int(np.clip(x / max(1.0, x) * (W - 1), 0, W - 1)) if False else int(np.clip(x, 0, W - 1))
            iy = int(np.clip(y, 0, H - 1))
            sal_map[iy, ix] += 1.0

            if sal_map.max() > 0:
                sal_map /= sal_map.max()

            saliency_maps[fname] = sal_map

        return saliency_maps
    
try:
    train_fix_raw = load_json(cfg.train_fixations_json)
    val_fix_raw = load_json(cfg.val_fixations_json)
    print('Loaded fixation JSONs.')
except Exception as e:
    print('Error loading or parsing fixation JSONs. Please adapt parse_fixations_json to your format.')
    print(e)
    train_fix_raw, val_fix_raw = None, None

In [None]:
train_saliency_maps = parse_fixations_json(train_fix_raw) if train_fix_raw is not None else {}
val_saliency_maps = parse_fixations_json(val_fix_raw) if val_fix_raw is not None else {}

print('Train saliency entries:', len(train_saliency_maps))
print('Val saliency entries:', len(val_saliency_maps))

In [None]:
class SaliconSaliencyDataset(Dataset):
    def __init__(self, img_dir: str, saliency_maps: Dict[str, np.ndarray], image_size=(256, 192)):
        self.img_dir = img_dir
        self.saliency_maps = saliency_maps
        self.image_size = image_size
        # Collect only images that have saliency info
        self.image_files = [f for f in os.listdir(img_dir) if f in saliency_maps]
        self.image_files.sort()

        self.img_transform = transforms.Compose([
            transforms.Resize((image_size[1], image_size[0])),
            transforms.ToTensor(),
        ])

        # Saliency in [0,1], single-channel

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        fname = self.image_files[idx]
        img_path = os.path.join(self.img_dir, fname)

        img = Image.open(img_path).convert('RGB')
        img = self.img_transform(img)

        sal_map = self.saliency_maps[fname]
        H, W = self.image_size[1], self.image_size[0]
        sal_img = Image.fromarray((sal_map * 255).astype(np.uint8))
        sal_img = sal_img.resize((W, H), resample=Image.BILINEAR)
        sal_tensor = torch.from_numpy(np.array(sal_img)).float() / 255.0
        sal_tensor = sal_tensor.unsqueeze(0)  # (1, H, W)

        return img, sal_tensor, fname
    
train_dataset = SaliconSaliencyDataset(cfg.train_img_dir, train_saliency_maps, image_size=cfg.image_size)
val_dataset = SaliconSaliencyDataset(cfg.val_img_dir, val_saliency_maps, image_size=cfg.image_size)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)

print('Train dataset size:', len(train_dataset))
print('Val dataset size:', len(val_dataset))

batch = next(iter(train_loader)) if len(train_dataset) > 0 else None
if batch is not None:
    imgs, sal_maps, fnames = batch
    print('Batch shapes:', imgs.shape, sal_maps.shape)
    grid = vutils.make_grid(imgs, nrow=min(4, imgs.size(0)))
    plt.figure(figsize=(8, 4))
    plt.title('Sample input images')
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

    grid_sal = vutils.make_grid(sal_maps, nrow=min(4, sal_maps.size(0)))
    plt.figure(figsize=(8, 4))
    plt.title('Sample ground truth saliency maps')
    plt.imshow(grid_sal[0].cpu(), cmap='hot')
    plt.axis('off')
    plt.show()
else:
    print('Warning: Train dataset is empty. Check your paths and JSON mapping.')

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.act = nn.ReLU(inplace=True)
        self.time_mlp = None
        if time_emb_dim is not None:
            self.time_mlp = nn.Sequential(
                nn.Linear(time_emb_dim, out_ch),
                nn.ReLU(inplace=True)
            )

    def forward(self, x, t_emb=None):
        h = self.conv1(x)
        if self.time_mlp is not None and t_emb is not None:
            # Add time embedding (broadcast over spatial dims)
            temb = self.time_mlp(t_emb)[:, :, None, None]
            h = h + temb
        h = self.act(h)
        h = self.conv2(h)
        h = self.act(h)
        return h
    
class UNetSaliency(nn.Module):
    def __init__(self, img_channels=3, saliency_channels=1, base_ch=64, time_emb_dim=128):
        super().__init__()
        in_ch = img_channels + saliency_channels

        # Time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(1, time_emb_dim),
            nn.ReLU(inplace=True),
            nn.Linear(time_emb_dim, time_emb_dim),
        )
        
        # Encoder
        self.enc1 = ConvBlock(in_ch, base_ch, time_emb_dim)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base_ch, base_ch * 2, time_emb_dim)
        self.pool2 = nn.MaxPool2d(2)

        # Bottleneck
        self.bottleneck = ConvBlock(base_ch * 2, base_ch * 4, time_emb_dim)
        
        # Decoder
        self.up2 = nn.ConvTranspose2d(base_ch * 4, base_ch * 2, 2, stride=2)
        self.dec2 = ConvBlock(base_ch * 4, base_ch * 2, time_emb_dim)
        self.up1 = nn.ConvTranspose2d(base_ch * 2, base_ch, 2, stride=2)
        self.dec1 = ConvBlock(base_ch * 2, base_ch, time_emb_dim)
        
        self.out_conv = nn.Conv2d(base_ch, saliency_channels, 1)

    def forward(self, x_img, x_sal_noisy, t):
        # x_img: (B, 3, H, W), x_sal_noisy: (B, 1, H, W), t: (B,) timestep index
        t = t.float().unsqueeze(-1) / cfg.timesteps
        t_emb = self.time_mlp(t)

        x = torch.cat([x_img, x_sal_noisy], dim=1)
        
        e1 = self.enc1(x, t_emb)
        p1 = self.pool1(e1)
        e2 = self.enc2(p1, t_emb)
        p2 = self.pool2(e2)

        b = self.bottleneck(p2, t_emb)

        u2 = self.up2(b)
        d2 = self.dec2(torch.cat([u2, e2], dim=1), t_emb)
        u1 = self.up1(d2)
        d1 = self.dec1(torch.cat([u1, e1], dim=1), t_emb)

        out = self.out_conv(d1)
        return out

model = UNetSaliency().to(cfg.device)
print('Model params:', sum(p.numel() for p in model.parameters()) / 1e6, 'M')

In [None]:
class SaliencyDDPM(nn.Module):
    def __init__(self, model: nn.Module, timesteps: int = 1000, beta_start: float = 1e-4, beta_end: float = 0.02):
        super().__init__()
        self.model = model
        self.timesteps = timesteps

        betas = torch.linspace(beta_start, beta_end, timesteps)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
       
        self.register_buffer('betas', betas)
        self.register_buffer('alphas', alphas)
        self.register_buffer('alphas_cumprod', alphas_cumprod)
        self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - alphas_cumprod))

    def q_sample(self, x_start, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x_start)
        sqrt_alphas_cumprod_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise

    def p_losses(self, x_img, x_sal, t):
        noise = torch.randn_like(x_sal)
        x_noisy = self.q_sample(x_sal, t, noise)
        noise_pred = self.model(x_img, x_noisy, t)
        return nn.functional.mse_loss(noise_pred, noise)
    
    @torch.no_grad()
    def p_sample(self, x_img, t, x):
        betas_t = self.betas[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alphas_cumprod_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_recip_alphas_t = torch.sqrt(1.0 / self.alphas[t]).view(-1, 1, 1, 1)
       
        # Predict noise
        noise_pred = self.model(x_img, x, t)
        model_mean = sqrt_recip_alphas_t * (x - betas_t / sqrt_one_minus_alphas_cumprod_t * noise_pred)
       
        if t[0] == 0:
            return model_mean
        else:
            noise = torch.randn_like(x)
            sigma_t = torch.sqrt(betas_t)
            return model_mean + sigma_t * noise

    @torch.no_grad()
    def sample(self, x_img, shape):
        # x_img: condition, shape: (B, 1, H, W) saliency shape
        x = torch.randn(shape, device=x_img.device)
        B = x.shape[0]
        for i in reversed(range(self.timesteps)):
            t = torch.full((B,), i, device=x_img.device, dtype=torch.long)
            x = self.p_sample(x_img, t, x)
        return x

ddpm = SaliencyDDPM(model, timesteps=cfg.timesteps, beta_start=cfg.beta_start, beta_end=cfg.beta_end).to(cfg.device)
ddpm

In [None]:
optimizer = torch.optim.Adam(ddpm.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay)

def train_epoch(epoch_idx: int):
    ddpm.train()
    total_loss = 0.0
    num_batches = 0
    for imgs, sal_maps, _ in train_loader:
        imgs = imgs.to(cfg.device)
        sal_maps = sal_maps.to(cfg.device)
        
        b = imgs.size(0)
        t = torch.randint(0, cfg.timesteps, (b,), device=cfg.device).long()

        loss = ddpm.p_losses(imgs, sal_maps, t)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        num_batches += 1
    return total_loss / max(1, num_batches)

@torch.no_grad()
def validate_epoch(epoch_idx: int):
    ddpm.eval()
    total_loss = 0.0
    num_batches = 0
    for imgs, sal_maps, _ in val_loader:
        imgs = imgs.to(cfg.device)
        sal_maps = sal_maps.to(cfg.device)
        b = imgs.size(0)
        t = torch.randint(0, cfg.timesteps, (b,), device=cfg.device).long()
        loss = ddpm.p_losses(imgs, sal_maps, t)
        total_loss += loss.item()
        num_batches += 1
    return total_loss / max(1, num_batches)

best_val_loss = float('inf')

for epoch in range(cfg.num_epochs):
    train_loss = train_epoch(epoch)
    val_loss = validate_epoch(epoch)
    print(f"Epoch {epoch+1}/{cfg.num_epochs} | Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(ddpm.state_dict(), cfg.checkpoint_path)
        print('  Saved new best model to', cfg.checkpoint_path)
print('Training complete. Best val loss:', best_val_loss)

In [None]:
if os.path.isfile(cfg.checkpoint_path):
    ddpm.load_state_dict(torch.load(cfg.checkpoint_path, map_location=cfg.device))
    print('Loaded best checkpoint.')
else:
    print('Checkpoint not found; using current model weights.')

In [None]:
@torch.no_grad()
def generate_samples_from_loader(loader, num_batches: int = 1, tag: str = 'val'):
    ddpm.eval()
    batch_count = 0
    for imgs, sal_maps, fnames in loader:
        imgs = imgs.to(cfg.device)
        sal_maps = sal_maps.to(cfg.device)
        B = imgs.size(0)

        # Sample saliency maps via reverse diffusion
        samples = ddpm.sample(imgs, shape=sal_maps.shape)
        samples = samples.clamp(0.0, 1.0)
        
        for i in range(B):
            img = imgs[i].cpu()
            gt = sal_maps[i, 0].cpu()
            pred = samples[i, 0].cpu()

            fig, axs = plt.subplots(1, 3, figsize=(9, 3))
            axs[0].imshow(img.permute(1, 2, 0))
            axs[0].set_title('Image')
            axs[0].axis('off')

            axs[1].imshow(gt, cmap='hot')
            axs[1].set_title('GT saliency')
            axs[1].axis('off')

            axs[2].imshow(pred, cmap='hot')
            axs[2].set_title('Pred saliency (diffusion)')
            axs[2].axis('off')
            plt.tight_layout()
            save_path = os.path.join(cfg.sample_dir, f'{tag}_{fnames[i]}')
            fig.savefig(save_path, dpi=150)
            plt.close(fig)

        batch_count += 1
        if batch_count >= num_batches:
            break

    print(f'Saved sample visualizations to {cfg.sample_dir} (tag={tag}).')
# Generate a few validation samples
if len(val_dataset) > 0:
    generate_samples_from_loader(val_loader, num_batches=1, tag='val')
else:
    print('No validation data; skipping sample generation.')

In [None]:
class TestImageDataset(Dataset):
    def __init__(self, img_dir: str, image_size=(256, 192)):
        self.img_dir = img_dir
        self.files = [f for f in os.listdir(img_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
        self.files.sort()
        self.transform = T.Compose([
            T.Resize((image_size[1], image_size[0])),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        fname = self.files[idx]
        img_path = os.path.join(self.img_dir, fname)
        img = Image.open(img_path).convert('RGB')
        img = self.transform(img)
        return img, fname

test_dataset = TestImageDataset(cfg.test_img_dir, image_size=cfg.image_size)
test_loader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=True)
print('Test dataset size:', len(test_dataset))

In [None]:
@torch.no_grad()
def run_test_inference():
    ddpm.eval()
    for imgs, fnames in test_loader:
        imgs = imgs.to(cfg.device)
        B, _, H, W = imgs.shape
        sal_shape = (B, 1, H, W)
        preds = ddpm.sample(imgs, shape=sal_shape)
        preds = preds.clamp(0.0, 1.0).cpu().numpy()

        for i in range(B):
            pred_map = (preds[i, 0] * 255).astype(np.uint8)
            im = Image.fromarray(pred_map)
            save_name = os.path.splitext(fnames[i])[0] + '_saliency.png'
            save_path = os.path.join(cfg.output_dir, save_name)
            im.save(save_path)
    print('Saved predicted saliency maps for test images to', cfg.output_dir)

if len(test_dataset) > 0:
    run_test_inference()
else:
    print('No test images found; skipping test inference.')