In [None]:
from sklearn.model_selection import train_test_split
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image
from ptflops import get_model_complexity_info
from scipy.stats import entropy
from scipy.stats import spearmanr
from scipy.stats import pearsonr
import torchvision.models as models
from io import BytesIO
from collections import defaultdict
import copy
from pytorch_msssim import ssim
import torch.nn.functional as F
import pandas as pd
from scipy.stats import sem
import seaborn as sns
import math
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr
import os, random, cv2, numpy as np, torch, torch.nn as nn
from typing import List
from pathlib import Path
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import matplotlib as mpl
import pylab
mpl.rcParams['lines.linewidth'] = 4
mpl.rcParams['lines.color'] = 'r'
mpl.rcParams['font.weight'] = 200
plt.style.use('seaborn-whitegrid')
plt.rc('figure',figsize=(15,9))
mpl.axes.Axes.annotate
mpl.rcParams['font.family'] = "serif"
pylab.rcParams['ytick.major.pad']='15'
pylab.rcParams['xtick.major.pad']='15'
mpl.rcParams['font.weight'] = "semibold"
mpl.rcParams['axes.labelsize'] = 25
mpl.rcParams['axes.linewidth'] = 4
mpl.rcParams['xtick.labelsize'] = 25
mpl.rcParams['ytick.labelsize'] = 25
mpl.rcParams['axes.edgecolor'] = 'black'
mpl.rcParams['axes.titlesize'] = 25
mpl.rcParams['legend.fontsize'] = 20
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)
import torchvision.transforms.functional as TF
from fvcore.nn import FlopCountAnalysis, parameter_count

In [None]:
torch.cuda.empty_cache()

In [None]:
#!pip install pytorch-msssim
#!pip install ptflops
#!pip install fvcore torchinfo

In [None]:
SEED = 42
random.seed(SEED);  np.random.seed(SEED)
torch.manual_seed(SEED);  torch.cuda.manual_seed_all(SEED)

class CFG:
    IMG_SZ      = 256          
    BATCH       = 8
    EPOCH_KD    = 8
    EPOCH_BC    = 8
    LAMBDA      = 0.5          
    LR          = 0.00004
    DEVICE      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(DEVICE)

cfg = CFG()

In [None]:
# def load_data(frame_dir: str, hm_dir: str):
#     frames, hmaps = [], []
#     for fname in sorted(os.listdir(frame_dir)):
#         f_path, h_path = os.path.join(frame_dir, fname), os.path.join(hm_dir, fname)

#         frm  = cv2.imread(f_path)                      # BGR uint8
#         hmap = cv2.imread(h_path, cv2.IMREAD_GRAYSCALE)   # (H,W) uint8
#         if frm is None or hmap is None:
#             print(f'[WARN] Could not load {fname}')
#             continue

#         frm  = cv2.resize(frm,  (cfg.IMG_SZ, cfg.IMG_SZ))
#         hmap = cv2.resize(hmap,(cfg.IMG_SZ, cfg.IMG_SZ))

#         frames.append(frm[..., ::-1])            # to RGB
#         hmaps.append(hmap)

#     return np.stack(frames,0), np.stack(hmaps,0)

In [None]:
def load_data(frame_dir: str, hm_dir: str):
    frames = []
    hmaps = []

    for fname in sorted(os.listdir(frame_dir)):
        if fname.startswith('.'):
            continue

        f_path = os.path.join(frame_dir, fname)
        h_path = os.path.join(hm_dir, fname)

        frm = cv2.imread(f_path)
        hmap = cv2.imread(h_path, cv2.IMREAD_GRAYSCALE)

        if frm is None or hmap is None:
            print(f'[WARN] Could not load {fname}')
            continue

        frm = cv2.resize(frm, (cfg.IMG_SZ, cfg.IMG_SZ))
        hmap = cv2.resize(hmap, (cfg.IMG_SZ, cfg.IMG_SZ))

        frames.append(frm[..., ::-1])  # BGR to RGB
        hmaps.append(hmap)

    return np.stack(frames), np.stack(hmaps)

In [None]:
frame_folder = 'Images/GT/extracted_frames'
heatmap_folder = 'Images/Gaze/extracted_frames'


frames, hmaps = load_data(frame_folder, heatmap_folder)
print(frames.shape, hmaps.shape)

In [None]:
# class GazeDataset(Dataset):
#     def __init__(self, frames, hmaps):
#         self.f_t = transforms.Compose([
#             transforms.ToPILImage(),
#             transforms.ToTensor()
#         ])
#         self.h_t = transforms.Compose([
#             transforms.ToPILImage(mode='L'),
#             transforms.ToTensor()
#         ])
#         self.X, self.Y = frames, hmaps

#     def __len__(self):
#         return len(self.X)

#     def __getitem__(self, idx):
#         return self.f_t(self.X[idx]), self.h_t(self.Y[idx])

In [None]:
# class GazeDataset(Dataset):
#     def __init__(self, frames, hmaps, augment=True):
#         if augment:
#             self.f_t = transforms.Compose([
#                 transforms.ToPILImage(),
#                 transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
#                 transforms.RandomHorizontalFlip(),
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
#             ])
#         else:
#             self.f_t = transforms.Compose([
#                 transforms.ToPILImage(),
#                 transforms.ToTensor(),
#                 transforms.Normalize(mean=[0.485, 0.456, 0.406],
#                                      std=[0.229, 0.224, 0.225])
#             ])

#         self.h_t = transforms.Compose([
#             transforms.ToPILImage(mode='L'),
#             transforms.ToTensor()  # Keeps heatmap in (1, 256, 256)
#         ])

#         self.X, self.Y = frames, hmaps

#     def __len__(self):
#         return len(self.X)

#     def __getitem__(self, idx):
#         return self.f_t(self.X[idx]), self.h_t(self.Y[idx])

class GazeDataset(Dataset):
    def __init__(self, frames, hmaps, augment=True):
        self.augment = augment
        self.frames, self.hmaps = frames, hmaps

     
        self.color_jitter = transforms.ColorJitter(
            brightness=0.2, contrast=0.2,
            saturation=0.2, hue=0.05)

        self.to_tensor_norm = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std =[0.229, 0.224, 0.225])
        ])

    def __len__(self): return len(self.frames)

    def __getitem__(self, idx):
        img = self.frames[idx]         
        hm  = self.hmaps [idx]          

      
        if self.augment:
            if random.random() < 0.5: 
                img = np.fliplr(img).copy()
                hm  = np.fliplr(hm ).copy()

     
        if self.augment:
            img = self.color_jitter(TF.to_pil_image(img))

        # to tensor
        img = self.to_tensor_norm(img)
        hm  = torch.from_numpy(hm).unsqueeze(0).float() / 255.0

        return img, hm

In [None]:
tr_X, te_X, tr_Y, te_Y = train_test_split(
    frames, hmaps, test_size=0.15, random_state=SEED, shuffle=True
)

train_loader = DataLoader(GazeDataset(tr_X, tr_Y), batch_size=cfg.BATCH, shuffle=True, drop_last=True)
val_loader   = DataLoader(GazeDataset(te_X, te_Y), batch_size=cfg.BATCH, shuffle=False)

In [None]:
for idx, (frames, heatmaps) in enumerate(val_loader):
    print(f"Batch {idx + 1}")
    print(f"Frames shape: {frames.shape}")
    print(f"Heatmaps shape: {heatmaps.shape}")
    break  

In [None]:
class SpatialChannelAttention(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.channel_fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // 8, 1),
            nn.ReLU(),
            nn.Conv2d(in_channels // 8, in_channels, 1),
            nn.Sigmoid()
        )
        self.spatial_conv = nn.Sequential(
            nn.Conv2d(2, 1, kernel_size=7, padding=3),
            nn.Sigmoid()
        )

    def forward(self, x):
        ch_att = self.channel_fc(x)
        x = x * ch_att
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        sp_att = self.spatial_conv(torch.cat([avg_out, max_out], dim=1))
        return x * sp_att


class TeacherNet(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        base = models.resnet18(pretrained=pretrained)
        self.layer1 = nn.Sequential(*list(base.children())[:5])   # conv1 to layer1
        self.layer2 = base.layer2
        self.layer3 = base.layer3
        self.layer4 = base.layer4

        self.att4 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            SpatialChannelAttention(256)
        )
        self.up3 = nn.ConvTranspose2d(256, 128, 4, 2, 1)
        self.att3 = SpatialChannelAttention(128 + 256)
        self.up2 = nn.ConvTranspose2d(384, 64, 4, 2, 1)
        self.att2 = SpatialChannelAttention(64 + 128)
        self.up1 = nn.ConvTranspose2d(192, 32, 4, 2, 1)
        self.final_up = nn.Sequential(
            nn.ConvTranspose2d(32, 16, 4, 4),
            nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        f1 = self.layer1(x)
        f2 = self.layer2(f1)
        f3 = self.layer3(f2)
        f4 = self.layer4(f3)

        x = self.att4(f4)
        x = self.up3(x)
        x = torch.cat([x, f3], dim=1)
        x = self.att3(x)
        x = self.up2(x)
        x = torch.cat([x, f2], dim=1)
        att_map = self.att2(x)       
        x = att_map
        x = self.up1(x)

        return self.final_up(x), att_map     


# class StudentNet(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.layer1 = nn.Sequential(
#             nn.Conv2d(3, 64, 3, padding=1),
#             nn.ReLU(),
#             SpatialChannelAttention(64)
#         )
#         self.layer2 = nn.Sequential(
#             nn.Conv2d(64, 32, 3, padding=1),
#             nn.ReLU(),
#             SpatialChannelAttention(32)
#         )
#         self.layer3 = nn.Sequential(
#             nn.Conv2d(32, 16, 3, padding=1),
#             nn.ReLU(),
#             SpatialChannelAttention(16)
#         )
#         self.layer4 = nn.Sequential(
#             nn.Conv2d(16, 8, 3, padding=1),
#             nn.ReLU(),
#             SpatialChannelAttention(8)
#         )
#         self.heatmap = nn.Sequential(
#             nn.Conv2d(8, 1, 1),
#             nn.Sigmoid()
#         )

#     def forward(self, x):
#         x = self.layer1(x)
#         x = self.layer2(x)
#         x = self.layer3(x)
#         att_map = self.layer4(x)            
#         heatmap = self.heatmap(att_map)

#         return heatmap, att_map          


class StudentNet(nn.Module):
    def __init__(self):
        super().__init__()

      
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1, stride=2),  # 128x128
            nn.ReLU(),
            SpatialChannelAttention(64)
        )
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, padding=1, stride=2),  # 64x64
            nn.ReLU(),
            SpatialChannelAttention(128)
        )

   
        self.bottleneck = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            SpatialChannelAttention(256)
        )


        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),  
            nn.ReLU(),
            SpatialChannelAttention(128)
        )
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),  
            nn.ReLU(),
            SpatialChannelAttention(64)
        )


        self.heatmap = nn.Sequential(
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x1 = self.enc1(x)
        x2 = self.enc2(x1)
        x_b = self.bottleneck(x2)
        x_d1 = self.dec1(x_b)
        x_d2 = self.dec2(x_d1)
        heatmap = self.heatmap(x_d2)

        att_map = x_b  

        return heatmap, att_map
    
    


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

x = torch.randn(2, 3, 256, 256).to(device)

teacher = TeacherNet(pretrained=False).to(device)
student = StudentNet().to(device)

with torch.no_grad():
    t_map, t_att = teacher(x)
    s_map, s_att = student(x)

print("Teacher Heatmap:", t_map.shape)
print("Teacher Attention Map:", t_att.shape)
print("Student Heatmap:", s_map.shape)
print("Student Attention Map:", s_att.shape)


flops_teacher = FlopCountAnalysis(teacher, x).total()         
flops_student = FlopCountAnalysis(student,  x).total()

params_teacher = parameter_count(teacher)[""]                
params_student = parameter_count(student)[""]


flops_total   = flops_teacher  + flops_student
params_total  = params_teacher + params_student


def to_m(x):
    return f"{x/1e6:.2f}"

df = pd.DataFrame({
    "Model"      : ["TeacherNet", "StudentNet", r"\textbf{MemGaze (Total)}"],
    "FLOPs (M)"  : [to_m(flops_teacher), to_m(flops_student),  to_m(flops_total)],
    "Params (M)" : [to_m(params_teacher), to_m(params_student), to_m(params_total)]
})

print(df.to_markdown(index=False, tablefmt="github"))

latex = df.to_latex(index=False, escape=False, column_format="lcc",
                    caption="Computational cost of the proposed pipeline.",
                    label="tab:flops_params")
print(latex)

In [None]:
def loss_teacher(pred_map, gt_map, kind="smooth_l1"):
 
    if kind == "l1":
        return F.l1_loss(pred_map, gt_map)
    elif kind == "smooth_l1":
        return F.smooth_l1_loss(pred_map, gt_map)
    elif kind == "mse":
        return F.mse_loss(pred_map, gt_map)
    else:
        raise ValueError("Unsupported loss type: choose from ['l1', 'smooth_l1', 'mse']")

In [None]:
def kd_loss_ssim_kl(student_map, teacher_map, alpha=0.5, T=0.2):

    B = student_map.size(0)
    data_rng = max((teacher_map.max() - teacher_map.min()).item(), 1e-8)

    ssim_loss = 1 - ssim(student_map, teacher_map, data_range=data_rng, size_average=True)

    s_logp = F.log_softmax(student_map.view(B, -1) / T, dim=1)
    t_prob = F.softmax(teacher_map.view(B, -1) / T, dim=1)
    kl_loss = F.kl_div(s_logp, t_prob, reduction="batchmean") * (T ** 2)

    return (1 - alpha) * ssim_loss + alpha * kl_loss

In [None]:
def train_teacher_epoch(teacher, loader, optimizer, device):
    teacher.train()
    total_loss = 0.0

    for frames, heatmaps in tqdm(loader, desc="[Teacher Training]"):
        frames = frames.to(device).float()
        heatmaps = heatmaps.to(device).float()
        #print(frames.shape, heatmaps.shape)

        pred_heatmap, _ = teacher(frames) 

        loss = F.mse_loss(pred_heatmap, heatmaps)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

@torch.no_grad()
def eval_teacher(teacher, loader, device):
    teacher.eval()
    total_loss = 0.0

    for frames, heatmaps in loader:
        frames = frames.to(device).float()
        heatmaps = heatmaps.to(device).float()

        pred_heatmap, _ = teacher(frames)
        loss = F.mse_loss(pred_heatmap, heatmaps)

        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
def train_student_epoch(student, teacher, loader, optimizer, device, alpha=0.5, T=0.75):
    student.train()
    teacher.eval()
    total_loss = 0.0

    for frames, _ in tqdm(loader, desc="[Student KD Training]"):
        frames = frames.to(device).float()

        with torch.no_grad():
            t_map, _ = teacher(frames)

        s_map, _ = student(frames)

        loss = kd_loss_ssim_kl(s_map, t_map, alpha=alpha, T=T)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

@torch.no_grad()
def eval_student(student, teacher, loader, device, alpha=0.4, T=1.5):
    student.eval()
    teacher.eval()
    total_loss = 0.0

    for frames, _ in loader:
        frames = frames.to(device).float()

        t_map, _ = teacher(frames)
        s_map, _ = student(frames)

        loss = F.mse_loss(s_map, t_map)#, alpha=alpha, T=T) kd_loss_ssim_kl
        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

teacher = TeacherNet(pretrained=True).to(device)
student = StudentNet().to(device)

opt_t = torch.optim.Adam(teacher.parameters(), lr=1e-4)
opt_s = torch.optim.Adam(student.parameters(), lr=1e-4)
alpha = 0.5
T = 1.5

train_losses, val_losses = [], []


EPOCHS_T = 100
for ep in range(1, EPOCHS_T + 1):
    tr_loss = train_teacher_epoch(teacher, train_loader, opt_t, device)
    val_loss = eval_teacher(teacher, val_loader, device)
    print(f"[Teacher {ep:02d}/{EPOCHS_T}] Train={tr_loss:.4f} | Val={val_loss:.4f}")


EPOCHS_S = 100
for ep in range(1, EPOCHS_S + 1):
    tr_loss_s = train_student_epoch(student, teacher, val_loader, opt_s, device)
    val_loss_s = eval_student(student, teacher, train_loader, device)
    train_losses.append(tr_loss_s)
    val_losses.append(val_loss_s)
    print(f"[Student {ep:02d}/{EPOCHS_S}] KD-Train={tr_loss:.4f} | KD-Val={val_loss:.4f}")
    
torch.save(student.state_dict(), "student_kd.pth")

In [None]:
plt.figure(figsize=(7, 5))
plt.plot(train_losses, label='Train Loss', linestyle='dashdot')#dotted')
plt.plot(val_losses, label='Val Loss', linestyle='dotted')
plt.xlabel('Epochs', weight="semibold")
plt.ylabel('Loss', weight="semibold")
plt.title("Student Knowledge Distillation Training Curve", weight="semibold")
plt.legend(prop=dict(weight='semibold'))
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
def imitation_loss(student_map, gaze_map, kind="smooth_l1"):
    if kind == "l1":
        return F.l1_loss(student_map, gaze_map)
    elif kind == "mse":
        return F.mse_loss(student_map, gaze_map)
    elif kind == "cosine":

        student_flat = student_map.view(student_map.size(0), -1)
        gaze_flat = gaze_map.view(gaze_map.size(0), -1)
        cos_sim = F.cosine_similarity(student_flat, gaze_flat, dim=1)
        return torch.mean(1 - cos_sim)  # Loss = 1 - cosine similarity
    else:
        return F.smooth_l1_loss(student_map, gaze_map)

In [None]:
def train_student_il_epoch(student, teacher, loader, optimizer, device, alpha_kd=0.7, beta_il=0.7, T=1.5):
    student.train()
    teacher.eval()
    total_loss = 0.0

    for frames, gaze_maps in tqdm(loader, desc="[Student KD+IL Training]"):
        frames = frames.to(device).float()
        gaze_maps = gaze_maps.to(device).float()

        with torch.no_grad():
            t_map, _ = teacher(frames)

        s_map, _ = student(frames)

        loss_kd = kd_loss_ssim_kl(s_map, t_map, alpha=alpha_kd, T=T)
        loss_il = imitation_loss(s_map, gaze_maps, kind="cosine")

        loss = loss_kd + beta_il * loss_il

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
student.load_state_dict(torch.load("student_kd.pth"))

student_kd = copy.deepcopy(student)

EPOCHS_IL = 100
il_losses = []

for ep in range(1, EPOCHS_IL + 1):
    il_loss = train_student_il_epoch(student, teacher, train_loader, opt_s, device)
    il_losses.append(il_loss)
    print(f"[Student (IL) {ep:02d}/{EPOCHS_IL}] Loss={il_loss:.4f}")
    
torch.save(student.state_dict(), "student_kd_il.pth")
student_kd_il = copy.deepcopy(student)

In [None]:
plt.figure(figsize=(7, 5))
plt.plot(train_losses, label='Student KD Train Loss', linestyle='solid')#dotted')
#plt.plot(il_losses, label='Student KD + IL Train Loss', linestyle='solid')
plt.xlabel('Epochs', weight="semibold")
#plt.ylabel('Loss', weight="semibold")
plt.ylabel("$T_{Loss}$")
#plt.title("Student Training Loss", weight="semibold")
plt.legend(prop=dict(weight='semibold'))
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
def normalize_heatmap(hm):
    hm = hm - hm.min()
    return hm / (hm.max() + 1e-8)

def denormalize_img(t):
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    t = t.permute(1, 2, 0).cpu().numpy()  # (H,W,3)
    img = (t * std) + mean
    return np.clip(img, 0, 1)

def plot_saliency_overlay_comparison(frames, gaze_maps, kd_maps, kd_il_maps, batch_idx=0, save_dir="compare_outputs", prefix="multi"):
    os.makedirs(save_dir, exist_ok=True)
    N = min(frames.size(0), 6)
    fig, axs = plt.subplots(nrows=N, ncols=4, figsize=(20, 4.5 * N), dpi=300)

    for i in range(N):
        img = denormalize_img(frames[i])

        gaze_map   = normalize_heatmap(gaze_maps[i, 0].cpu().numpy())
        kd_map     = normalize_heatmap(kd_maps[i, 0].cpu().numpy())
        kd_il_map  = normalize_heatmap(kd_il_maps[i, 0].cpu().numpy())

        axs[i, 0].imshow(img)
        axs[i, 0].set_title("Input", fontsize=25, weight = 'semibold')
        axs[i, 0].axis('off')

        axs[i, 1].imshow(img, alpha=1.0)
        axs[i, 1].imshow(gaze_map, cmap='inferno', alpha=0.5)
        axs[i, 1].set_title("GT Gaze", fontsize=25, weight = 'semibold')
        axs[i, 1].axis('off')

        axs[i, 2].imshow(img, alpha=1.0)
        axs[i, 2].imshow(kd_map, cmap='jet', alpha=0.4)
        axs[i, 2].set_title("Student (KD)", fontsize=25, weight = 'semibold')
        axs[i, 2].axis('off')

        axs[i, 3].imshow(img, alpha=1.0)
        axs[i, 3].imshow(kd_il_map,  cmap='jet', alpha=0.4)
        axs[i, 3].set_title("Student (KD+IL)", fontsize=25, weight = 'semibold')
        axs[i, 3].axis('off')

    plt.tight_layout()
    #out_path = os.path.join(save_dir, f"{prefix}_batch{batch_idx}_overlay.png")
    #plt.savefig(out_path, dpi=300, bbox_inches='tight')
    #plt.close()
    plt.show()


In [None]:
student_kd.eval()
student_kd_il.eval()
teacher.eval()

MAX_BATCHES = 20  

for b_idx, (frames, gaze_maps) in enumerate(train_loader):
    if b_idx >= MAX_BATCHES:
        break

    frames = frames.to(device)
    gaze_maps = gaze_maps.to(device)

    with torch.no_grad():
        _, s_kd_maps    = student_kd(frames)
        _, s_kd_il_maps = student_kd_il(frames)

    plot_saliency_overlay_comparison(frames, gaze_maps, s_kd_maps, s_kd_il_maps,
                                  batch_idx=b_idx,
                                  save_dir="compare_outputs",
                                  prefix="train")

In [None]:
def normalize_heatmap(hm):
    hm = hm - hm.min()
    hm = hm / (hm.max() + 1e-8)
    return hm

def plot_saliency_comparison(frames, gaze_maps, student_kd_maps, student_kd_il_maps, save_dir="compare_outputs", prefix="comp"):
    os.makedirs(save_dir, exist_ok=True)

    for i in range(frames.size(0)):
        img = frames[i].permute(1, 2, 0).cpu().numpy()
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        gaze = normalize_heatmap(gaze_maps[i, 0].cpu().numpy())
        kd_map = normalize_heatmap(student_kd_maps[i, 0].cpu().numpy())
        kd_il_map = normalize_heatmap(student_kd_il_maps[i, 0].cpu().numpy())

        fig, axs = plt.subplots(1, 4, figsize=(14, 3))
        axs[0].imshow(img)
        axs[0].set_title("Input")
        axs[0].axis('off')

        axs[1].imshow(img, alpha=0.5)
        axs[1].imshow(gaze)#, cmap='jet', alpha=0.5)
        axs[1].set_title("Ground Truth (Gaze)")
        axs[1].axis('off')

        axs[2].imshow(img, alpha=0.)
        axs[2].imshow(kd_map)#, cmap='jet', alpha=0.5)
        axs[2].set_title("Student (KD)")
        axs[2].axis('off')

        axs[3].imshow(img, alpha=0.8)
        axs[3].imshow(kd_il_map)#, cmap='jet', alpha=0.5)
        axs[3].set_title("Student (KD+IL)")
        axs[3].axis('off')

        plt.tight_layout()
        #plt.savefig(os.path.join(save_dir, f"{prefix}_{i}.png"), dpi=150, bbox_inches='tight')
        #plt.close()
        plt.show()

In [None]:
student_kd.eval()
student_kd_il.eval()

frames, gaze_maps = next(iter(val_loader))
frames = frames.to(device)
gaze_maps = gaze_maps.to(device)

with torch.no_grad():
    _, s_kd_maps = student_kd(frames)
    _, s_kd_il_maps = student_kd_il(frames)

plot_saliency_comparison(frames, gaze_maps, s_kd_maps, s_kd_il_maps, prefix="val_batch")

In [None]:
def normalize_heatmap(hm):
    hm = hm - hm.min()
    return hm / (hm.max() + 1e-8)

def plot_multiple_saliency_comparisons(frames, gaze_maps, kd_maps, kd_il_maps, N=6, save_dir="compare_outputs", prefix="multi"):
    os.makedirs(save_dir, exist_ok=True)
    N = min(N, frames.size(0))  

    fig, axs = plt.subplots(nrows=N, ncols=4, figsize=(14, 3 * N))

    if isinstance(axs, np.ndarray):
        if axs.ndim == 1:
            axs = axs.reshape((N, 4))
    else:
        axs = np.array([[axs]]) 

    for i in range(N):
        img = frames[i].permute(1, 2, 0).cpu().numpy()
        img = (img - img.min()) / (img.max() + 1e-8)

        gaze     = normalize_heatmap(gaze_maps[i, 0].cpu().numpy())
        kd_map   = normalize_heatmap(kd_maps[i, 0].cpu().numpy())
        kd_il_map= normalize_heatmap(kd_il_maps[i, 0].cpu().numpy())

        axs[i, 0].imshow(img)
        axs[i, 0].set_title("Input")
        axs[i, 0].axis('off')

        axs[i, 1].imshow(img, alpha=0.8)
        axs[i, 1].imshow(gaze, alpha=0.5)#, cmap='jet', alpha=0.5)
        axs[i, 1].set_title("GT Gaze")
        axs[i, 1].axis('off')

        axs[i, 2].imshow(img, alpha=0.8)
        axs[i, 2].imshow(kd_map, alpha=0.5)#, cmap='jet', alpha=0.5)
        axs[i, 2].set_title("Student (KD)")
        axs[i, 2].axis('off')

        axs[i, 3].imshow(img, alpha=0.8)
        axs[i, 3].imshow(kd_il_map, alpha=0.5)#, cmap='jet', alpha=0.5)
        axs[i, 3].set_title("Student (KD+IL)")
        axs[i, 3].axis('off')

    plt.tight_layout()
    #out_path = os.path.join(save_dir, f"{prefix}_samples.png")
    #plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
student_kd.eval()
student_kd_il.eval()
teacher.eval()

frames, gaze_maps = next(iter(train_loader))
#print("Batch size:", frames.size(3))  # Add this

frames = frames.to(device)
gaze_maps = gaze_maps.to(device)

with torch.no_grad():
    _, s_kd_maps = student_kd(frames)
    _, s_kd_il_maps = student_kd_il(frames)

# Set N <= batch size
plot_multiple_saliency_comparisons(frames, gaze_maps, s_kd_maps, s_kd_il_maps, N=5)

In [None]:
student_kd.eval()
student_kd_il.eval()
teacher.eval()

MAX_BATCHES = 20  

for b_idx, (frames, gaze_maps) in enumerate(val_loader):
    if b_idx >= MAX_BATCHES:
        break

    frames = frames.to(device)
    gaze_maps = gaze_maps.to(device)

    with torch.no_grad():
        _, s_kd_maps    = student_kd(frames)
        _, s_kd_il_maps = student_kd_il(frames)

    plot_saliency_overlay_comparison(frames, gaze_maps, s_kd_maps, s_kd_il_maps,
                                  batch_idx=b_idx,
                                  save_dir="compare_outputs",
                                  prefix="train")

# Rough Codes

In [None]:
def loss_teacher(pred_map, gt_map, kind="smooth_l1"):
    if kind == "l1":
        return F.l1_loss(pred_map, gt_map)
    elif kind == "smooth_l1":
        return F.smooth_l1_loss(pred_map, gt_map)
    else:
        return F.mse_loss(pred_map, gt_map)


def kd_loss_ssim_kl(student_map, teacher_map, alpha=0.4, T=1.5):
    B = student_map.size(0)
    data_rng = max((teacher_map.max() - teacher_map.min()).item(), 1e-8)  # Safe range
    ssim_loss = 1 - ssim(student_map, teacher_map, data_range=data_rng, size_average=True)

    s_logp = F.log_softmax(student_map.view(B, -1) / T, dim=1)
    t_prob = F.softmax(teacher_map.view(B, -1) / T, dim=1)
    kl_loss = F.kl_div(s_logp, t_prob, reduction="batchmean") * (T ** 2)

    return (1 - alpha) * ssim_loss + alpha * kl_loss

In [None]:
def train_teacher_epoch(teacher, loader, opt, device):
    teacher.train()
    total_loss = 0.0

    for frames, heatmaps in tqdm(loader, desc="Teacher-Train"):
        frames = frames.to(device).float()
        heatmaps = heatmaps.to(device).float()

        pred_heatmap = teacher(frames)  # only heatmap

        loss = F.smooth_l1_loss(pred_heatmap, heatmaps)

        opt.zero_grad()
        loss.backward()
        opt.step()

        total_loss += loss.item()

    return total_loss / len(loader)


@torch.no_grad()
def eval_teacher(teacher, loader, device):
    teacher.eval()
    total_loss = 0.0

    for frames, heatmaps in loader:
        frames = frames.to(device).float()
        heatmaps = heatmaps.to(device).float()

        pred_heatmap = teacher(frames)
        loss = F.mse_loss(pred_heatmap, heatmaps)
        total_loss += loss.item()

    return total_loss / len(loader)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

teacher = TeacherNet(pretrained=True).to(device)
student = StudentNet().to(device)

opt_t = torch.optim.Adam(teacher.parameters(), lr=1e-4)
opt_s = torch.optim.Adam(student.parameters(), lr=1e-4)

EPOCHS_T = 50
t_train, t_val = [], []

for ep in range(1, EPOCHS_T + 1):
    tr_loss = train_teacher_epoch(teacher, train_loader, opt_t, device)
    val_loss = eval_teacher(teacher, val_loader, device)

    t_train.append(tr_loss)
    t_val.append(val_loss)

    print(f"[Teacher {ep:02d}/{EPOCHS_T}] TrainLoss={tr_loss:.4f} | ValLoss={val_loss:.4f}")

In [None]:
EPOCHS = 3           
log_every = 20        
opt_teacher = torch.optim.Adam(teacher.parameters(), lr=4e-5)
opt_student = torch.optim.Adam(student.parameters(), lr=4e-5)

for ep in range(1, EPOCHS + 1):
    teacher.train(); student.train()
    tot_T, tot_S, tot_KD, tot_IM = 0, 0, 0, 0

    for b_idx, (frames, gazes) in enumerate(train_loader, 1):
        frames, gazes = frames.to(DEVICE), gazes.to(DEVICE)
        B = frames.size(0)

        pred_teacher   = teacher(frames)                  # (B,1,256,256)
        t_loss         = loss_teacher(pred_teacher, gazes)

        opt_teacher.zero_grad()
        t_loss.backward()
        opt_teacher.step()

        with torch.no_grad():
            teacher_maps = pred_teacher.detach()         

        pred_student   = student(frames)                
        s_loss, kd_l, im_l = kd_loss_ssim_kl(
                                pred_student, teacher_maps, gazes
                             )

        opt_student.zero_grad()
        s_loss.backward()
        opt_student.step()
        tot_T  += t_loss.item()
        tot_S  += s_loss.item()
        tot_KD += kd_l
        tot_IM += im_l

        if b_idx % log_every == 0:
            print(f"  [ep{ep} | step {b_idx}] "
                  f"T {t_loss.item():.4f} | S {s_loss.item():.4f} "
                  f"(KD {kd_l:.4f}, IM {im_l:.4f})")

    N = len(train_loader)
    print(f"Epoch {ep}/{EPOCHS}  "
          f"T {tot_T/N:.4f} | S {tot_S/N:.4f}  "
          f"KD {tot_KD/N:.4f} | IM {tot_IM/N:.4f}")

In [None]:
def show_heatmaps(X,T,S,G,count=2):
    for i in range(min(count, X.size(0))):
        fig,ax=plt.subplots(1,4,figsize=(14,3)); [a.axis('off') for a in ax]
        ax[0].imshow(X[i].permute(1,2,0).cpu()); ax[0].set_title("Input")
        ax[1].imshow(T[i,0].cpu(),cmap='hot');   ax[1].set_title("Teacher")
        ax[2].imshow(S[i,0].cpu(),cmap='hot');   ax[2].set_title("Student")
        ax[3].imshow(G[i,0].cpu(),cmap='hot');   ax[3].set_title("GT Gaze")
        plt.tight_layout(); plt.show()

In [None]:
@torch.no_grad()
def evaluate(model, loader, vis=True):
    model.eval(); teacher.eval()
    rhos, mses = [], []
    for X,G in loader:
        X,G = X.to(DEVICE), G.to(DEVICE)
        S = model(X)
        for i in range(X.size(0)):
            s = S[i,0].cpu().numpy(); g = G[i,0].cpu().numpy()
            rhos.append(spearmanr(s.flatten(), g.flatten()).correlation)
            mses.append(np.mean((s-g)**2))
        if vis:
            show_heatmaps(X, teacher(X), S, G, count=2)
            vis=False
    print(f"Spearman ρ : {np.mean(rhos):.4f}")
    print(f"MSE        : {np.mean(mses):.4f}")

evaluate(student, test_loader)

In [None]:
def visualize_single_sample(student, teacher, loader, device):
    student.eval()
    teacher.eval()

    batch_iter = iter(loader)
    frames, gazes = next(batch_iter)
    frames = frames.to(device).float()

    with torch.no_grad():
        teacher_out = teacher(frames).squeeze(1).cpu().numpy()
        student_out = student(frames).squeeze(1).cpu().numpy()

    rgb = frames[0].permute(1, 2, 0).cpu().numpy()
    rgb = (rgb - rgb.min()) / (rgb.max() - rgb.min())

    t_map = normalize_heatmap(teacher_out[0])
    s_map = normalize_heatmap(student_out[0])

    fig, axs = plt.subplots(1, 3, figsize=(14, 4))
    axs[0].imshow(rgb)
    axs[0].set_title("Input Image")

    axs[1].imshow(rgb, alpha=0.5)
    axs[1].imshow(t_map, cmap='jet', alpha=0.5)
    axs[1].set_title("Teacher Heatmap")

    axs[2].imshow(rgb, alpha=0.5)
    axs[2].imshow(s_map, cmap='jet', alpha=0.5)
    axs[2].set_title("Student Heatmap")

    for ax in axs:
        ax.axis('off')

    plt.suptitle("Saliency Transfer - Sanity Check", fontsize=14)
    plt.tight_layout()
    plt.show()
