<a href="https://colab.research.google.com/github/jeong1suk/CT_Classification_segmentation/blob/main/ImageClassification/Classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 흉부 X-ray 이미지로 정상/코로나 폐렴을 분류하는 Image Classification

## 1. 라이브러리 불러오기

In [None]:
import os
import copy
import random

import cv2
import torch
import numpy as np
from torch import nn
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from ipywidgets import interact

random_seed = 2000

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## 2. 이미지 파일경로 불러오기

In [None]:
def list_image_files(data_dir, sub_dir):
    image_format = ["jpeg", "jpg", "png"]

    image_files = []
    images_dir = os.path.join(data_dir, sub_dir)
    for file_path in os.listdir(images_dir):
        if file_path.split(".")[-1] in image_format:
            image_files.append(os.path.join(sub_dir, file_path))
    return image_files

In [None]:
data_dir = "/content/drive/MyDrive/DATASET/Classification/train/"

normals_list = list_image_files(data_dir, "Normal")
covids_list = list_image_files(data_dir, "Covid")
pneumonias_list = list_image_files(data_dir, "Viral Pneumonia")

In [None]:
print(len(normals_list), len(covids_list), len(pneumonias_list))

## 3. 이미지파일을 RGB 3차원 배열로 불러오기

In [None]:
def get_RGB_image(data_dir, file_name):
    image_file = os.path.join(data_dir, file_name)
    image = cv2.imread(image_file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

## 4. 이미지데이터 확인하기

In [None]:
min_num_files = min(len(normals_list), len(covids_list), len(pneumonias_list))

@interact(index=(0, min_num_files-1))
def show_samples(index=0):
    normal_image = get_RGB_image(data_dir, normals_list[index])
    covid_image = get_RGB_image(data_dir, covids_list[index])
    pneumonia_image = get_RGB_image(data_dir, pneumonias_list[index])

    plt.figure(figsize=(12,8))
    plt.subplot(131)
    plt.title("Normal")
    plt.imshow(normal_image)
    plt.subplot(132)
    plt.title("Covid")
    plt.imshow(covid_image)
    plt.subplot(133)
    plt.title("Viral Pneumonia")
    plt.imshow(pneumonia_image)
    plt.tight_layout()

~인덱스별로 하나씩 확인하려고 interact 쓴건데 계속 쌓이네??~

## 5. 학습데이터셋 클래스 구축

In [None]:
train_data_dir = "/content/drive/MyDrive/DATASET/Classification/train/"
class_list = ["Normal", "Covid", "Viral Pneumonia"]

In [None]:
class Chest_dataset(Dataset):
    def __init__(self, data_dir, transformer=None):
        self.data_dir = data_dir
        normals = list_image_files(data_dir, "Normal")
        covids = list_image_files(data_dir, "Covid")
        pneumonias = list_image_files(data_dir, "Viral Pneumonia")

        self.files_path = normals + covids + pneumonias
        self.transform = transformer

    def __len__(self):
        return len(self.files_path)

    def __getitem__(self, index):
        image_file = os.path.join(self.data_dir, self.files_path[index])
        image = cv2.imread(image_file)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # target = class_list.index(self.files_path[index].split(os.sep)[-2])

        target = class_list.index(self.files_path[index].split(os.sep)[0])

        if self.transform:
            image = self.transform(image)
            target = torch.Tensor([target]).long()

        return {"image":image, "target":target}

In [None]:
normals = list_image_files(data_dir, "Normal")
covids = list_image_files(data_dir, "Covid")
pneumonias = list_image_files(data_dir, "Viral Pneumonia")
files_path = normals + covids + pneumonias
print(files_path)

In [None]:
print(files_path[200].split(os.sep)[0])

In [None]:
dset = Chest_dataset(train_data_dir)

In [None]:
index = 200
plt.title(class_list[dset[index]["target"]])
plt.imshow(dset[index]["image"])

## 6. 데이터로더 구현하기

In [None]:
transformer = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224,224)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [None]:
def build_dataloader(train_data_dir, val_data_dir):
    dataloaders = {}
    train_dset = Chest_dataset(train_data_dir, transformer)
    dataloaders["train"] = DataLoader(train_dset, batch_size=4, shuffle=True, drop_last=True)

    val_dset = Chest_dataset(val_data_dir, transformer)
    dataloaders["val"] = DataLoader(val_dset, batch_size=1, shuffle=False, drop_last=False)
    return dataloaders

In [None]:
train_data_dir = "/content/drive/MyDrive/DATASET/Classification/train/"
val_data_dir = "/content/drive/MyDrive/DATASET/Classification/test/"
dataloaders = build_dataloader(train_data_dir, val_data_dir)

In [None]:
print(dataloaders)

In [None]:
for i, d in enumerate(dataloaders["train"]):
    if i == 0:
        break

In [None]:
d["target"].shape

In [None]:
d["target"].squeeze()

## 7. VGG19 모델 불러오기

In [None]:
model = models.vgg19(pretrained=True)

In [None]:
from torchsummary import summary
summary(model, (3, 224, 224), batch_size=1, device="cpu")

## 8. 데이터에 맞게 Head 부분 변경하기

In [None]:
model.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
model.classifier = nn.Sequential(
    nn.Flatten(),
    nn.Linear(512, 256),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(256, len(class_list)),
    nn.Sigmoid()
)

In [None]:
def build_vgg19_based_model(device="cpu"):
    device = torch.device(device)
    model = models.vgg19(pretrained=True)
    model.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))
    model.classifier = nn.Sequential(
        nn.Flatten(),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Linear(256, len(class_list)),
        nn.Softmax(dim=1)
    )
    model.to(device)
    return model

In [None]:
model = build_vgg19_based_model(device='cpu')

In [None]:
print(model)

## 9. 손실함수

In [None]:
loss_func = nn.CrossEntropyLoss(reduction="mean")

## 10. Gradient 최적화 함수

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

## 11. 모델 검증을 위한 Accuracy 생성하기

In [None]:
@torch.no_grad()
def get_accuracy(image, target, model):
    batch_size = image.shape[0]
    prediction = model(image)
    _, pred_label = torch.max(prediction, dim=1)
    accuracy = (pred_label == target).sum().item() / batch_size
    return accuracy

## 12. 모델 학습

In [None]:
device = torch.device("cpu")

In [None]:
for index, batch in enumerate(dataloaders["train"]):
    print(batch["target"].squeeze(dim=1).to(device))
    break

In [None]:
def train_one_epoch(dataloaders, model, optimizer, loss_func, device):
    losses = {}
    accuracies = {}

    for phase in ["train", "val"]:
        running_loss = 0.0
        running_correct = 0

        if phase == "train":
            model.train()
        else:
            model.eval()

        for index, batch in enumerate(dataloaders[phase]):
            image = batch["image"].to(device)
            target = batch["target"].squeeze(dim=1).to(device)

            with torch.set_grad_enabled(phase == "train"):
                prediction = model(image)
                loss = loss_func(prediction, target)

                if phase == "train":
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item()
            running_correct += get_accuracy(image, target, model)

            if phase == "train":
                if index % 10 == 0:
                    print(f"{index}/{len(dataloaders[phase])} - Running Loss : {loss.item()}")

        losses[phase] = running_loss / len(dataloaders[phase])
        accuracies[phase] = running_correct / len(dataloaders[phase])

    return losses, accuracies

In [None]:
def save_best_model(model_state, model_name, save_dir="./trained_model"):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(model_state, os.path.join(save_dir, model_name))

## 13. 모델 학습 수행하기

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_data_dir = "/content/drive/MyDrive/DATASET/Classification/train/"
val_data_dir = "/content/drive/MyDrive/DATASET/Classification/test/"

dataloaders = build_dataloader(train_data_dir, val_data_dir)
model = build_vgg19_based_model(device=device)
loss_func = nn.CrossEntropyLoss(reduction="mean")
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [None]:
def train_one_epoch(dataloaders, model, optimizer, loss_function, device):
    losses = {}
    accuracies = {}

    for phase in ["train", "val"]:
        running_loss = 0.0
        running_correct = 0.0 # 매번 트레인과 밸리데이션 페이지별로 accuracy도 누적시킴

        if phase == "train":
            model.train()
        else:
            model.eval()

        for index, batch in enumerate(dataloaders[phase]):
            image = batch["image"].to(device) # 첫번째 리턴값: 이미지
            label = batch["target"].squeeze(dim=1).to(device) # 두번째 리턴값: 클래스 아이디

            with torch.set_grad_enabled(phase == "train"):
                prediction = model(image)
                loss = loss_func(prediction, label)

                optimizer.zero_grad()

                if phase == "train":
                    loss.backward()
                    optimizer.step()

            running_loss += loss.item()
            running_correct += get_accuracy(image, label, model)

            if phase == "train":
                if index % 10 == 0:
                    print(f"{index}/{len(dataloaders['train'])} - Running loss: {loss.item()}")

        losses[phase] = running_loss / len(dataloaders[phase])
        accuracies[phase] = running_correct / len(dataloaders[phase])

    return losses, accuracies

In [None]:
num_epochs = 10

best_acc = 0.0
train_loss, train_acc = [], []
val_loss, val_acc = [], []

for epoch in range(num_epochs):
    losses, accuracies = train_one_epoch(dataloaders, model, optimizer, loss_func, device)
    train_loss.append(losses['train'])
    train_acc.append(accuracies['train'])
    val_loss.append(losses['val'])
    val_acc.append(accuracies['val'])

    print(f"{epoch}/{num_epochs}-Tr loss:{losses['train']}, Val loss {losses['val']}")
    print(f"{epoch}/{num_epochs}-Tr acc:{accuracies['train']}, Val acc {accuracies['val']}")

    if accuracies["val"] > best_acc:
        best_acc = accuracies['val']
        torch.save(model.state_dict(), f"model_{epoch}.pth")

## 13. 테스트 이미지를 통한 학습모델 분류 성능 검증하기

In [None]:
data_dir = "/content/drive/MyDrive/DATASET/Classification/test/"
class_list = ["Normal", "Covid", "Viral Pneumonia"]

test_normals_list = list_image_files(data_dir, "Normal")
test_covids_list = list_image_files(data_dir, "Covid")
test_pneumonias_list = list_image_files(data_dir, "Viral Pneumonia")

In [None]:
def preprocess_image(image):
    transformer = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    tensor_image = transformer(image) # (C, H, W)
    tensor_image = tensor_image.unsqueeze(dim=0) # (B, C, H, W)
    return tensor_image

In [None]:
@torch.no_grad()
def model_predict(image, model):
    tensor_image = preprocess_image(image)
    prediction = model(tensor_image)
    _, pred_label = torch.max(prediction.detach().cpu(), dim=1) #(B, NUM_CLASSES)
    pred_label = pred_label.squeeze(dim=0)
    return pred_label.item() # 토치 변수가 가지고 있는 수치적인 값만을 가져옴

In [None]:
ckpt = torch.load("/content/model_5.pth")

model = build_vgg19_based_model(device='cpu')
model.load_state_dict(ckpt)
model.eval()

In [None]:
min_num_files = min(len(test_normals_list), len(test_covids_list), len(test_pneumonias_list))

@interact(index=(0, min_num_files-1))
def show_result(index=0):
    normal_image = get_RGB_image(data_dir, test_normals_list[index])
    covid_image = get_RGB_image(data_dir, test_covids_list[index])
    pneumonia_image = get_RGB_image(data_dir, test_pneumonias_list[index])

    pred_normal = model_predict(normal_image, model)
    pred_covid = model_predict(covid_image, model)
    pred_pneumonia = model_predict(pneumonia_image, model)

    plt.figure(figsize=(12, 8))
    plt.subplot(131)
    plt.title(f"Pred:{class_list[pred_normal]} | GT:Normal")
    plt.imshow(normal_image)

    plt.subplot(132)
    plt.title(f"Pred:{class_list[pred_covid]} | GT:Covid")
    plt.imshow(covid_image)

    plt.subplot(133)
    plt.title(f"Pred:{class_list[pred_pneumonia]} | GT:Viral Pneumonia")
    plt.imshow(pneumonia_image)