In [None]:
# 필요한 라이브러리 설치
!pip install kaggle --quiet
!pip install torch torchvision albumentations --quiet

# Kaggle API Key 업로드 (이 코드 실행 후, 파일 업로드 필요)
from google.colab import files
files.upload()

# Kaggle 데이터 다운로드 및 압축 해제
!mkdir -p ~/.kaggle
!mv kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
!kaggle datasets download -d mateuszbuda/lgg-mri-segmentation
!unzip -q lgg-mri-segmentation.zip -d data

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.7/41.7 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.3/80.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.0/66.0 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m290.6/290.6 kB[0m [31m10.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.0/50.0 MB[0m [31m25.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m632.7/632.7 kB[0m [31m40.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m307.6/307.6 kB[0m [31m22.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [3]:
import os

IMAGE_DIR = "data/lgg-mri-segmentation/kaggle_3m"

# 모든 .tif 파일을 찾기 위해 하위 폴더까지 탐색
image_paths = []
mask_paths = []

for folder in os.listdir(IMAGE_DIR):
    folder_path = os.path.join(IMAGE_DIR, folder)

    if os.path.isdir(folder_path):  # 폴더라면?
        for file in os.listdir(folder_path):  # 폴더 안의 파일 리스트 가져오기
            file_path = os.path.join(folder_path, file)

            if file.lower().endswith(".tif"):  # tif 파일만 필터링
                if "mask" in file.lower():
                    mask_paths.append(file_path)
                else:
                    image_paths.append(file_path)

# 정렬 (파일 순서 유지)
image_paths = sorted(image_paths)
mask_paths = sorted(mask_paths)

print(f"✅ 총 이미지 개수: {len(image_paths)}")
print(f"✅ 총 마스크 개수: {len(mask_paths)}")

# 예외 처리: 데이터가 없으면 오류 발생
if len(image_paths) == 0 or len(mask_paths) == 0:
    raise ValueError("❌ 데이터셋이 비어 있습니다. 경로 및 파일 확장자를 확인하세요.")


✅ 총 이미지 개수: 3929
✅ 총 마스크 개수: 3929


In [6]:
!pip install --upgrade torch torchvision torchaudio

Collecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cubla

In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import cv2

# 데이터 경로 설정
# Updated paths to reflect the correct directory structure
IMAGE_DIR = "data/lgg-mri-segmentation/kaggle_3m"
MASK_DIR = "data/lgg-mri-segmentation/kaggle_3m"

# 모든 .tif 파일을 찾기 위해 하위 폴더까지 탐색
image_paths = []
mask_paths = []

for folder in os.listdir(IMAGE_DIR):
    folder_path = os.path.join(IMAGE_DIR, folder)

    if os.path.isdir(folder_path):  # 폴더라면?
        for file in os.listdir(folder_path):  # 폴더 안의 파일 리스트 가져오기
            file_path = os.path.join(folder_path, file)

            if file.lower().endswith(".tif"):  # tif 파일만 필터링
                if "mask" in file.lower():
                    mask_paths.append(file_path)
                else:
                    image_paths.append(file_path)

# 정렬 (파일 순서 유지)
image_paths = sorted(image_paths)
mask_paths = sorted(mask_paths)

# 데이터셋 클래스 정의
class BrainMRIDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = cv2.imread(self.image_paths[idx], cv2.IMREAD_GRAYSCALE)  # 흑백 로드
        mask = cv2.imread(self.mask_paths[idx], cv2.IMREAD_GRAYSCALE)

        # 크기 변경 (256x256)
        image = cv2.resize(image, (256, 256))
        mask = cv2.resize(mask, (256, 256))

        # 정규화 및 변환
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image, mask = augmented["image"], augmented["mask"]

        return image, mask

# 데이터 변환 정의
transform = A.Compose([
    A.Normalize(mean=0.5, std=0.5),
    ToTensorV2()
])

# 데이터셋 및 DataLoader 생성
train_size = int(0.8 * len(image_paths))
train_dataset = BrainMRIDataset(image_paths[:train_size], mask_paths[:train_size], transform)
val_dataset = BrainMRIDataset(image_paths[train_size:], mask_paths[train_size:], transform)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

In [2]:
# U-Net 모델 정의
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet, self).__init__()

        def conv_block(in_channels, out_channels):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                nn.ReLU(inplace=True),
            )

        self.encoder = nn.Sequential(
            conv_block(1, 64),
            nn.MaxPool2d(2),
            conv_block(64, 128),
            nn.MaxPool2d(2)
        )

        self.decoder = nn.Sequential(
            conv_block(128, 64),
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2), # Upsampling to match input size
            conv_block(64, 32),  # Added another conv_block for better feature extraction
            nn.ConvTranspose2d(32, 32, kernel_size=2, stride=2), # Upsampling to match input size
            nn.Conv2d(32, 1, kernel_size=1) # Output layer with 1 channel
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return torch.sigmoid(x)

# 모델 생성
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)

In [1]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu

# ✅ TPU 디바이스 확인 및 설정
device = xm.xla_device()
print("TPU 디바이스:", device)

# ❗ 모델을 TPU로 이동
model = model.to(device)

# ✅ 손실 함수 수정 (sigmoid 제거)
criterion = torch.nn.BCEWithLogitsLoss()  # 모델 출력에 sigmoid 적용 안 함
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

# TPU용 데이터 로더 적용 (배치 크기 줄이기)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

# 🔹 XLA 텐서 변환 (오류 해결 핵심)
def to_xla(tensor):
    return tensor.to(device, dtype=torch.float32)

# 모델 학습 (TPU 최적화)
num_epochs = 5  # TPU에서는 빠르게 처리 가능
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0

    for images, masks in train_loader:
        # 🔹 XLA 텐서 강제 변환 (이 코드가 중요!)
        images = to_xla(images)
        masks = to_xla(masks) / 255.0  # ✅ 마스크 정규화 (0~1 범위)

        optimizer.zero_grad()
        outputs = model(images)  # ✅ sigmoid 제거
        loss = criterion(outputs, masks)  # ✅ BCEWithLogitsLoss() 그대로 사용
        loss.backward()

        # ✅ TPU에서 최적화된 optimizer step 적용 (위치 변경)
        xm.optimizer_step(optimizer, barrier=True)

        epoch_loss += loss.item()

    # 🔹 TPU에서 강제 동기화 (오류 방지, 위치 변경)
    xm.mark_step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")

# 모델 평가 및 결과 시각화 (TPU 적용)
model.eval()
with torch.no_grad():
    sample_img, sample_mask = next(iter(val_loader))

    # 🔹 XLA 텐서 변환 적용 (오류 방지)
    sample_img = to_xla(sample_img)
    pred_mask = torch.sigmoid(model(sample_img)).cpu().numpy()  # ✅ 평가 시에는 sigmoid 적용

# 시각화
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 6))
for i in range(4):
    plt.subplot(3, 4, i+1)
    plt.imshow(sample_img[i][0].cpu(), cmap="gray")
    plt.title("MRI Image")
    plt.axis("off")

    plt.subplot(3, 4, i+5)
    plt.imshow(sample_mask[i].cpu(), cmap="gray")
    plt.title("Ground Truth")
    plt.axis("off")

    plt.subplot(3, 4, i+9)
    plt.imshow(pred_mask[i][0] > 0.5, cmap="gray")  # ✅ 0.5 기준으로 이진화
    plt.title("Predicted Mask")
    plt.axis("off")

plt.show()


KeyboardInterrupt: 