In [1]:
import os, torch, PIL, random, timm
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torchvision.transforms as transforms
import data

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

  from .autonotebook import tqdm as notebook_tqdm


cuda


In [2]:
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED']=str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic=True
    torch.backends.cudnn.benchmark=True

In [3]:
control_group = "NONSJS"

CFG={'SEED' : 46,  # 42~46
     'IMG_SIZE' : 224,
     'TEST_PORTION' : 0.5,  # Test set 비율
     'pg_res_path' : f"E:/model_save_path/SJS/{control_group}_save_path/{control_group}_5x_test50(PG_Res)_seed46.pt",
     'pg_vgg_path' : f"E:/model_save_path/SJS/{control_group}_save_path/{control_group}_5x_test50(PG_VGG)_seed46.pt",
     'pg_inc_path' : f"E:/model_save_path/SJS/{control_group}_save_path/{control_group}_5x_test50(PG_Inception)_seed46.pt",
     'sg_res_path' : f"E:/model_save_path/SJS/{control_group}_save_path/{control_group}_5x_test50(SG_Res)_seed46.pt",
     'sg_vgg_path' : f"E:/model_save_path/SJS/{control_group}_save_path/{control_group}_5x_test50(SG_VGG)_seed46.pt",
     'sg_inc_path' : f"E:/model_save_path/SJS/{control_group}_save_path/{control_group}_5x_test50(SG_Inception)_seed46.pt",
     'EPOCHS' : 15,
     'BATCH_SIZE' : 4,
     'LR' : 1e-4}

In [4]:
data_root = "D:\\Datasets\\SJS\\Processed"

SJS_path = data.path_by_diagnosis(data_root, "SJS")
CTR_path = data.path_by_diagnosis(data_root, control_group)
SJS_ID = data.ID_summary(SJS_path)
CTR_ID = data.ID_summary(CTR_path)

seed_everything(CFG["SEED"])

test_SJS_idx = random.sample(range(len(SJS_ID)), len(SJS_ID)//2)
test_CTR_idx = random.sample(range(len(CTR_ID)), len(CTR_ID)//2)
test_SJS_ID = [SJS_ID[i] for i in test_SJS_idx]
test_CTR_ID = [CTR_ID[i] for i in test_CTR_idx]

test_SJS_path = data.path_by_IDs(test_SJS_ID, SJS_path)
test_CTR_path = data.path_by_IDs(test_CTR_ID, CTR_path)

total_ID = test_SJS_ID + test_CTR_ID
total_path = test_SJS_path + test_CTR_path

print(len(SJS_path)+len(CTR_path))
print(len(total_path))

4851
232


In [6]:
resize = transforms.Compose([
    transforms.Resize((CFG['IMG_SIZE'], CFG['IMG_SIZE'])),
    transforms.ToTensor()
])

In [7]:
class ResNet(nn.Module):
    def __init__(self, classes=2):
        super().__init__()
        self.model = timm.create_model("resnet50", pretrained=True)
        self.model.fc = nn.Linear(in_features=2048, out_features=classes, bias=True)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.model(x)
        x = self.softmax(x)
        
        return x

class VGG(nn.Module):
    def __init__(self, classes=2):
        super().__init__()
        self.model = timm.create_model("vgg16_bn", pretrained=True)
        self.model.head.fc = nn.Linear(in_features=4096, out_features=classes, bias=True)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.model(x)
        x = self.softmax(x)
        
        return x

class Inception(nn.Module):
    def __init__(self, classes=2):
        super().__init__()
        self.model = timm.create_model("inception_v3", pretrained=True)
        self.model.fc = nn.Linear(in_features=2048, out_features=classes, bias=True)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x = self.model(x)
        x = self.softmax(x)
        
        return x

In [8]:
PG_ResNet = ResNet().to(device)
PG_VGG = VGG().to(device)
PG_Inception = Inception().to(device)
SG_ResNet = ResNet().to(device)
SG_VGG = VGG().to(device)
SG_Inception = Inception().to(device)

In [9]:
PG_ResNet.load_state_dict(torch.load(CFG["pg_res_path"], map_location="cuda"))
PG_VGG.load_state_dict(torch.load(CFG["pg_vgg_path"], map_location="cuda"))
PG_Inception.load_state_dict(torch.load(CFG["pg_inc_path"], map_location="cuda"))
SG_ResNet.load_state_dict(torch.load(CFG["sg_res_path"], map_location="cuda"))
SG_VGG.load_state_dict(torch.load(CFG["sg_vgg_path"], map_location="cuda"))
SG_Inception.load_state_dict(torch.load(CFG["sg_inc_path"], map_location="cuda"))
PG_ResNet.eval()
PG_VGG.eval()
PG_Inception.eval()
SG_ResNet.eval()
SG_VGG.eval()
SG_Inception.eval()

# total_preds, total_labels = [], []
# for idx, ID in enumerate(total_ID):
#     curr_datalist = total_path[idx]
#     curr_label = 0 if ID in CTR_ID else 1
#     output_list, model_list = [], []
#     for datapath in curr_datalist:
#         curr_data = resize(PIL.Image.open(datapath)).to(torch.float32).unsqueeze(0).to(device)
#         data_gland = datapath.split("\\")[-1].split("_")[0]
#         with torch.no_grad():
#             if data_gland == "PTG":
#                 res_out = PG_ResNet(curr_data)[0][0].detach().cpu().item()
#                 vgg_out = PG_VGG(curr_data)[0][0].detach().cpu().item()
#                 inc_out = PG_Inception(curr_data)[0][0].detach().cpu().item()
#                 max_output = np.max([res_out, vgg_out, inc_out])
#                 max_model = np.argmax([res_out, vgg_out, inc_out])
#                 output_list.append(max_output)
#                 model_list.append(max_model)
#             else:
#                 res_out = SG_ResNet(curr_data)[0][0].detach().cpu().item()
#                 vgg_out = SG_VGG(curr_data)[0][0].detach().cpu().item()
#                 inc_out = SG_Inception(curr_data)[0][0].detach().cpu().item()
#                 max_output = np.max([res_out, vgg_out, inc_out])
#                 max_model = np.argmax([res_out, vgg_out, inc_out])
#                 output_list.append(max_output)
#                 model_list.append(max_model)
#     curr_max_idx = np.argmax(output_list)
#     total_preds.append(output_list[curr_max_idx])
#     total_labels.append(curr_label)

Inception(
  (model): InceptionV3(
    (Conv2d_1a_3x3): ConvNormAct(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
      (bn): BatchNormAct2d(
        32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): ReLU(inplace=True)
      )
    )
    (Conv2d_2a_3x3): ConvNormAct(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), bias=False)
      (bn): BatchNormAct2d(
        32, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): ReLU(inplace=True)
      )
    )
    (Conv2d_2b_3x3): ConvNormAct(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNormAct2d(
        64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True
        (drop): Identity()
        (act): ReLU(inplace=True)
      )
    )
    (Pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False

In [10]:
correct_list = []
for idx, ID in enumerate(total_ID):
    curr_datalist = total_path[idx]
    curr_label = 0 if ID in CTR_ID else 1
    output_list = []
    for datapath in curr_datalist:
        curr_data = resize(PIL.Image.open(datapath)).to(torch.float32).unsqueeze(0).to(device)
        data_gland = datapath.split("\\")[-1].split("_")[0]
        with torch.no_grad():
            if data_gland == "PTG":
                res_out = PG_ResNet(curr_data)[0][0].detach().cpu().item()
                vgg_out = PG_VGG(curr_data)[0][0].detach().cpu().item()
                inc_out = PG_Inception(curr_data)[0][0].detach().cpu().item()
                max_output = np.max([res_out, vgg_out, inc_out])
                output_list.append(max_output)
            else:
                res_out = SG_ResNet(curr_data)[0][0].detach().cpu().item()
                vgg_out = SG_VGG(curr_data)[0][0].detach().cpu().item()
                inc_out = SG_Inception(curr_data)[0][0].detach().cpu().item()
                max_output = np.max([res_out, vgg_out, inc_out])
                output_list.append(max_output)
    curr_max_idx = np.argmax(output_list)
    curr_pred = output_list[curr_max_idx].item()
    curr_datapath = curr_datalist[curr_max_idx]
    if (curr_pred == 1.0) and (curr_label == 1): correct_list.append(curr_datapath)
    elif (curr_pred != 1.0) and (curr_label == 0): correct_list.append(curr_datapath)

In [11]:
import pandas as pd
df = dict()
class_list = []
for datapath in correct_list:
    ID = os.path.basename(datapath).split("_")[1].split("-")[0]
    if ID in SJS_ID: class_list.append("SJS")
    else: class_list.append(control_group)

df["Path"] = correct_list
df["Class"] = class_list
pd.DataFrame(df).to_excel(f"Correct_IDs({control_group}).xlsx")

In [12]:
# from sklearn.metrics import roc_auc_score, roc_curve

# fpr, tpr, thresholds = roc_curve(np.array(total_labels), total_preds)
# J=tpr-fpr
# idx = np.argmax(J)

# best_thresh = thresholds[idx]
# sens, spec = tpr[idx], 1-fpr[idx]
# print(best_thresh)

# acc = (sens*len(test_SJS_ID) + spec*len(test_CTR_ID)) / len(total_ID)
# auc = roc_auc_score(np.array(total_labels), total_preds)

# plt.title("Roc Curve")
# plt.plot([0,1], [0,1], linestyle='--', markersize=0.01, color='black')
# plt.plot(fpr, tpr, marker='.', color='black', markersize=0.05)
# plt.scatter(fpr[idx], tpr[idx], marker='o', s=200, color='r', label = 'Sensitivity : %.3f (%d / %d), \nSpecificity = %.3f (%d / %d), \nAUC = %.3f , \nACC = %.3f (%d / %d)' % (sens, (sens*len(test_SJS_ID)), len(test_SJS_ID), spec, (spec*len(test_CTR_ID)), len(test_CTR_ID), auc, acc, sens*len(test_SJS_ID)+spec*len(test_CTR_ID), len(total_ID)))
# plt.legend()