# Contest Objetive

치과 구강이미지 합성 데이터 분양의 헬스케어 AI 경진대회이며, 구강이미지 내 충치 유무 판별 모델 개발을 목표로 함.

# Competition Schedule

- 참 가  신 청 : 2023년 11월 30일(금) ~ 12월 05(화)
- 대 회  참 가 : 2023년 12월 11일(월) ~ 12월 15(금) 18:00 까지
- 레포트 제출 : 2023년 12월 17일(일) 낮 12:00 까지
- 심 사  기 간 : 2023년 12월 17일(일)
- 결 과  발 표 : 2023년 12월 18일(월) 
- 시    상   식 : ’23.12.21.(목)

# Summary

1. 본 경진대회는 동일한 환경(ssh & vim editor)과 동일한 GPU를 참가자에게 분배함.
2. 주최측에서 Resnet기반 모델 사용을 권장하였으며, `resnet50_binary` 모델 사용
3. 일부 참가자의 GPU 독점 이슈로, 배치 사이즈(12) 및 에폭 수(6)를 낮게 설정함.
4. 손실함수는 CrossEntropy, 옵티마이저는 Adam(lr=0.001).

# Resnet

- 충치가 하나라도 있으면 True로 출력하는 Resnet기반 모델을 개발함.
- GPU 이슈로 light한 `Resnet50_binary` 모델 사용

In [None]:
#resnet.py

import torch
import torch.nn as nn
from typing import List, Tuple

class BasicBlock(nn.Module):
    expansion_factor = 1

    def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != self.expansion_factor * out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, self.expansion_factor * out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion_factor * out_channels)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block: nn.Module, num_blocks: List[int], num_classes: int = 1):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        self.conv2 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.conv3 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.conv4 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.conv5 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc = nn.Linear(512 * block.expansion_factor, num_classes)

        self._init_layer()

    def _make_layer(self, block: nn.Module, out_channels: int, num_blocks: int, stride: int):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion_factor
        return nn.Sequential(*layers)

    def _init_layer(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.conv5(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = torch.sigmoid(self.fc(x))  # Apply sigmoid activation for binary classification
        return x

# ResNet-18
def resnet18_binary(num_classes: int = 2) -> ResNet:
    return ResNet(BasicBlock, [2, 2, 2, 2], num_classes)

# ResNet-34
def resnet34_binary(num_classes: int = 2) -> ResNet:
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)

# ResNet-50
def resnet50_binary(num_classes: int = 2) -> ResNet:
    return ResNet(BasicBlock, [3, 4, 6, 3], num_classes)

# ResNet-101
def resnet101_binary(num_classes: int = 2) -> ResNet:
    return ResNet(BasicBlock, [3, 4, 23, 3], num_classes)

# Train

In [None]:
##train.py

import argparse
import os
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image
from PIL import Image
from resnet import resnet18_binary, resnet34_binary, resnet50_binary, resnet101_binary
import albumentations as A
import numpy as np

#arguments 지정
def parse_arguments():
    parser = argparse.ArgumentParser(description='Tooth Decay Model Training')
    parser.add_argument('--num_epochs', type=int, default=100, help='에폭 설정')
    parser.add_argument('--lr', type=float, default=0.001, help='lr 설정')
    parser.add_argument('--image_dir', type=str, default='/qorskawls12/Dataset/train_data/image/', help='이미지 경로')
    parser.add_argument('--json_dir', type=str, default='/qorskawls12/Dataset/train_data/json/', help='json 경로')
    parser.add_argument('--model_dir', type=str, default='/qorskawls12/model_1216/', help='pt파일 저장 경로')
    parser.add_argument('--batch_size', type=int, default=16, help='배치 사이즈')
    parser.add_argument('--epoch_interval', type=int, default=1, help='모델 저장 간격 (에폭)')
    parser.add_argument('--pretrained_model_path', type=str, default='/qorskawls12/model_1215/resnet_lr0.001_ep6_bs12.pt', help='미리 학습된 모델 경로')

    return parser.parse_args()

def main(args):
    # 변경 전처리 추가
    data_transform = A.Compose([
        A.Resize(900, 1500),
        A.Rotate(limit=20, p=0.3),
        A.CLAHE(clip_limit=3.0, tile_grid_size=(8, 8), p=0.4),
        A.RandomBrightnessContrast(p=0.4, brightness_limit=-0.3, contrast_limit=-0.3)
    ])

    class CustomDataset(Dataset):
        def pil_to_numpy(self,image):
            return np.array(image)

        def __init__(self, image_dir, json_dir, transform=None):
            self.image_dir = image_dir
            self.json_dir = json_dir
            self.transform = transform

            self.samples = []
            for image_file in os.listdir(image_dir):
                if image_file.endswith('.png'):
                    image_path = os.path.join(image_dir, image_file)
                    json_path = os.path.join(json_dir, image_file.replace('.png', '.json'))
                    self.samples.append((image_path, json_path))

        def __len__(self):
            return len(self.samples)

        def __getitem__(self, idx):
            image_path, json_path = self.samples[idx]

            # 이미지 로드
            image = Image.open(image_path).convert("RGB")
            image_np = self.pil_to_numpy(image)

            if self.transform:
                transformed = self.transform(image=image_np)
                image = transformed['image']

            # JSON 파일 로드
            with open(json_path, 'r') as json_file:
                json_data = json.load(json_file)
                decayed_values = [tooth["decayed"] for tooth in json_data['tooth']]

            # 라벨 설정
            label = torch.tensor(int(any(decayed_values)), dtype=torch.float32)
            labels = label.type(torch.LongTensor)
            return image, labels

    # 데이터셋 및 데이터로더 생성
    dataset = CustomDataset(args.image_dir, args.json_dir, transform=data_transform)
    dataloader = DataLoader(dataset, batch_size=  args.batch_size , shuffle=True)

    # GPU 설정
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(device)

    ## pretrain 모델 정의
    pretrain_model = resnet50_binary()
    pretrain_model.load_state_dict(torch.load(args.pretrained_model_path))
    pretrain_model.to(device)

    # 현재 모델 정의
    model = resnet50_binary().to(device)
    model_name = model.__class__.__name__.lower()

    # 미리 학습된 가중치 적용
    model.load_state_dict(pretrain_model.state_dict(), strict=False)

    # 손실 함수 및 옵티마이저 정의
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    # 학습
    for epoch in range(args.num_epochs):
        for batch_index, (inputs, labels) in enumerate(dataloader):
            inputs, labels = inputs.to(device), labels.to(device)
            print(f"Processing Batch {batch_index+1}/{len(dataloader)}")
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs.squeeze(), labels)
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}/{args.num_epochs}, Loss: {loss.item()}')
        # 지정한 간격만큼 중간 모델 저장
        if (epoch + 1) % args.epoch_interval == 0:
            model_path = os.path.join(args.model_dir, f'{model_name}_ep{epoch+1}_bs{args.batch_size}.pt')
            torch.save(model.state_dict(), model_path)
            print(f'[~ing] SUCCESS! Saved to {model_path} (Epoch {epoch+1})')
    # 학습된 모델 저장
    model_path = os.path.join(args.model_dir, f'final_{model_name}_l_ep{args.num_epochs}_train.pt')
    torch.save(model.state_dict(), model_path)
    print(f'[finish] SUCCESS! Saved to {model_path}')

if __name__ == "__main__":
    arguments = parse_arguments()
    main(arguments)

# Learning & Evaluation

In [None]:
#Library
import os
import json
from PIL import Image
import torch
from torchvision import transforms
from resnet import resnet50_binary
from sklearn.metrics import f1_score
import argparse
import time
from sklearn.metrics import precision_recall_fscore_support
import torch.nn.functional as F

start_time = time.time()

#Load Model
def load_model(model_path, model_class, num_classes=2):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # device = "cpu"
    model = model_class(num_classes=num_classes)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict)
    model.to(device).eval()
    return model

#Learning & Evaluation
def perform_inference_and_evaluate(model, test_json_dir, image_path, transform, threshold, true_labels, predicted_labels):
    device = next(model.parameters()).device
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0).to(device)
    image_filename = os.path.basename(image_path)
    new_json_path = os.path.join(test_json_dir, image_filename.replace('.png', '.json'))

    with torch.no_grad():
        output = model(image)
    #probability = torch.sigmoid(output).item()
    #prediction = probability > threshol
    # 소프트맥스 함수를 사용하여 확률 계산
    probabilities = F.softmax(output, dim=1)

    # 가장 높은 확률을 가진 클래스 선택
    _, prediction_class = torch.max(probabilities, 1)

    # 예측값과 실제값 추가
    predicted_labels.append(bool(prediction_class.item()))

    # 실제값 추출
    with open(new_json_path, 'r') as json_file:
        json_data = json.load(json_file)
        true_decayed_values = [tooth["decayed"] for tooth in json_data['tooth']]
        true_label = torch.tensor(int(any(true_decayed_values)), dtype=torch.float32)
        true_labels.append(int(true_label))

    return (bool(prediction_class.item())), true_label
def create_predictions_list_and_evaluate(model, test_json_dir, image_dir, transform, threshold):
    predictions = []
    true_labels = []
    predicted_labels = []

    for index, image_file in enumerate(os.listdir(image_dir)):
        if image_file.endswith('.png'):
            image_path = os.path.join(image_dir, image_file)
            pred_label, true_label = perform_inference_and_evaluate(model, test_json_dir, image_path, transform, threshold, true_labels, predicted_labels)

            #predictions.append({
            #    "image_file": image_file,
            #    "decayed": prediction
            #})
            predictions.append(int(pred_label))

            print(f"Index: {index}/{len(os.listdir(image_dir))}, true label: {int(true_label)}, pred_label:{int(pred_label)}")

    return predictions, true_labels, predicted_labels

if __name__ == "__main__":
    print("start!")
    parser = argparse.ArgumentParser(description='Tooth Decay Model Inferencne')
    parser.add_argument('--model_path', type=str, required=True, help='모델 파일(.pt) 경로')
    parser.add_argument('--threshold', type=float, default=0.5, help='임계값')
    parser.add_argument('--output_json_dir', type=str, default='/qorskawls12/predict/', help='JSON 파일 결과를 저장할 디렉토리')
    parser.add_argument('--test_image_dir', type=str, default='/qorskawls12/Dataset/test_data/image/', help='testdata/image/* 경로')
    parser.add_argument('--test_json_dir', type=str, default='/qorskawls12/Dataset/test_data/json/', help='testdata/json/* 경로')
    parser.add_argument('--output_txt_dir', type=str, default='/qorskawls12/output_txt/', help='텍스트 파일 경로')
    args = parser.parse_args()

    print("Inference start")
    # 모델 로드
    model = load_model(args.model_path, resnet50_binary)

    # 변환 정의
    inference_transform = transforms.Compose([
        transforms.Resize((900, 1500)),
        transforms.ToTensor(),
    ])

    # 인퍼런스 결과 리스트 및 F1 Score 계산
    predictions, true_labels, predicted_labels = create_predictions_list_and_evaluate(
        model, args.test_json_dir, args.test_image_dir, inference_transform, args.threshold
    )

    # 결과를 JSON으로 저장
    model_filename = os.path.basename(args.model_path)
    json_filename = os.path.splitext(model_filename)[0] + '_predictions.json'
    output_json_path = os.path.join(args.output_json_dir, json_filename)

    with open(output_json_path, 'w') as json_file:
        json.dump({"predict": predictions}, json_file, indent=2)

    precision, recall, f1_score ,_ = precision_recall_fscore_support(true_labels, predicted_labels,labels=[1], average='binary')

    print(f'Precision for class 1 (decayed 양성)): {precision[0]:.2f}')
    print(f'Recall for class 1 (decayed 양성): {recall[0]:.2f}')
    print(f'F1 Score for class 1 (decayed 양성): {f1_score[0]:.2f}')

    ## 최종 결과를 저장할 output_txt 파일 명 : 모델 pt파일을 따서 네이밍
    output_txt_path = os.path.join(args.output_txt_dir, os.path.splitext(model_filename)[0] + '_output.txt')
    with open(output_txt_path, 'w') as txt_file:
        txt_file.write(f'Precision for class 1 (decayed 양성)): {precision[0]:.2f}\n')
        txt_file.write(f'Recall for class 1 (decayed 양성): {recall[0]:.2f}\n')
        txt_file.write(f'F1 Score for class 1 (decayed 양성): {f1_score[0]:.2f}\n')

        end_time = time.time()
        elapsed_time = (end_time - start_time) / 60
        print(f"Finish inference & f1-score(성능 : {f1_score[0]:.2f})! \n 걸린 시간 : {elapsed_time:.2f} 분")
        txt_file.write(f"Finish inference & f1-score(성능 : {f1_score[0]:.2f})! \n 걸린 시간 : {elapsed_time:.2f} 분\n")        txt_file.write(f"\n\n\n--------------예측값과 실제값 레이블 출력 결과----------------\n\n\n")
        for index, (true_label, pred_label) in enumerate(zip(true_labels, predicted_labels)):
            txt_file.write(f"Index: {index}, true label: {true_label}, pred_label: {int(pred_label)}\n")                                                                                                      

* Precision for class 1 (decayed 양성)): 1.0000
* Recall for class 1 (decayed 양성): 0.9990
* F1 Score for class 1 (decayed 양성): 0.9995
* Finish inference & f1-score(성능 : 0.9995)!
    - 걸린 시간 : 6.51 분