# MIL 모델: WSI 하나 전체를 입력 (Bag 단위)

In [ ]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import requests
from io import BytesIO
import pandas as pd
import torchvision
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import torch.nn as nn
import torch.optim as optim
import random
from collections import defaultdict, Counter


In [ ]:
def make_mil_data_index(repo_id, label_csv_path):
    df = pd.read_csv(label_csv_path)
    filename_to_label = dict(zip(df['filename'], df['label']))
    data_index = []
    for fname, label in filename_to_label.items():
        fname_with_ext = fname if fname.endswith(".npz") else f"{fname}.npz"
        url = f"https://huggingface.co/datasets/{repo_id}/resolve/main/{fname_with_ext}"
        data_index.append((url, label))
    return data_index


In [ ]:
class MILDataset(Dataset):
    def __init__(self, data_index, transform=None):
        self.data_index = data_index
        self.transform = transform

    def __len__(self):
        return len(self.data_index)

    def __getitem__(self, idx):
        url, label = self.data_index[idx]
        response = requests.get(url)
        npz = np.load(BytesIO(response.content))
        patches = []
        for key in npz.files:
            patch = npz[key]
            if patch.ndim == 2:
                patch = Image.fromarray(patch.astype(np.uint8), mode='L')
            elif patch.shape[-1] == 3:
                patch = Image.fromarray(patch.astype(np.uint8), mode='RGB')
            else:
                patch = Image.fromarray(patch.astype(np.uint8))
            if self.transform:
                patch = self.transform(patch)
            patches.append(patch)
        patch_tensor = torch.stack(patches)  # Shape: (N, C, H, W)
        return patch_tensor, int(label)
