# Training

### Import Libraries

In [10]:
import os
import easydict
from PIL import Image as im
from PIL import ImageOps
import cv2
from tqdm import tqdm
import json
import sys
import shutil
from datetime import datetime
import pandas as pd
import random

op = os.path.join

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import Dataset, DataLoader
from torch.optim import lr_scheduler

In [12]:
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, roc_auc_score, roc_curve
from sklearn.metrics import precision_score, recall_score
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from tensorboardX import SummaryWriter

### Set Directory Path

In [13]:
DATA_DIR = "/hoeunlee228/Dataset/"

TRAIN_DATASET_DIR = op(DATA_DIR, "train_odata")
TRAIN_IMAGE_DIR = op(TRAIN_DATASET_DIR, "image")
TRAIN_JSON_DIR = op(TRAIN_DATASET_DIR, "json")

TEST_DATASET_DIR = op(DATA_DIR, "test_odata")
TEST_IMAGE_DIR = op(TEST_DATASET_DIR, "image")
TEST_JSON_DIR = op(TEST_DATASET_DIR, "json")

current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime('%Y-%m-%d_%H_%M')
print(formatted_datetime)

os.makedirs(op("/hoeunlee228/logs", formatted_datetime), exist_ok=True)
os.makedirs(op("/hoeunlee228/weights", formatted_datetime), exist_ok=True)

2023-12-15_15_41


### Read File List

In [14]:
train_file_list = list(map(lambda x : x.split(".")[0], os.listdir(TRAIN_IMAGE_DIR)))
train_image_list = os.listdir(TRAIN_IMAGE_DIR)
train_json_list = os.listdir(TRAIN_JSON_DIR)
print("number of train file:", len(train_file_list))

test_file_list = list(map(lambda x : x.split(".")[0], os.listdir(TEST_IMAGE_DIR)))
test_image_list = os.listdir(TEST_IMAGE_DIR)
test_json_list = os.listdir(TEST_JSON_DIR)
print("number of train file:", len(test_file_list))

number of train file: 338719
number of train file: 37489


### Undersampling

In [15]:
train_df = pd.read_csv("/hoeunlee228/Dataset/train_df.csv")
# print(train_df.head())
train_df['is_decayed'].value_counts()

not_decayed_rows = train_df[train_df['is_decayed'] == False]
decayed_rows = train_df[train_df['is_decayed'] == True]
print("total not decayed rows:", len(not_decayed_rows), "total decayed rows", len(decayed_rows))

num_samples_not_decayed = 28136 * 3
num_samples_decayed = 28136

print("args num sampled not decayed:", num_samples_not_decayed, "args num sampled decayed:", num_samples_decayed)

random_samples_not_decayed = random.sample(range(len(not_decayed_rows)), num_samples_not_decayed)
not_decayed_sampled_df = not_decayed_rows.iloc[random_samples_not_decayed]
print("num of not decayed sampled list", len(not_decayed_sampled_df))

random_samples_decayed = random.sample(range(len(decayed_rows)), num_samples_decayed)
decayed_sampled_df = decayed_rows.iloc[random_samples_decayed]
print("num of decayed sampled list", len(decayed_sampled_df))

train_file_list_not_decayed_sampled = list(map(lambda x : f"{x[0]}_{x[1]}", not_decayed_sampled_df[['file', 'teeth_idx']].values.tolist()))
train_file_list_decayed_sampled = list(map(lambda x : f"{x[0]}_{x[1]}", decayed_sampled_df[['file', 'teeth_idx']].values.tolist()))
print("num of sampled file list (not decayed, decayed):", len(train_file_list_not_decayed_sampled), len(train_file_list_decayed_sampled))

sampled_train_list = train_file_list_not_decayed_sampled + train_file_list_decayed_sampled
print("total num of sampled train list:", len(sampled_train_list))

total not decayed rows: 310583 total decayed rows 28136
args num sampled not decayed: 84408 args num sampled decayed: 28136
num of not decayed sampled list 84408
num of decayed sampled list 28136
num of sampled file list (not decayed, decayed): 84408 28136
total num of sampled train list: 112544


### Custom Dataset

In [16]:
class ToothDataset(Dataset):
    def __init__(self, data_dir, file_list, transform=None, aug_transform=None):
        self.data_dir = data_dir
        self.file_list = file_list
        self.transform = transform
        self.aug_transform = aug_transform

    def __len__(self):
        return len(self.file_list)
    
    def _load_image(self, image_path):
        assert os.path.exists(op(self.data_dir, "image", image_path))
        # return cv2.cvtColor(cv2.imread(op(self.data_dir, "image", image_path)), cv2.COLOR_BGR2RGB)
        return im.open(op(self.data_dir, "image", image_path)).convert("RGB")
    
    def __getitem__(self, index):
        image_path = self.file_list[index] + ".png"
        json_path = self.file_list[index] + ".json"

        image = self._load_image(image_path)

        with open(op(self.data_dir, "json", json_path), 'r') as json_file:
            data = json.load(json_file)

        decayed = data["tooth"][0]["decayed"]
        target = 1 if decayed else 0

        if self.transform:
            image = self.transform(image)

        if self.aug_transform:
            image = self.aug_transform(image)

        return image, target

In [17]:
class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x


class Xception(nn.Module):
    def __init__(self, num_classes=1000):
        super(Xception, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3,32,3,2,0,bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

    def features(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)
        return x

    def logits(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1)) 
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x


## 기존 Xception에 Dropout만 추가
class xception(nn.Module):
    def __init__(self, num_out_classes=2, dropout=0.5):
        super(xception, self).__init__()

        self.model = Xception(num_classes=num_out_classes)
        self.model.last_linear = self.model.fc
        del self.model.fc

        num_ftrs = self.model.last_linear.in_features
        if not dropout:
            self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
        else:            
            self.model.last_linear = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(num_ftrs, num_out_classes)
            )

    def forward(self, x):
        x = self.model(x)
        return x

In [18]:
cudnn.benchmark = True

args = easydict.EasyDict({
    "gpu": 0,
    "num_workers": 4,
    "root": "/hoeunlee228/Dataset/",
    "learning_rate": 1e-4,
    "num_epochs": 50,
    "batch_size": 32,

    "save_fn": f"/hoeunlee228/weights/{formatted_datetime}/xception",
    "load_fn": "/hoeunlee228/weights/2023-12-15_00_43/xception_epoch0.pth",
    "scheduler": None,

    "scheduler_step": 1,
    "scheduler_gamma": 0.001
})

In [19]:
# 1. Zero padding to make the image square
def make_square(img):
    width, height = img.size
    max_side = max(width, height)
    left = (max_side - width) // 2
    top = (max_side - height) // 2
    right = (max_side - width) - left
    bottom = (max_side - height) - top
    padding = (left, top, right, bottom)
    return ImageOps.expand(img, padding)

In [20]:
mean = (0.57933619, 0.42688786, 0.33401168)
std = (0.35580848, 0.27125023, 0.22251333)

"""train_transform = transforms.Compose([
    transforms.Resize(size=(299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_transform = transforms.Compose([
    transforms.Resize(size=(299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

valid_transform = transforms.Compose([
    transforms.Resize(size=(299, 299)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])"""

train_aug_transform = transforms.Compose([
    # 1. 이미지를 정사각형으로 만들기
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Exif 정보 처리
    transforms.Lambda(make_square),

    # 2. Resize
    transforms.Resize(size=(299, 299)),

    # 3. Augmentation
    
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1), shear=10),
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.1),

    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

test_aug_transform = transforms.Compose([
    # 1. 이미지를 정사각형으로 만들기
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Exif 정보 처리
    transforms.Lambda(make_square),

    # 2. Resize
    transforms.Resize(size=(299, 299)),
    
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

valid_aug_transform = transforms.Compose([
    # 1. 이미지를 정사각형으로 만들기
    transforms.Lambda(lambda img: ImageOps.exif_transpose(img)),  # Exif 정보 처리
    transforms.Lambda(make_square),

    # 2. Resize
    transforms.Resize(size=(299, 299)),

    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

In [21]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
    torch.save(state, filename)
    print("pth file saved at " + filename)
    if is_best:
        shutil.copyfile(filename, f"/hoeunlee228/weights/{formatted_datetime}/model_best.pth.tar")

def adjust_learning_rate(optimizer, epoch, args):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = args.lr * (0.1 ** (epoch // 30))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [22]:
# train / validate
def train(train_loader, model, criterion, optimizer, epoch, writer, step):   
    n = 0
    running_loss = 0.0
    running_corrects = 0

    all_targets = []
    all_preds = []

    error_count = {
        "precision": 0,
        "recall": 0,
        "auc": 0
    }

    model.train()

    with tqdm(train_loader, total=len(train_loader), desc="Train", file=sys.stdout) as iterator:
        for images, target in iterator:
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            outputs = model(images)
            _, pred = torch.max(outputs.data, 1)

            loss = criterion(outputs, target)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            n += images.size(0)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(pred == target.data)

            epoch_loss = running_loss / float(n)
            epoch_acc = running_corrects / float(n)

            all_targets.extend(target.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())

            # Calculate F1-Score
            f1 = f1_score(all_targets, all_preds, average='weighted')

            # Calculate Precision and Recall
            try:
                precision = precision_score(all_targets, all_preds, average='weighted', zero_division=1)
            except Exception as e:
                precision = -1.0
                error_count["precision"] += 1

            try:
                recall = recall_score(all_targets, all_preds, average='weighted')
            except Exception as e:
                recall = -1.0
                error_count["recall"] += 1

            # Calculate AUC
            # fpr, tpr, _ = roc_curve(target.cpu(), outputs.data[:, 1].cpu())
            # auc_value = roc_auc_score(target.cpu(), outputs.data[:, 1].cpu())
            try:
                # fpr, tpr, _ = roc_curve(all_targets, all_preds)
                auc_value = roc_auc_score(all_targets, all_preds)
            except Exception as e:
                auc_value = -1.0
                error_count["auc"] += 1

            log = 'loss - {:.4f}, acc - {:.4f}, F1 - {:.4f}, Precision - {:.4f}, Recall - {:.4f}, AUC - {:.4f}'.format(epoch_loss, epoch_acc, f1, precision, recall, auc_value)
            iterator.set_postfix_str(log)

            if step % 10 == 0:
                writer.add_scalar("Train/Step/Loss", epoch_loss, step)
                writer.add_scalar("Train/Step/Accuracy", epoch_acc, step)
                writer.add_scalar("Train/Step/F1-Score", f1, step)
                writer.add_scalar("Train/Step/Precision", precision, step)
                writer.add_scalar("Train/Step/Recall", recall, step)
                writer.add_scalar("Train/Step/AUC", auc_value, step)

            step += 1

    writer.add_scalar("Train/Epoch/Loss", epoch_loss, epoch)
    writer.add_scalar("Train/Epoch/Accuracy", epoch_acc, epoch)
    writer.add_scalar("Train/Epoch/F1-Score", f1, epoch)
    writer.add_scalar("Train/Epoch/Precision", precision, epoch)
    writer.add_scalar("Train/Epoch/Recall", recall, epoch)
    writer.add_scalar("Train/Epoch/AUC", auc_value, epoch)

    print(error_count)

    # scheduler.step()

    return step


def validate(test_loader, model, criterion, epoch, writer):
    n = 0
    running_loss = 0.0
    running_corrects = 0

    all_targets = []
    all_preds = []

    error_count = {
        "precision": 0,
        "recall": 0,
        "auc": 0
    }

    model.eval()

    with tqdm(test_loader, total=len(test_loader), desc="Valid", file=sys.stdout) as iterator:
        for images, target in iterator:
            if args.gpu is not None:
                images = images.cuda(args.gpu, non_blocking=True)
                target = target.cuda(args.gpu, non_blocking=True)

            with torch.no_grad():
                output = model(images)

            loss = criterion(output, target)
            _, pred = torch.max(output.data, 1)

            n += images.size(0)
            running_loss += loss.item() * images.size(0)
            running_corrects += torch.sum(pred == target.data)

            epoch_loss = running_loss / float(n)
            epoch_acc = running_corrects / float(n)

            all_targets.extend(target.cpu().numpy())
            all_preds.extend(pred.cpu().numpy())

            # Calculate F1-Score
            f1 = f1_score(all_targets, all_preds, average='weighted')

            try:
                precision = precision_score(all_targets, all_preds, average='weighted', zero_division=1)
            except Exception as e:
                precision = -1.0
                error_count["precision"] += 1
            
            try:
                recall = recall_score(all_targets, all_preds, average='weighted')
            except Exception as e:
                recall = -1.0
                error_count["recall"] += 1

            # Calculate AUC
            # fpr, tpr, _ = roc_curve(target.cpu(), output.data[:, 1].cpu())
            # auc_value = roc_auc_score(target.cpu(), output.data[:, 1].cpu())
            try:
                # fpr, tpr, _ = roc_curve(all_targets, all_preds)
                auc_value = roc_auc_score(all_targets, all_preds)
            except Exception as e:
                auc_value = -1.0
                error_count["auc"] += 1

            log = 'loss - {:.4f}, acc - {:.4f}, F1 - {:.4f}, Precision - {:.4f}, Recall - {:.4f}, AUC - {:.4f}'.format(epoch_loss, epoch_acc, f1, precision, recall, auc_value)
            iterator.set_postfix_str(log)

    writer.add_scalar("Validation/Epoch/Loss", epoch_loss, epoch)
    writer.add_scalar("Validation/Epoch/Accuracy", epoch_acc, epoch)
    writer.add_scalar("Validation/Epoch/F1-Score", f1, epoch)
    writer.add_scalar("Validation/Epoch/Precision", precision, epoch)
    writer.add_scalar("Validation/Epoch/Recall", recall, epoch)
    writer.add_scalar("Validation/Epoch/AUC", auc_value, epoch)

    print(error_count)

    return epoch_acc

In [23]:
# TensorboardX를 사용하여 summary writer 생성
train_writer = SummaryWriter(f"/hoeunlee228/logs/{formatted_datetime}/train")  # 'logs/train'는 로그가 저장될 디렉토리입니다.
test_writer = SummaryWriter(f"/hoeunlee228/logs/{formatted_datetime}/validation")  # 'logs/validation'은 로그가 저장될 디렉토리입니다.

In [24]:
model = xception(num_out_classes=2, dropout=0.5)
print("=> creating model '{}'".format('xception'))
model = model.cuda(args.gpu)

=> creating model 'xception'


In [25]:
if args.load_fn is not None:
    assert os.path.isfile(args.load_fn), 'wrong path'

    model.load_state_dict(torch.load(args.load_fn)['state_dict'])
    print("=> model weight '{}' is loaded".format(args.load_fn))

=> model weight '/hoeunlee228/weights/2023-12-15_00_43/xception_epoch0.pth' is loaded


In [26]:
criterion = nn.CrossEntropyLoss().cuda()
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), eps=1e-08)

In [27]:
if args.scheduler == "steplr":
    print("StepLR selected to Scheduler")
    scheduler = lr_scheduler.StepLR(optimizer, 
                                    step_size=args.scheduler_step,
                                    gamma=args.scheduler_gamma)
elif args.scheduler == 'exp':
    print("ExponentialLR selected to Scheduler")
    scheduler = lr_scheduler.ExponentialLR(optimizer, 
                                    gamma=args.scheduler_gamma)

In [28]:
train_file_list = list(map(lambda x : x.split(".")[0], os.listdir("/hoeunlee228/Dataset/train_odata/image")))
test_file_list = list(map(lambda x : x.split(".")[0], os.listdir("/hoeunlee228/Dataset/test_odata/image")))
print(len(train_file_list), len(test_file_list))

train_dataset = ToothDataset(data_dir="/hoeunlee228/Dataset/train_odata",
                             file_list=sampled_train_list,
                             transform=train_aug_transform)

valid_dataset = ToothDataset(data_dir="/hoeunlee228/Dataset/test_odata",
                             file_list=test_file_list,
                             transform=test_aug_transform)

338719 37489


In [29]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.num_workers,
                                           pin_memory=True,
                                           )

valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=args.batch_size,
                                           shuffle=False,
                                           num_workers=args.num_workers,
                                           pin_memory=False,
                                           )

In [30]:
print("checking dataset and dataloader is okay...")

try:
    _ = next(iter(train_loader))
    print("ok")
except Exception as e:
    print(e)
    print("dataset or dataloader is not ok")

checking dataset and dataloader is okay...
ok


In [24]:
print("start training...")
step = 0

for epoch in range(args.num_epochs):
    print('-' * 50)
    print('Epoch {}/{}'.format(epoch, args.num_epochs))
    step = train(train_loader, model, criterion, optimizer, epoch, train_writer, step)
    acc = validate(valid_loader, model, criterion, epoch, test_writer)

    save_checkpoint(state={'epoch': epoch,
                           'state_dict': model.state_dict(),
                           'best_acc1': acc,
                           'optimizer': optimizer.state_dict(),},
                        is_best=False,
                        filename=args.save_fn + f"_epoch{epoch}.pth",
                    )

start training...
--------------------------------------------------
Epoch 1/50
Train:   4%|▎         | 125/3517 [01:50<49:53,  1.13it/s, loss - 0.1950, acc - 0.916, F1 - 0.915, Precision - 0.915, Recall - 0.916, AUC - 0.877]


KeyboardInterrupt: 

In [31]:
train_writer.close()
test_writer.close()

# Inference

In [32]:
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score

In [50]:
args = easydict.EasyDict({    
    "gpu": 0,
    # 데이터 위치
    "root": "/hoeunlee228/Dataset/test_odata/",
    # pth
    "save_fn": "/hoeunlee228/weights/2023-12-15_01_57/xception_epoch5.pth"
})

assert os.path.isfile(args.save_fn), 'wrong path'

In [51]:
model = xception(num_out_classes=2, dropout=0.5)
print("=> creating model '{}'".format('xception'))
model = model.cuda(args.gpu)

assert os.path.isfile(args.save_fn), 'wrong path'

model.load_state_dict(torch.load(args.save_fn)['state_dict'])
print("=> model weight '{}' is loaded".format(args.save_fn))

model = model.eval()

=> creating model 'xception'
=> model weight '/hoeunlee228/weights/2023-12-15_01_57/xception_epoch5.pth' is loaded


In [52]:
test_df = pd.read_csv("/hoeunlee228/test_df.csv", index_col=0)
test_df.head()

Unnamed: 0,file,teeth_idx,num_seg,is_decayed,bound_x,bound_size_x,bound_y,bound_size_y,data_type
0,front_100,11,974,False,"(561, 925)",364,"(0, 396)",396,test
1,front_100,12,790,False,"(312, 573)",261,"(35, 368)",333,test
2,front_100,13,713,False,"(101, 329)",228,"(69, 411)",342,test
3,front_100,14,459,False,"(0, 167)",167,"(183, 413)",230,test
4,front_100,21,929,False,"(909, 1281)",372,"(48, 408)",360,test


In [53]:
# collect all images
print(len(test_file_list))

# prepare pred dict
test_predict_dict = {}
image_inf_list = []

for f in os.listdir("/hoeunlee228/Dataset/test_data/image/"):
    test_predict_dict[f] = []

# predict label
m = nn.Softmax()

with torch.no_grad():
    for tf in tqdm(test_file_list):
        image_path = op(TEST_DATASET_DIR, "image", tf + ".png")
        image = im.open(image_path)
        image = valid_aug_transform(image)
        image = torch.unsqueeze(image, dim=0)
        image = image.cuda(args.gpu, non_blocking=True)

        original_f = tf.split("_")[0] + "_" + tf.split("_")[1] + ".png"
        file_name = tf.split("_")[0] + "_" + tf.split("_")[1]
        teeth_idx = int(tf.split("_")[2])

        output = model(image)
        output = m(output)[0]  # apply softmax

        # 0 = not decayed
        # write to submission file
        if output[0] > output[1]:
            test_predict_dict[original_f].append(False)
            image_inf_list.append([file_name, teeth_idx, False])

        # 1 = decayed
        else:
            test_predict_dict[original_f].append(True)
            image_inf_list.append([file_name, teeth_idx, True])

print("predicted finished:", len(test_predict_dict.keys()))
print("len of preds:", len(image_inf_list))

37489


  output = m(output)[0]  # apply softmax
100%|██████████| 37489/37489 [20:09<00:00, 31.00it/s]

predicted finished: 3000
len of preds: 37489





In [54]:
correct_count_image = 0
inf_list_image = []

for row in tqdm(image_inf_list):
    # name, pred, answer, same
    file_name = row[0]
    teeth_idx = row[1]
    f_pred = row[2]
    # print(file_name, teeth_idx)
    f_answer = test_df[(test_df['file'] == file_name) & (test_df['teeth_idx'] == teeth_idx)]['is_decayed'].values[0]
    correct = (f_pred == f_answer)
    # print(file_name, f_pred, f_answer, correct)
    inf_list_image.append([file_name, f_pred, f_answer, correct])

    if correct:
        correct_count_image += 1

print(f"{correct_count_image}/{len(inf_list_image)}")

100%|██████████| 37489/37489 [02:21<00:00, 264.05it/s]

37272/37489





In [55]:
# Extract labels and predictions
labels_image = [item[2] for item in inf_list_image]
predictions_image = [item[1] for item in inf_list_image]

In [56]:
print(labels_image.count(0) + labels_image.count(1))
print(labels_image.count(0), labels_image.count(1))

print(predictions_image.count(0) + predictions_image.count(1))
print(predictions_image.count(0), predictions_image.count(1))

# Convert boolean values to integers (True -> 1, False -> 0)
labels_image = [int(label) for label in labels_image]
predictions_image = [int(prediction) for prediction in predictions_image]

37489
36489 1000
37489
36612 877


In [57]:
# Calculate metrics
accuracy_image = accuracy_score(labels_image, predictions_image)
precision_image = precision_score(labels_image, predictions_image)
recall_image = recall_score(labels_image, predictions_image)
auc_image = roc_auc_score(labels_image, predictions_image)
f1_image = f1_score(labels_image, predictions_image)

print("Accuracy:", accuracy_image)
print("Precision:", precision_image)
print("Recall:", recall_image)
print("AUC:", auc_image)
print("F1 Score:", f1_image)

Accuracy: 0.9942116354130545
Precision: 0.9464082098061574
Recall: 0.83
AUC: 0.9143559702924169
F1 Score: 0.8843899840170485


In [58]:
# Calculate confusion matrix
conf_matrix_image = confusion_matrix(labels_image, predictions_image)
print(conf_matrix_image)

# Extract TP, FP, TN, FN
tp = conf_matrix_image[1, 1]
fp = conf_matrix_image[0, 1]
tn = conf_matrix_image[0, 0]
fn = conf_matrix_image[1, 0]

print("True Positive (TP):", tp)
print("False Positive (FP):", fp)
print("True Negative (TN):", tn)
print("False Negative (FN):", fn)

[[36442    47]
 [  170   830]]
True Positive (TP): 830
False Positive (FP): 47
True Negative (TN): 36442
False Negative (FN): 170


In [59]:
test_image_df = pd.read_csv("/hoeunlee228/test_image_df.csv")
print("total num of test data:", len(test_image_df))
test_image_df['decayed'].value_counts()

total num of test data: 3000


False    2000
True     1000
Name: decayed, dtype: int64

In [60]:
correct_count = 0
inf_list = []

for index, row in tqdm(test_image_df.iterrows()):
    # name, pred, answer, same
    file_name = row['name']
    f_answer = row['decayed']
    f_pred = any(test_predict_dict[row['name'] + ".png"])
    correct = (f_pred == f_answer)
    # print(file_name, f_pred, f_answer, correct)
    inf_list.append([file_name, f_pred, f_answer, correct])

    if correct:
        correct_count += 1

print(f"{correct_count}/{len(test_image_df)}")

3000it [00:00, 18590.67it/s]

2833/3000





In [61]:
# Extract labels and predictions
labels = [item[2] for item in inf_list]
predictions = [item[1] for item in inf_list]

# Convert boolean values to integers (True -> 1, False -> 0)
labels = [int(label) for label in labels]
predictions = [int(prediction) for prediction in predictions]

# Calculate metrics
accuracy = accuracy_score(labels, predictions)
precision = precision_score(labels, predictions)
recall = recall_score(labels, predictions)
auc = roc_auc_score(labels, predictions)
f1 = f1_score(labels, predictions)

print("Accuracy:", accuracy)
print("Precision:", precision)
print("Recall:", recall)
print("AUC:", auc)
print("F1 Score:", f1)

Accuracy: 0.9443333333333334
Precision: 0.997610513739546
Recall: 0.835
AUC: 0.9169999999999999
F1 Score: 0.9090909090909092


In [62]:
# Calculate confusion matrix
conf_matrix = confusion_matrix(labels, predictions)
print(conf_matrix)

# Extract TP, FP, TN, FN
tp = conf_matrix[1, 1]
fp = conf_matrix[0, 1]
tn = conf_matrix[0, 0]
fn = conf_matrix[1, 0]

print("True Positive (TP):", tp)
print("False Positive (FP):", fp)
print("True Negative (TN):", tn)
print("False Negative (FN):", fn)

[[1998    2]
 [ 165  835]]
True Positive (TP): 835
False Positive (FP): 2
True Negative (TN): 1998
False Negative (FN): 165
