In [4]:
import torch
import utils
import argparse
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader

from sklearn.metrics import accuracy_score, recall_score, f1_score, balanced_accuracy_score, roc_auc_score, roc_curve, confusion_matrix, auc

import matplotlib.pyplot as plt

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_name', help='experiement name', type=str, default='CNN_B_CMSC')
    parser.add_argument('--model', help='Model'  , type=str, default='CNN_B'  , choices=['CNN_B', 'CNN_Bg', 'CNN_single-B', 'vit_B', 'vit_B-single'])
    parser.add_argument('--dataset', help='Dataset', type=str, default='cad', choices=['angio', 'cad', 'whole', 'full'])
    parser.add_argument('--trainset', type=str, default='cad', choices=['angio', 'cad', 'whole', 'full'])
    parser.add_argument('--testset', type=str, default='cad', choices=['angio', 'cad', 'whole', 'full'])
    parser.add_argument('--phase', help='Phase', type=str, default='randominit', choices=['finetune', 'linear', 'randominit', 'SimCLR', 'BYOL', 'CMSC'])
    parser.add_argument('--loss', help='Loss function', type=str, default='CrossEntropyLoss')
    parser.add_argument('--optimizer', help='Optimizer', type=str, default='AdamW')
    parser.add_argument('--lr', help='Learning rate', type=float, default=0.001)
    parser.add_argument('--decay', help='Weight decay', type=float, default=0.001)
    parser.add_argument('--batch_size', help='Batch size', type=int, default=128)
    parser.add_argument('--epochs', help='Epochs', type=int, default=100)
    parser.add_argument('--ckpt_freq', type=int, default=20)
    parser.add_argument('--seed', type=int, default=0)
     
    parser.add_argument('--t', help='temperature for SimCLR', type=float, default=0.5)
    parser.add_argument('--ma_decay', help='Moving average decay', type=float, default=0.9)
    
    parser.add_argument('--datapath', type=str, default='./dataset')
    parser.add_argument('--test_batch', type=int, default=2048)
    parser.add_argument('--ckpt_path', type=str, default='SimCLR_4096_aug_pretrain')
    parser.add_argument('--ckpt_epoch', type=int, default=3)
    parser.add_argument('--use_tb', type=bool, default=False)

    args = parser.parse_args([])
    return args

def load_backbone(ckpt, model):
    model_dict = model.state_dict()
    pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
    model_dict.update(pretrained_dict)
    model.load_state_dict(model_dict)

if __name__ == '__main__':
    from torch.utils.data import DataLoader
    args = parse_args()
    args.model = 'vit_B'
    torch.manual_seed(args.seed)
    
    # trainer = utils.build_trainer(args)
    # ckpts = ['baseline_linear_angio', 'baseline_linear_cad', 'baseline_linear_whole']
    ckpts = ['vit_angio', 'vit_cad', 'vit_whole']
    testsets = ['angio', 'cad', 'whole']
    
    
    for data in testsets:
        args.testset = data
        args.phase = 'randominit'
        trainer = utils.build_trainer(args)
        for ckpt in ckpts:
            checkpoint = torch.load(f'./checkpoints/{ckpt}/30.pth')
            load_backbone(checkpoint, trainer.model)            
            trainer.test()
            print(f'{data} {ckpt}')
            trainer.print_train_info()

  0%|          | 0/524 [00:00<?, ?it/s]

100%|██████████| 524/524 [00:00<00:00, 897.75it/s]
100%|██████████| 780/780 [00:00<00:00, 894.97it/s]


Train data: torch.Size([1304, 6, 12, 2500]) target: torch.Size([1304])
Weighted Loss) Negative: 0.60 Positive: 0.40


100%|██████████| 132/132 [00:00<00:00, 947.34it/s]

Test data: torch.Size([132, 6, 12, 2500]) target: torch.Size([132])





TRAINING FROM SCRATCH!
model name:vit_B
dataset:cad
device:cuda
Tensorboard:False
Total parameter:66,778,882
angio vit_angio
(001/100) Train Loss:-1.0000 Test Loss:0.9553 Test Accuracy:0.5152% Balanced Test Accuracy:0.5936% Sensitivity:0.4673 specificity:0.7200 f1:0.6098 AUROC:0.5781
angio vit_cad
(001/100) Train Loss:-1.0000 Test Loss:0.9818 Test Accuracy:0.7197% Balanced Test Accuracy:0.5819% Sensitivity:0.8037 specificity:0.3600 f1:0.8230 AUROC:0.5854
angio vit_whole
(001/100) Train Loss:-1.0000 Test Loss:4.7380 Test Accuracy:0.4545% Balanced Test Accuracy:0.4796% Sensitivity:0.4393 specificity:0.5200 f1:0.5663 AUROC:0.4499


100%|██████████| 524/524 [00:00<00:00, 923.21it/s]
100%|██████████| 780/780 [00:00<00:00, 904.33it/s]


Train data: torch.Size([1304, 6, 12, 2500]) target: torch.Size([1304])
Weighted Loss) Negative: 0.60 Positive: 0.40


100%|██████████| 132/132 [00:00<00:00, 927.49it/s]
100%|██████████| 195/195 [00:00<00:00, 911.07it/s]


Test data: torch.Size([327, 6, 12, 2500]) target: torch.Size([327])
TRAINING FROM SCRATCH!
model name:vit_B
dataset:cad
device:cuda
Tensorboard:False
Total parameter:66,778,882
cad vit_angio
(001/100) Train Loss:-1.0000 Test Loss:2.1145 Test Accuracy:0.5719% Balanced Test Accuracy:0.5292% Sensitivity:0.2446 specificity:0.8138 f1:0.3269 AUROC:0.5132
cad vit_cad
(001/100) Train Loss:-1.0000 Test Loss:0.6371 Test Accuracy:0.5780% Balanced Test Accuracy:0.5542% Sensitivity:0.3957 specificity:0.7128 f1:0.4435 AUROC:0.5634
cad vit_whole
(001/100) Train Loss:-1.0000 Test Loss:2.2077 Test Accuracy:0.6330% Balanced Test Accuracy:0.6171% Sensitivity:0.5108 specificity:0.7234 f1:0.5420 AUROC:0.6396


100%|██████████| 524/524 [00:00<00:00, 898.69it/s]
100%|██████████| 780/780 [00:00<00:00, 903.56it/s]


Train data: torch.Size([1304, 6, 12, 2500]) target: torch.Size([1304])
Weighted Loss) Negative: 0.60 Positive: 0.40


100%|██████████| 132/132 [00:00<00:00, 913.41it/s]
100%|██████████| 195/195 [00:00<00:00, 940.81it/s]
100%|██████████| 3189/3189 [00:03<00:00, 901.15it/s]


Test data: torch.Size([3516, 6, 12, 2500]) target: torch.Size([3516])
TRAINING FROM SCRATCH!
model name:vit_B
dataset:cad
device:cuda
Tensorboard:False
Total parameter:66,778,882
whole vit_angio
(001/100) Train Loss:-1.0000 Test Loss:2.9621 Test Accuracy:0.5731% Balanced Test Accuracy:0.5431% Sensitivity:0.5083 specificity:0.5779 f1:0.1408 AUROC:0.5318
whole vit_cad
(001/100) Train Loss:-1.0000 Test Loss:0.3839 Test Accuracy:0.7290% Balanced Test Accuracy:0.5732% Sensitivity:0.3926 specificity:0.7538 f1:0.1662 AUROC:0.5906
whole vit_whole
(001/100) Train Loss:-1.0000 Test Loss:0.3574 Test Accuracy:0.3805% Balanced Test Accuracy:0.5392% Sensitivity:0.7231 specificity:0.3552 f1:0.1384 AUROC:0.5197
