<a href="https://colab.research.google.com/github/dansojo/Medical_CV/blob/main/ViT_model_training_%EB%8F%99%EB%AC%BC_%EA%B7%BC%EA%B3%A8%EA%B2%A9%EA%B3%84_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
import torch.optim as optim

**Custom Dataset 클래스 정의**

In [None]:
class CustomDataset(Dataset):
    def __init__(self, root_dir, transform=None, apply_clahe=False):
        self.root_dir = root_dir
        self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith('.npy')]
        self.transform = transform
        self.apply_clahe = apply_clahe

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        # 넘파이 파일 불러오기
        data = np.load(self.files[idx], allow_pickle=True).item()
        image_np, label = data['image'], data['label']

        # (1, H, W) 형식을 (3, H, W)로 변환
        if image_np.shape[0] == 1:  # 단일 채널인 경우
            image_np = np.repeat(image_np, 3, axis=0)  # 채널을 3번 반복하여 (3, H, W)로 확장
        # ToTensor 변환 적용
        image_tensor = torch.tensor(image_np, dtype=torch.float32)

        # (H, W, 3) -> (3, H, W)로 변환 필요 시 permute 사용
        if image_tensor.shape[0] != 3:
            image_tensor = image_tensor.permute(2, 0, 1)


        return image_tensor, label

In [None]:
# 데이터셋 경로 설정
train_data_path = '/content/drive/MyDrive/Medical_CV/3/전처리_데이터/train'
val_data_path = '/content/drive/MyDrive/Medical_CV/3/전처리_데이터/val'


train_dataset = CustomDataset(root_dir=train_data_path)
val_dataset = CustomDataset(root_dir=val_data_path)

# 데이터 로더 준비
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [None]:
# ViT 모델 커스터마이징 클래스 정의
class ViTModel2(nn.Module):
    def __init__(self, num_classes=4):
        super(ViTModel2, self).__init__()
        # 사전 학습된 ViT 모델 로드
        self.model = models.vit_b_16(pretrained=True)

        # 마지막 레이어를 num_classes에 맞게 수정
        in_features = self.model.heads.head.in_features

        self.model.classifier = nn.Sequential(
            nn.Dropout(p=0.5),  # Dropout 추가 (과적합 방지)
            nn.Linear(in_features, num_classes)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
# 모델 생성 및 장치 설정
num_classes = 4  # 클래스 수 (갈비뼈골절, 슬개골탈구, 전십자인대파열, 추간판질환)
ViT_model2 = ViTModel2(num_classes=num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ViT_model2 = ViT_model2.to(device)



In [None]:
learning_rate  = 1e-6
optimizer = optim.Adam(ViT_model2.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
def train_model(model, optimizer, criterion, train_loader, val_loader, num_epochs=15):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)

        # 검증 단계
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_accuracy = 100 * correct / total
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {avg_train_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

# 모델 학습 시작
train_model(ViT_model2, optimizer, criterion, train_loader, val_loader, num_epochs=15)

Epoch [1/15], Loss: 3.6629, Validation Accuracy: 75.97%
Epoch [2/15], Loss: 0.4944, Validation Accuracy: 85.02%
Epoch [3/15], Loss: 0.2816, Validation Accuracy: 84.32%
Epoch [4/15], Loss: 0.2085, Validation Accuracy: 89.42%
Epoch [5/15], Loss: 0.1710, Validation Accuracy: 86.30%
Epoch [6/15], Loss: 0.1470, Validation Accuracy: 87.44%
Epoch [7/15], Loss: 0.1299, Validation Accuracy: 89.48%
Epoch [8/15], Loss: 0.1161, Validation Accuracy: 87.64%
Epoch [9/15], Loss: 0.1042, Validation Accuracy: 91.46%
Epoch [10/15], Loss: 0.0959, Validation Accuracy: 89.93%
Epoch [11/15], Loss: 0.0882, Validation Accuracy: 94.65%
Epoch [12/15], Loss: 0.0775, Validation Accuracy: 91.84%
Epoch [13/15], Loss: 0.0697, Validation Accuracy: 91.65%
Epoch [14/15], Loss: 0.0633, Validation Accuracy: 91.27%
Epoch [15/15], Loss: 0.0580, Validation Accuracy: 91.78%


In [None]:
save_dir = '/content/drive/MyDrive/Medical_CV/models'
os.makedirs(save_dir, exist_ok=True)  # 경로가 없으면 생성

# 저장 파일 이름 설정
save_path = os.path.join(save_dir, 'vit_weights_V2.pth')

In [None]:
torch.save(ViT_model2.state_dict(), save_path)