In [1]:
import numpy as np
from pathlib import Path
from astropy.io import fits
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import Dataset, DataLoader

# -----------------------------
# 1. 경로 설정
# -----------------------------
CLASS_PATHS = {
    "slsim_lenses": "/caefs/data/IllustrisTNG/slchallenge/slsim_lenses/slsim_lenses",
    "slsim_nonlenses": "/caefs/data/IllustrisTNG/slchallenge/slsim_nonlenses/slsim_nonlenses",
    "hsc_lenses": "/caefs/data/IllustrisTNG/slchallenge/hsc_lenses/hsc_lenses",
    "hsc_nonlenses": "/caefs/data/IllustrisTNG/slchallenge/hsc_nonlenses/hsc_nonlenses",
}

# -----------------------------
# 2. 파일 리스트와 라벨 준비
# -----------------------------
def prepare_file_list():
    files, labels = [], []
    for cls, path in CLASS_PATHS.items():
        if "nonlenses" in cls:  # 먼저 비렌즈 확인
            label = 0
        else:                   # 그 외는 렌즈
            label = 1
        for f in Path(path).glob("*.fits"):
            files.append(str(f))
            labels.append(label)
    return np.array(files), np.array(labels)


files, labels = prepare_file_list()
print(f"총 파일 수: {len(files)}")
print(f"렌즈 수: {labels.sum()}")
print(f"비렌즈 수: {len(labels) - labels.sum()}")



총 파일 수: 1000000
렌즈 수: 500000
비렌즈 수: 500000


In [4]:

# -----------------------------
# 3. Train/Val/Test Split
# -----------------------------
# 첫 번째 split: test 15%
train_files, test_files, train_labels, test_labels = train_test_split(
    files, labels, test_size=0.15, stratify=labels, random_state=42
)

# 두 번째 split: 남은 데이터 중 15/85 ≈ 0.1765 → 15%를 유지하려면 0.1765 비율로 split
train_files, val_files, train_labels, val_labels = train_test_split(
    train_files, train_labels, test_size=0.1765, stratify=train_labels, random_state=42
)

# -----------------------------
# 4. PyTorch Dataset 정의
# -----------------------------
class LensDataset(Dataset):
    def __init__(self, file_list, labels, transform=None):
        self.file_list = file_list
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        # FITS 이미지 로드
        with fits.open(self.file_list[idx]) as hdul:
            img = np.array(hdul[1].data[0][1]).reshape(41, 41)

        # 정규화: Z-score
        img = (img - img.mean()) / (img.std() + 1e-8)

        # 채널 차원 추가 (1, H, W)
        img = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        label = torch.tensor(self.labels[idx], dtype=torch.long)

        # 추가 변환 (증강 등)
        if self.transform:
            img = self.transform(img)

        return img, label


In [5]:
# -----------------------------
# 5. DataLoader 준비
# -----------------------------
train_dataset = LensDataset(train_files, train_labels)
val_dataset = LensDataset(val_files, val_labels)
test_dataset = LensDataset(test_files, test_labels)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")


Train: 699975, Val: 150025, Test: 150000
