<a href="https://colab.research.google.com/github/dansojo/Medical_CV/blob/main/CheXNet_(DenseNet_base)_model_training_%EB%8F%99%EB%AC%BC_%EA%B7%BC%EA%B3%A8%EA%B2%A9%EA%B3%84.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.optim as optim

흑백 / 3채널 이미지 수

In [None]:
# # 이미지 경로 설정
# root_dir = '/content/drive/MyDrive/Medical_CV/3/전처리_데이터'

# # 각 split 폴더 (train, val, test) 확인
# splits = ['train', 'val', 'test']

# single_channel_count = 0
# three_channel_count = 0

# for split in splits:
#     split_dir = os.path.join(root_dir, split)
#     files = [f for f in os.listdir(split_dir) if f.endswith('.npy')]

#     for file in files:
#         file_path = os.path.join(split_dir, file)

#         # 넘파이 파일 불러오기
#         data = np.load(file_path, allow_pickle=True).item()
#         image_np = data['image']

#         # 채널 수 확인
#         if (len(image_np.shape) == 2) or (len(image_np.shape) == 3 and image_np.shape[0] == 1):  # 흑백 이미지 (H, W)
#             single_channel_count += 1
#         elif len(image_np.shape) == 3 and image_np.shape[2] == 3:  # 3채널 이미지 (H, W, 3)
#             three_channel_count += 1

# # 결과 출력
# print(f"Single-channel (Grayscale) images: {single_channel_count}")
# print(f"Three-channel (RGB) images: {three_channel_count}")

**Custom Dataset 클래스 정의**

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, apply_clahe=False):
        self.root_dir = root_dir
        self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.npy')]
        self.transform = transform
        self.apply_clahe = apply_clahe

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # 넘파이 파일 불러오기
        data = np.load(self.files[idx], allow_pickle=True).item()
        image_np, label = data['image'], data['label']

        # (1, H, W) 형식을 (3, H, W)로 변환
        if image_np.shape[0] == 1:  # 단일 채널인 경우
            image_np = np.repeat(image_np, 3, axis=0)  # 채널을 3번 반복하여 (3, H, W)로 확장
        # ToTensor 변환 적용
        image_tensor = torch.tensor(image_np, dtype=torch.float32)

        # (H, W, 3) -> (3, H, W)로 변환 필요 시 permute 사용
        if image_tensor.shape[0] != 3:
            image_tensor = image_tensor.permute(2, 0, 1)


        return image_tensor, label

In [None]:
# CheXNet (DenseNet-121) 모델 정의
class CheXNetModel(nn.Module):
    def __init__(self, num_classes=4):
        super(CheXNetModel, self).__init__()
        # DenseNet-121 기반 모델 초기화
        self.model = models.densenet121(pretrained=True)

        in_features = self.model.classifier.in_features
        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.5),  # Dropout 추가 (과적합 방지)
            nn.Linear(in_features, num_classes)
        )
    def forward(self, x):
        return self.model(x)

In [None]:
# 모델 생성 및 장치 설정
num_classes = 4  # 클래스 수 (갈비뼈골절, 슬개골탈구, 전십자인대파열, 추간판질환)
chexnet_model = CheXNetModel(num_classes=num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
chexnet_model = chexnet_model.to(device)

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 123MB/s]


In [None]:
# 옵티마이저 및 학습 하이퍼파라미터 설정
learning_rate = 1e-5  # CheXNet에 적합한 초기 러닝 레이트
optimizer = optim.Adam(chexnet_model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
# 데이터셋 경로 설정
train_data_path = '/content/drive/MyDrive/Medical_CV/3/전처리_데이터/train'
val_data_path = '/content/drive/MyDrive/Medical_CV/3/전처리_데이터/val'


train_dataset = CustomDataset(root_dir=train_data_path)
val_dataset = CustomDataset(root_dir=val_data_path)

# 데이터 로더 준비
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
# 학습 함수 정의
def train_model(model, optimizer, criterion, train_loader, val_loader, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()


        # 에포크마다 평균 손실 계산 및 출력
        avg_train_loss = running_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_train_loss:.4f}")

        # 검증 단계
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        # 정확도 출력
        val_accuracy = 100 * correct / total
        print(f"Validation Accuracy: {val_accuracy:.2f}%")

# 모델 학습
train_model(chexnet_model, optimizer, criterion, train_loader, val_loader, num_epochs=10)

Epoch [1/10], Loss: 0.8257
Validation Accuracy: 77.25%
Epoch [2/10], Loss: 0.3369
Validation Accuracy: 79.99%
Epoch [3/10], Loss: 0.2258
Validation Accuracy: 81.13%
Epoch [4/10], Loss: 0.1688
Validation Accuracy: 83.56%
Epoch [5/10], Loss: 0.1376
Validation Accuracy: 90.38%
Epoch [6/10], Loss: 0.1181
Validation Accuracy: 95.35%
Epoch [7/10], Loss: 0.0954
Validation Accuracy: 92.99%
Epoch [8/10], Loss: 0.0828
Validation Accuracy: 90.50%
Epoch [9/10], Loss: 0.0651
Validation Accuracy: 93.31%
Epoch [10/10], Loss: 0.0608
Validation Accuracy: 93.44%


In [None]:
save_dir = '/content/drive/MyDrive/Medical_CV/models'
os.makedirs(save_dir, exist_ok=True)  # 경로가 없으면 생성

# 저장 파일 이름 설정
save_path = os.path.join(save_dir, 'chexnet_weights_V2.pth')

In [None]:
torch.save(chexnet_model.state_dict(), save_path)

In [None]:
# # 모델 클래스를 다시 정의하고 초기화 후, 가중치 불러오기
# loaded_model = CheXNetModel(num_classes=4)  # num_classes는 모델 클래스 정의 시 지정했던 값과 일치해야 합니다
# loaded_model.load_state_dict(torch.load('chexnet_weights.pth'))
# loaded_model = loaded_model.to(device)
