In [1]:
# ===============================
# 全图 Heatmap 回归（Baseline）
# 不用 YOLO / 不用 ROI
# 直接在整张图上预测中心凹 Heatmap
# ===============================

import os, cv2, torch, numpy as np, pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# -------------------------------

In [2]:
# 配置路径（按 Kaggle / 本地自行调整）
# -------------------------------
ROOT = '/kaggle/input/eye-data'
TRAIN_IMG_DIR = f'{ROOT}/detection/train'
TEST_IMG_DIR  = f'{ROOT}/detection/test'
GT_CSV = f'{ROOT}/detection/fovea_localization_train_GT.csv'

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
# -------------------------------
# Gaussian Heatmap
# -------------------------------
def draw_gaussian(hm, center, sigma=2):
    x0, y0 = center
    H, W = hm.shape
    xs = np.arange(W)
    ys = np.arange(H)[:, None]
    g = np.exp(-((xs - x0)**2 + (ys - y0)**2) / (2 * sigma**2))
    hm[:] = np.maximum(hm, g)

# -------------------------------
# Dataset
# -------------------------------
class FullImageHeatmapDataset(Dataset):
    def __init__(self, img_dir, csv_path=None, train=True):
        self.img_dir = img_dir
        self.train = train
        self.t = transforms.ToTensor()

        if train:
            df = pd.read_csv(csv_path)
            self.samples = []
            for _, r in df.iterrows():
                idx = int(r['data'])
                fname = f'{idx:04d}.jpg'
                self.samples.append((fname, int(r.iloc[1]), int(r.iloc[2])))
        else:
            self.samples = sorted(os.listdir(img_dir))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, i):
        if self.train:
            fname, x, y = self.samples[i]
        else:
            fname = self.samples[i]

        img = cv2.imread(os.path.join(self.img_dir, fname))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h0, w0 = img.shape[:2]
        
        hm = np.zeros((64,64), np.float32)
        if not (x == 0 and y == 0):
            hx = int(x / w0 * 64)
            hy = int(y / h0 * 64)
            hx = np.clip(hx, 0, 63)
            hy = np.clip(hy, 0, 63)
            draw_gaussian(hm, (hx, hy))
        
        img = cv2.resize(img, (512,512))
        img_t = self.t(img)

        if not self.train:
            return img_t, fname

        hm = np.zeros((64, 64), np.float32)
        if not (x == 0 and y == 0):
            hx = int(x / img.shape[1] * 64)
            hy = int(y / img.shape[0] * 64)
            draw_gaussian(hm, (hx, hy))

        return img_t, torch.from_numpy(hm).unsqueeze(0)

In [None]:
# -------------------------------
# Heatmap Network
# -------------------------------
class HeatmapNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(),
            nn.Conv2d(256, 1, 1)
        )

    def forward(self, x):
        x = self.net(x)
        return torch.nn.functional.interpolate(x, (64, 64), mode='bilinear')

# -------------------------------
# Train
# -------------------------------
train_ds = FullImageHeatmapDataset(TRAIN_IMG_DIR, GT_CSV, train=True)
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)

model = HeatmapNet().to(DEVICE)
opt = torch.optim.Adam(model.parameters(), 1e-4)
lossf = nn.MSELoss()

for e in range(40):
    model.train()
    s = 0
    for x, y in train_dl:
        x, y = x.to(DEVICE), y.to(DEVICE)
        p = model(x)
        loss = lossf(p, y)
        opt.zero_grad(); loss.backward(); opt.step()
        s += loss.item()
    print(f'Epoch {e}: {s/len(train_dl):.4f}')

torch.save(model.state_dict(), 'heatmap_full.pth')

Epoch 0: 0.0000


In [5]:
# -------------------------------
# Inference + Submission
# -------------------------------
def heatmap_to_coord(hm):
    hm = hm.reshape(-1)
    idx = hm.argmax()
    y, x = divmod(idx, 64)
    return x, y

model.load_state_dict(torch.load('heatmap_full.pth', map_location=DEVICE))
model.eval()

test_ds = FullImageHeatmapDataset(TEST_IMG_DIR, train=False)
test_dl = DataLoader(test_ds, batch_size=1)

rows = []
for img, fname in test_dl:
    img = img.to(DEVICE)
    with torch.no_grad():
        hm = model(img)[0,0].cpu().numpy()

    hx, hy = heatmap_to_coord(hm)
    x = int(hx / 64 * 512)
    y = int(hy / 64 * 512)

    idx = int(fname[0].replace('.jpg',''))
    rows += [
        {'index': f'{idx}_Fovea_X', 'value': x},
        {'index': f'{idx}_Fovea_Y', 'value': y}
    ]

pd.DataFrame(rows).to_csv('submission.csv', index=False)
print('submission.csv saved')

submission.csv saved
