In [1]:
import os, pandas as pd, numpy as np, random,gc
import copy, cv2
pd.options.mode.chained_assignment = None
import torch, torch.nn as nn
import timm
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AdamW
from transformers import get_cosine_schedule_with_warmup
from sklearn.metrics import f1_score,roc_auc_score
from sklearn.metrics import accuracy_score, confusion_matrix

In [2]:
def set_seed(seed=42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
# # fix path
# df_224_3d = pd.read_csv("./csv_pickle/df_224_embed.csv")
# ct_path_stack = []
# for old_path in df_224_3d['ct_path'].tolist():
#     old_ = '/ssd2/COVID2023_data_embed_npy/'
#     new_ = '/ssd8/2023COVID19/2023_covid/COVID2023_data_embed_npy/'
#     new_path = old_path.replace(old_, new_)
#     ct_path_stack.append(new_path)
# df_224_3d['ct_path'] = ct_path_stack
# df_224_3d.to_csv("./csv_pickle/df_224_embed_146server.csv", index=False, encoding='utf-8-sig')

In [4]:
# 修改init讀取全部embed的方法（記憶體無法負擔所有npy資料），改以呼叫index階段時再進行讀取
class COVID_Dataset(torch.utils.data.Dataset):
    def __init__(self, csv_path, data_split, ct_len_s=50 , transform = None):
        if data_split == 'test':
            dataset = pd.read_pickle(csv_path)
            self.embed_info = dataset
            classes = [999]
            classes = sorted(list(classes))
            class_to_idx = {classes[i]: i for i in range(len(classes))}
            ct_path = np.unique(dataset.iloc[:, 2])
            imgs = []
            for i_scan_dir in tqdm(ct_path):
                temp_df = dataset[dataset['ct_path'] == i_scan_dir]
                imgs.append((i_scan_dir, 999))
            
        elif data_split == 'train' or 'valid': 
            df = pd.read_csv(csv_path)
            # df['embed'] = self.load_npy(df.ct_path.values.tolist(), df.embed.values.tolist())
            dataset = df[df['split'] == data_split]
            self.embed_info = dataset
            
            classes = set(dataset['label'])
            classes = sorted(list(classes))
            class_to_idx = {classes[i]: i for i in range(len(classes))}
            
            ct_path = np.unique(dataset.iloc[:, 4])
            imgs = []
            for i_scan_dir in ct_path:
                temp_df = dataset[dataset['ct_path'] == i_scan_dir]
                imgs.append((i_scan_dir, temp_df.iloc[0, 3]))
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.ct_len_s = ct_len_s
        self.imgs = imgs
        self.transform = transform
    def load_npy(self, npy_path, npy_file):
        new_npy_embed = []
        for path_, file_ in zip(npy_path, npy_file):
            new_npy_embed.append(np.load(os.path.join(path_, file_)))
        return new_npy_embed
    def __getitem__(self, index):
        img_scan_dir, label = self.imgs[index]
        
        label = self.class_to_idx[label]
        temp_df = self.embed_info[self.embed_info['ct_path'] == img_scan_dir]
        temp_df['embed'] = self.load_npy(temp_df['ct_path'].values.tolist(), temp_df['embed'].values.tolist())
        random.seed(4019)
        if len(temp_df) >= self.ct_len_s:
            temp_index = [x for x in range(len(temp_df))]
            target_index = random.sample(temp_index, k = self.ct_len_s)
            
        elif len(temp_df) < self.ct_len_s:
            target_index = [x for x in range(len(temp_df))]
            temp = random.choices(target_index, k = self.ct_len_s - len(target_index))
            target_index += temp
        
        target_index.sort()
        embed = temp_df.iloc[target_index, 1]
        img = []
        for i_embed in embed:
            img.append(i_embed)
        # img = np.expand_dims(np.array(img).reshape((1536, 8*8*self.ct_len_s)), axis=0)
        img = np.array(img)
        if len(img.shape)==4:
            img = np.array(img).reshape((1536, 8*8*self.ct_len_s))
        # else:
        #     img = img.reshape((img.shape[1],img.shape[0]))
        # img = np.concatenate([img,img,img], axis=0)
        # print(img.shape)
        return img, label

    def __len__(self):
        return len(self.imgs)

In [5]:
import torch,os
import torch.nn as nn
import timm
import torch.nn as nn
class MyModel(nn.Module):
    def __init__(self, ct_len=224, kernal_size = 3, pre_train=True):
        super(MyModel, self).__init__()
        self.conv1d = nn.Conv1d(in_channels=100, out_channels=CONFIG.ct_len_get, kernel_size=kernal_size)
        # self.conv1d = nn.Conv1d(in_channels=224, out_channels=CONFIG.ct_len_get, kernel_size=kernal_size)
        self.backbone = timm.create_model('resnet18', pretrained=pre_train, num_classes=1)
    def forward(self, x):
        
        x = self.conv1d(x)
        
        x = torch.cat((x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1)), dim=1)
        
        x = self.backbone(x)
        return x

In [6]:
def loss_fn(outputs, labels):
    return nn.BCEWithLogitsLoss()(outputs, labels)

def train_loop(model, optimizer, scheduler, loader):
    losses, lrs = [], []
    model.train()
    optimizer.zero_grad()
    for images, label in loader:
        out = model(images.cuda())

        loss = loss_fn(out.view(-1),label.cuda().float())
        
        losses.append(loss.item())
        step_lr = np.array([param_group["lr"] for param_group in optimizer.param_groups]).mean()
        lrs.append(step_lr)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    return np.array(losses).mean(), np.array(lrs).mean()


def valid_loop(model, loader):
    losses, predicts = [], []
    true_y=[]
    pred_y=[]
    model.eval()
    for images, label in loader:
        with torch.no_grad():
            images = images.cuda().float()
            labels = label.cuda().float()
            out = model(images)
            loss = loss_fn(out.view(-1),labels)
        losses.append(loss.item())
        predicts.append(out.cpu())
        true_y.append(labels.cpu().numpy())
        pred_y.append(torch.sigmoid(out).cpu().numpy())
        
    true_y=np.concatenate(true_y)
    pred_y=np.concatenate(pred_y)
    
    gc.collect()
    
    true_y=np.array(true_y).reshape(-1,1)
    pred_y=np.array(pred_y).reshape(-1,1)
   
    acc_f1=f1_score(np.array(true_y),np.round(pred_y),average='macro')

    auc_roc=roc_auc_score(np.array(true_y),np.array(pred_y))
    print("acc_f1 : ",round(acc_f1,4),"  auc_roc : ",round(auc_roc,4))
    
    return np.array(losses).mean(),acc_f1,auc_roc



In [7]:
class CONFIG:
    # model_path1 = "/ssd8/2023COVID19/CT-COVID19-Classification/train_code_fix/output/f1_best_model_k1_convslice_check_from144weight[384].bin"
    model_path1 = "/ssd8/2023COVID19/CT-COVID19-Classification/train_code_fix/output/f1_best_model_k1_convembed_check_from144weight[384].bin"
    # model_path1 = "/ssd8/2023COVID19/CT-COVID19-Classification/train_code_fix/output/f1_best_model_k1_convembed_check_from144weight[256].bin"
    pre_train = False
    N_EPOCHS = 100
    train_batch_size = 32
    valid_batch_size = 32
    ct_len_get = 100 #100
    kernal_size = 1 #1
    SEDD =42
    LR = 3e-5 #3e-5
    WEIGHT_DECAY = 1e-3 #1e-3

In [8]:
def main():
    print("==========loading model==========")
    set_seed(seed=42)
    model = MyModel(ct_len = CONFIG.ct_len_get, kernal_size=CONFIG.kernal_size, pre_train=CONFIG.pre_train).cuda()
    optimizer = AdamW(model.parameters(), lr=CONFIG.LR, weight_decay=CONFIG.WEIGHT_DECAY)


    df_path = './csv_pickle/df_224_embed_146server.csv'
    # df_path = './csv_pickle/df_224_embed_sz_384_146server.csv'
    print("==========data loader==========")
    train_ds = COVID_Dataset(csv_path = df_path, data_split = 'train', ct_len_s = CONFIG.ct_len_get, transform = None)
    valid_ds = COVID_Dataset(csv_path = df_path,data_split = 'valid', ct_len_s = CONFIG.ct_len_get, transform = None)
    
    train_loader = DataLoader(train_ds, batch_size=CONFIG.train_batch_size, num_workers=15, shuffle=True, pin_memory=True)
    valid_loader = DataLoader(valid_ds, batch_size=CONFIG.valid_batch_size, num_workers=15, shuffle=False, pin_memory=True)
    image, label = next(iter(train_loader))
    print(image.shape)

    num_train_steps = int(len(train_loader) * CONFIG.N_EPOCHS)
    num_warmup_steps = int(num_train_steps / 10)
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

    
    print("==========start train==========")
    best_epoch_f1 = 0
    for epoch in tqdm(range(CONFIG.N_EPOCHS)):
        train_loss, lrs = train_loop(model, optimizer, scheduler, train_loader)
        valid_loss,acc_f1,auc_roc = valid_loop(model, valid_loader)

        if acc_f1 > best_epoch_f1:
            print(f"Validation f1 Improved ({best_epoch_f1} ---> {acc_f1})")
            best_epoch_f1 = acc_f1
            #run.summary["Best Loss"] = best_epoch_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            PATH = CONFIG.model_path1
            torch.save(model.state_dict(), PATH)
            # Save a model file from the current directory
            print(f"Model Saved")
    torch.cuda.empty_cache()
    gc.collect()
    print(best_epoch_f1)

In [9]:
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
main()





torch.Size([32, 100, 224])


  1%|          | 1/100 [00:24<41:03, 24.89s/it]

acc_f1 :  0.3998   auc_roc :  0.8335
Validation f1 Improved (0 ---> 0.3998294970161978)
Model Saved


  2%|▏         | 2/100 [00:50<41:21, 25.32s/it]

acc_f1 :  0.7703   auc_roc :  0.8499
Validation f1 Improved (0.3998294970161978 ---> 0.7702842545891873)
Model Saved


  3%|▎         | 3/100 [01:14<39:44, 24.58s/it]

acc_f1 :  0.7824   auc_roc :  0.8608
Validation f1 Improved (0.7702842545891873 ---> 0.7823672254592823)
Model Saved


  4%|▍         | 4/100 [01:39<39:37, 24.77s/it]

acc_f1 :  0.789   auc_roc :  0.8719
Validation f1 Improved (0.7823672254592823 ---> 0.7889504289495786)
Model Saved


  5%|▌         | 5/100 [02:06<40:41, 25.70s/it]

acc_f1 :  0.8206   auc_roc :  0.8818
Validation f1 Improved (0.7889504289495786 ---> 0.8206017837890817)
Model Saved


  6%|▌         | 6/100 [02:34<41:15, 26.34s/it]

acc_f1 :  0.8143   auc_roc :  0.8891


  7%|▋         | 7/100 [03:03<42:27, 27.39s/it]

acc_f1 :  0.8393   auc_roc :  0.8955
Validation f1 Improved (0.8206017837890817 ---> 0.839253372349505)
Model Saved


  8%|▊         | 8/100 [03:35<44:13, 28.84s/it]

acc_f1 :  0.8415   auc_roc :  0.905
Validation f1 Improved (0.839253372349505 ---> 0.8414721967685335)
Model Saved
acc_f1 :  0.863   auc_roc :  0.9126
Validation f1 Improved (0.8414721967685335 ---> 0.8629581341445749)
Model Saved


 10%|█         | 10/100 [04:28<40:47, 27.19s/it]

acc_f1 :  0.8563   auc_roc :  0.9181


 11%|█         | 11/100 [04:53<39:16, 26.48s/it]

acc_f1 :  0.8618   auc_roc :  0.9193


 12%|█▏        | 12/100 [05:18<37:50, 25.81s/it]

acc_f1 :  0.8667   auc_roc :  0.9215
Validation f1 Improved (0.8629581341445749 ---> 0.8667442046089464)
Model Saved


 13%|█▎        | 13/100 [05:46<38:39, 26.66s/it]

acc_f1 :  0.8679   auc_roc :  0.9223
Validation f1 Improved (0.8667442046089464 ---> 0.867879807324719)
Model Saved


 14%|█▍        | 14/100 [06:14<38:52, 27.12s/it]

acc_f1 :  0.8732   auc_roc :  0.9261
Validation f1 Improved (0.867879807324719 ---> 0.8731661706518887)
Model Saved


 15%|█▌        | 15/100 [06:44<39:19, 27.75s/it]

acc_f1 :  0.8652   auc_roc :  0.9268


 16%|█▌        | 16/100 [07:17<41:12, 29.44s/it]

acc_f1 :  0.8744   auc_roc :  0.9299
Validation f1 Improved (0.8731661706518887 ---> 0.8743581999395953)
Model Saved


 17%|█▋        | 17/100 [07:43<39:26, 28.51s/it]

acc_f1 :  0.8803   auc_roc :  0.9309
Validation f1 Improved (0.8743581999395953 ---> 0.8802666053780741)
Model Saved


 18%|█▊        | 18/100 [08:07<37:04, 27.13s/it]

acc_f1 :  0.8658   auc_roc :  0.9295


 19%|█▉        | 19/100 [08:30<34:52, 25.83s/it]

acc_f1 :  0.87   auc_roc :  0.9274


 20%|██        | 20/100 [08:49<31:47, 23.85s/it]

acc_f1 :  0.8818   auc_roc :  0.9349
Validation f1 Improved (0.8802666053780741 ---> 0.8818197079066645)
Model Saved


 21%|██        | 21/100 [09:14<31:47, 24.15s/it]

acc_f1 :  0.877   auc_roc :  0.932


 22%|██▏       | 22/100 [09:37<30:47, 23.68s/it]

acc_f1 :  0.8738   auc_roc :  0.9343


 23%|██▎       | 23/100 [09:58<29:31, 23.00s/it]

acc_f1 :  0.8774   auc_roc :  0.9344


 24%|██▍       | 24/100 [10:22<29:18, 23.14s/it]

acc_f1 :  0.8779   auc_roc :  0.9361


 25%|██▌       | 25/100 [10:47<29:37, 23.70s/it]

acc_f1 :  0.8737   auc_roc :  0.9332


 26%|██▌       | 26/100 [11:07<27:54, 22.63s/it]

acc_f1 :  0.8682   auc_roc :  0.9317


 27%|██▋       | 27/100 [11:27<26:46, 22.01s/it]

acc_f1 :  0.8764   auc_roc :  0.9336


 28%|██▊       | 28/100 [11:48<25:57, 21.64s/it]

acc_f1 :  0.8779   auc_roc :  0.935


 29%|██▉       | 29/100 [12:08<25:05, 21.20s/it]

acc_f1 :  0.8796   auc_roc :  0.934


 30%|███       | 30/100 [12:30<25:01, 21.45s/it]

acc_f1 :  0.8742   auc_roc :  0.933


 31%|███       | 31/100 [12:52<24:55, 21.67s/it]

acc_f1 :  0.8875   auc_roc :  0.9342
Validation f1 Improved (0.8818197079066645 ---> 0.8874926943308006)
Model Saved


 32%|███▏      | 32/100 [13:11<23:27, 20.70s/it]

acc_f1 :  0.887   auc_roc :  0.9296


 33%|███▎      | 33/100 [13:34<23:48, 21.32s/it]

acc_f1 :  0.8694   auc_roc :  0.9295


 34%|███▍      | 34/100 [13:55<23:26, 21.31s/it]

acc_f1 :  0.8885   auc_roc :  0.9325
Validation f1 Improved (0.8874926943308006 ---> 0.8884825444051875)
Model Saved


 35%|███▌      | 35/100 [14:19<24:03, 22.21s/it]

acc_f1 :  0.8843   auc_roc :  0.9377


 36%|███▌      | 36/100 [14:39<23:01, 21.59s/it]

acc_f1 :  0.885   auc_roc :  0.9308


 37%|███▋      | 37/100 [15:01<22:34, 21.50s/it]

acc_f1 :  0.8865   auc_roc :  0.9336


 38%|███▊      | 38/100 [15:22<22:12, 21.50s/it]

acc_f1 :  0.887   auc_roc :  0.9358


 39%|███▉      | 39/100 [15:44<21:54, 21.55s/it]

acc_f1 :  0.8811   auc_roc :  0.9351


 40%|████      | 40/100 [16:06<21:35, 21.59s/it]

acc_f1 :  0.8826   auc_roc :  0.9361


 41%|████      | 41/100 [16:31<22:21, 22.73s/it]

acc_f1 :  0.887   auc_roc :  0.934


 42%|████▏     | 42/100 [16:53<21:46, 22.53s/it]

acc_f1 :  0.8855   auc_roc :  0.9351


 43%|████▎     | 43/100 [17:14<20:50, 21.93s/it]

acc_f1 :  0.885   auc_roc :  0.933


 44%|████▍     | 44/100 [17:33<19:40, 21.07s/it]

acc_f1 :  0.8825   auc_roc :  0.9324


 45%|████▌     | 45/100 [17:51<18:40, 20.37s/it]

acc_f1 :  0.8863   auc_roc :  0.9348


 46%|████▌     | 46/100 [18:10<17:49, 19.80s/it]

acc_f1 :  0.8863   auc_roc :  0.9344


 47%|████▋     | 47/100 [18:35<18:50, 21.32s/it]

acc_f1 :  0.8809   auc_roc :  0.9302


 48%|████▊     | 48/100 [18:58<19:05, 22.04s/it]

acc_f1 :  0.8823   auc_roc :  0.9352


 49%|████▉     | 49/100 [19:21<18:56, 22.29s/it]

acc_f1 :  0.885   auc_roc :  0.9329


 50%|█████     | 50/100 [19:41<18:01, 21.63s/it]

acc_f1 :  0.8836   auc_roc :  0.9349


 51%|█████     | 51/100 [20:07<18:40, 22.87s/it]

acc_f1 :  0.8818   auc_roc :  0.9295


 52%|█████▏    | 52/100 [20:29<17:58, 22.47s/it]

acc_f1 :  0.8789   auc_roc :  0.934


 53%|█████▎    | 53/100 [20:53<17:57, 22.92s/it]

acc_f1 :  0.888   auc_roc :  0.9335


 54%|█████▍    | 54/100 [21:18<18:02, 23.53s/it]

acc_f1 :  0.8848   auc_roc :  0.9332


 55%|█████▌    | 55/100 [21:40<17:22, 23.16s/it]

acc_f1 :  0.8858   auc_roc :  0.9304


 56%|█████▌    | 56/100 [22:08<18:00, 24.56s/it]

acc_f1 :  0.8831   auc_roc :  0.9338


 57%|█████▋    | 57/100 [22:40<19:21, 27.01s/it]

acc_f1 :  0.885   auc_roc :  0.934


 58%|█████▊    | 58/100 [23:10<19:24, 27.72s/it]

acc_f1 :  0.884   auc_roc :  0.9348


 59%|█████▉    | 59/100 [23:33<17:59, 26.32s/it]

acc_f1 :  0.8838   auc_roc :  0.9348


 60%|██████    | 60/100 [24:02<18:07, 27.19s/it]

acc_f1 :  0.8809   auc_roc :  0.935


 61%|██████    | 61/100 [24:29<17:39, 27.17s/it]

acc_f1 :  0.884   auc_roc :  0.934


 62%|██████▏   | 62/100 [24:58<17:29, 27.61s/it]

acc_f1 :  0.8877   auc_roc :  0.9348


 63%|██████▎   | 63/100 [25:28<17:25, 28.27s/it]

acc_f1 :  0.8836   auc_roc :  0.9345


 64%|██████▍   | 64/100 [25:55<16:45, 27.93s/it]

acc_f1 :  0.8845   auc_roc :  0.9343


 65%|██████▌   | 65/100 [26:22<16:12, 27.78s/it]

acc_f1 :  0.8811   auc_roc :  0.9336


 66%|██████▌   | 66/100 [26:48<15:22, 27.15s/it]

acc_f1 :  0.884   auc_roc :  0.9366


 67%|██████▋   | 67/100 [27:12<14:23, 26.17s/it]

acc_f1 :  0.8792   auc_roc :  0.9302


 68%|██████▊   | 68/100 [27:31<12:52, 24.14s/it]

acc_f1 :  0.8877   auc_roc :  0.9348


 69%|██████▉   | 69/100 [27:57<12:39, 24.51s/it]

acc_f1 :  0.8831   auc_roc :  0.9322


 70%|███████   | 70/100 [28:20<12:01, 24.06s/it]

acc_f1 :  0.8855   auc_roc :  0.9336


 71%|███████   | 71/100 [28:46<11:59, 24.81s/it]

acc_f1 :  0.8863   auc_roc :  0.9333


 72%|███████▏  | 72/100 [29:11<11:34, 24.82s/it]

acc_f1 :  0.885   auc_roc :  0.9354


 73%|███████▎  | 73/100 [29:34<10:52, 24.17s/it]

acc_f1 :  0.8812   auc_roc :  0.93


 74%|███████▍  | 74/100 [30:02<11:00, 25.41s/it]

acc_f1 :  0.8828   auc_roc :  0.934


 75%|███████▌  | 75/100 [30:31<11:00, 26.42s/it]

acc_f1 :  0.8831   auc_roc :  0.9319


 76%|███████▌  | 76/100 [30:55<10:21, 25.88s/it]

acc_f1 :  0.8836   auc_roc :  0.9359


 77%|███████▋  | 77/100 [31:18<09:36, 25.05s/it]

acc_f1 :  0.8826   auc_roc :  0.9348


 78%|███████▊  | 78/100 [31:39<08:44, 23.83s/it]

acc_f1 :  0.8813   auc_roc :  0.9339


 79%|███████▉  | 79/100 [32:00<07:59, 22.83s/it]

acc_f1 :  0.8882   auc_roc :  0.9336


 80%|████████  | 80/100 [32:23<07:39, 22.98s/it]

acc_f1 :  0.8793   auc_roc :  0.9338


 81%|████████  | 81/100 [32:46<07:12, 22.78s/it]

acc_f1 :  0.8836   auc_roc :  0.9339


 82%|████████▏ | 82/100 [33:10<07:01, 23.40s/it]

acc_f1 :  0.8779   auc_roc :  0.935


 83%|████████▎ | 83/100 [33:34<06:41, 23.61s/it]

acc_f1 :  0.8828   auc_roc :  0.9345


 84%|████████▍ | 84/100 [33:58<06:17, 23.59s/it]

acc_f1 :  0.8823   auc_roc :  0.9347


 85%|████████▌ | 85/100 [34:22<05:55, 23.73s/it]

acc_f1 :  0.8835   auc_roc :  0.933


 86%|████████▌ | 86/100 [34:49<05:44, 24.61s/it]

acc_f1 :  0.8808   auc_roc :  0.9327


 87%|████████▋ | 87/100 [35:15<05:24, 24.98s/it]

acc_f1 :  0.8833   auc_roc :  0.9332


 88%|████████▊ | 88/100 [35:41<05:05, 25.45s/it]

acc_f1 :  0.8835   auc_roc :  0.9331


 89%|████████▉ | 89/100 [36:10<04:52, 26.57s/it]

acc_f1 :  0.8867   auc_roc :  0.9338


 90%|█████████ | 90/100 [36:31<04:07, 24.77s/it]

acc_f1 :  0.8838   auc_roc :  0.9338


 91%|█████████ | 91/100 [36:50<03:28, 23.20s/it]

acc_f1 :  0.8831   auc_roc :  0.9339


 92%|█████████▏| 92/100 [37:14<03:07, 23.41s/it]

acc_f1 :  0.8828   auc_roc :  0.9339


 93%|█████████▎| 93/100 [37:43<02:54, 24.89s/it]

acc_f1 :  0.8816   auc_roc :  0.9337


 94%|█████████▍| 94/100 [38:07<02:27, 24.58s/it]

acc_f1 :  0.8848   auc_roc :  0.934


 95%|█████████▌| 95/100 [38:33<02:06, 25.21s/it]

acc_f1 :  0.8828   auc_roc :  0.934


 96%|█████████▌| 96/100 [38:59<01:42, 25.51s/it]

acc_f1 :  0.8867   auc_roc :  0.9338


 97%|█████████▋| 97/100 [39:24<01:15, 25.13s/it]

acc_f1 :  0.8838   auc_roc :  0.9342


 98%|█████████▊| 98/100 [39:51<00:51, 25.81s/it]

acc_f1 :  0.8833   auc_roc :  0.934


 99%|█████████▉| 99/100 [40:20<00:26, 26.89s/it]

acc_f1 :  0.8813   auc_roc :  0.9329


100%|██████████| 100/100 [40:42<00:00, 24.42s/it]

acc_f1 :  0.8833   auc_roc :  0.9336
0.8884825444051875





In [14]:
# mean 10
def valid_one(model, loader):
    losses, predicts = [], []
    true_y=[]
    pred_y=[]
    model.eval()
    for images, label in loader:
        with torch.no_grad():
            images = images.cuda().float()
            labels = label.cuda().float()
            out = model(images)
            loss = loss_fn(out.view(-1),labels)
        losses.append(loss.item())
        predicts.append(out.cpu())
        true_y.append(labels.cpu().numpy())
        pred_y.append(torch.sigmoid(out).cpu().numpy())
        
    true_y=np.concatenate(true_y)
    pred_y=np.concatenate(pred_y)
    
    gc.collect()
    
    true_y=np.array(true_y).reshape(-1,1)
    pred_y=np.array(pred_y).reshape(-1,1)

    return true_y,pred_y

In [15]:
pred_path = CONFIG.model_path1
model = MyModel(ct_len = CONFIG.ct_len_get, kernal_size=CONFIG.kernal_size, pre_train=CONFIG.pre_train).cuda()
model.load_state_dict(torch.load(pred_path))
model.cuda()
df_path = './csv_pickle/df_224_embed_sz_384_146server.csv'
print("==========data loader==========")
valid_ds = COVID_Dataset(csv_path = df_path,data_split = 'valid', ct_len_s = CONFIG.ct_len_get, transform = None)
valid_loader = DataLoader(valid_ds, batch_size=CONFIG.valid_batch_size, num_workers=15, shuffle=False, pin_memory=True)




In [16]:
total_pred=[]
for i in range(10):
    true_y,pred_y=valid_one(model, valid_loader)
    total_pred.append(pred_y)
# for i in range(len(total_pred)):
#     print(f1_score(np.array(true_y),np.round(total_pred[i]),average='macro'))  
tn, fp, fn, tp = confusion_matrix(np.array(true_y), np.round(np.mean(total_pred,axis=0))).ravel()
print("Mean F1-Score: {}".format(f1_score(np.array(true_y),np.round(np.mean(total_pred,axis=0)),average='macro')))
print("Negative Accuracy: {}".format(tn/(tn+fp)))
print("Positive Accuracy: {}".format(tp/(tp+fn)))

Mean F1-Score: 0.8884825444051875
Negative Accuracy: 0.9381663113006397
Positive Accuracy: 0.8297872340425532


In [13]:
# ======================================
# model: 2dcnn(efficientnet_b3a) + 1dcnn + resnet18
# LR: (0.0001, decay:0.0005)
# batch size: 8
# image size: [384x384]
# datatype: challenge data
# (weight =  best f1 socre checkpoint)
# ------------------------------------
# [conv axis=embedding]
# kernel=1; ct_len=100[pretrain=True]
# Mean F1-Score: 0.8964
# Negative Accuracy: 0.9744
# Positive Accuracy: 0.7872
# ------------------------------------
# [conv axis=embedding]
# kernel=3; ct_len=100[pretrain=True]
# Mean F1-Score: 0.8964
# Negative Accuracy: 0.9744
# Positive Accuracy: 0.7872
# ------------------------------------
# [conv axis=embedding]
# kernel=1; ct_len=100[pretrain=False]
# Mean F1-Score: 0.9221
# Negative Accuracy: 0.9637
# Positive Accuracy: 0.8680
# ------------------------------------
# [conv axis=slice]
# kernel=1; ct_len=100[pretrain=False]
# Mean F1-Score: 0.9133
# Negative Accuracy: 0.9658
# Positive Accuracy: 0.0.8425
# ------------------------------------
# [conv axis=embedding]
# kernel=3; ct_len=100[pretrain=False]
# Mean F1-Score: 0.9160
# Negative Accuracy: 0.9744
# Positive Accuracy: 0.8340
# ------------------------------------
# [conv axis=embedding]
# kernel=7; ct_len=100[pretrain=False]
# Mean F1-Score: 0.9150
# Negative Accuracy: 0.9658
# Positive Accuracy: 0.8468
# ------------------------------------
# [conv axis=embedding]
# kernel=1; ct_len=200[pretrain=False]
# Mean F1-Score: 0.9076
# Negative Accuracy: 0.9701
# Positive Accuracy: 0.8212
# ------------------------------------
# [conv axis=embedding]
# kernel=1; ct_len=224[pretrain=False]
# Mean F1-Score: 0.9015
# Negative Accuracy: 0.9616
# Positive Accuracy: 0.8212