# setup

In [1]:
import torch
import yaml
import random
import numpy as np
import sys
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data.dataloader import DataLoader

sys.path.append('../utils')
# from dataset import ECG_TEXT_Dsataset
from dataset import ECG_test_Dsataset
from builder import ECGCLIP
from utils import find_best_thresholds, metrics_table
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)
model_checkpoints_folder = '../checkpoints/'
model_checkpoints_folder + config['wandb_name'] + f'_bestZeroShotAll_ckpt.pth'

'../checkpoints/vit_tiny_150_bestZeroShotAll_ckpt.pth'

In [4]:
torch.manual_seed(42)
random.seed(0)
np.random.seed(0)

# dataset

In [5]:
config['dataset']['dataset_name']

'CODEmel'

In [6]:
# data_path = config['dataset']['data_path']
data_path = 'D:\datasets\codetest-mel'
dataset = ECG_test_Dsataset(
    data_path=data_path, dataset_name='CODEtestmel')
tst_dataset = dataset.get_dataset()
batch_size = 4
tst_loader = DataLoader(tst_dataset, batch_size = batch_size, shuffle=False, )

Load CODEtestmel dataset!
Apply Val-stage Transform!
code test dataset length:  827


# builder

In [7]:
model = ECGCLIP(config['network'])

In [8]:
model = model.to(device).eval()
ckpt = torch.load(model_checkpoints_folder + config['wandb_name'] + f'_bestZeroShotAll_ckpt.pth', map_location='cpu')
model.load_state_dict(ckpt)

<All keys matched successfully>

# zeroshot

## subfeatures com normal

In [56]:
text_1davb = 'bloqueio atrioventricular de primeiro grau'
text_rbbb = 'bloqueio de ramo direito'
text_lbbb = 'bloqueio de ramo esquerdo'
text_sb = 'bradicardia'
text_af = 'fibrilacao atrial'
text_st = 'taquicardia'
text_normal = 'dentro dos limites da normalidade'

texts = [text_1davb, text_rbbb, text_lbbb, text_sb, text_af, text_st, text_normal]
num_classes = len(texts)

In [57]:
with torch.no_grad():
    # text synthesis
    zeroshot_weights = []
    for text in tqdm(texts):
        text = model._tokenize([text.lower()])

        class_embeddings = model.get_text_emb(text.input_ids.to(device=device), text.attention_mask.to(device=device)) # embed with text encoder
        class_embeddings = model.proj_t(class_embeddings) # embed with text encoder

        # normalize class_embeddings
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        # average over templates 
        class_embedding = class_embeddings.mean(dim=0) 
        # norm over new averaged templates
        class_embedding /= class_embedding.norm()

        zeroshot_weights.append(class_embedding)
    zeroshot_weights = torch.stack(zeroshot_weights, dim=1)

100%|██████████| 7/7 [00:01<00:00,  6.12it/s]


In [58]:
y_pred = []
all_binary_results = []
all_true_labels = []
with torch.no_grad():
    for data in tqdm(tst_loader):
        # read
        # report = data['raw_text']
        ecg = data['ecg'].to(torch.float32).to(device).contiguous()
        label = data['label'].float().to(device)
        exam_id = data['exam_id'].to(device)

        # predict
        ecg_emb = model.ext_ecg_emb(ecg)
        ecg_emb /= ecg_emb.norm(dim=-1, keepdim=True)

        # obtain logits (cos similarity)
        logits = ecg_emb @ zeroshot_weights
        logits = torch.squeeze(logits, 0) # (N, num_classes)
        # logits = torch.stack([torch.max(logits[:, s], axis = 1).values for s in slices]).T

        # norm_logits = (logits - logits.mean()) / (logits.std())
        norm_logits = (logits - logits.mean(axis = 1).unsqueeze(1)) / (logits.std(axis = 1).unsqueeze(1))
        probs = torch.sigmoid(norm_logits)
        # probs = torch.sigmoid(logits)
        # break

        # binary_result = torch.zeros_like(probs)
        # binary_result[np.arange(logits.shape[0]), torch.argmax(probs, axis = 1)] = 1

        y_pred.append(probs)
        # all_binary_results.append(binary_result)
        all_true_labels.append(label)

100%|██████████| 207/207 [00:05<00:00, 36.08it/s]


In [59]:
y_pred = torch.cat(y_pred, dim=0)
all_true_labels = torch.cat(all_true_labels, dim=0)

In [60]:
y_pred, all_true_labels

(tensor([[0.4923, 0.6153, 0.5409,  ..., 0.7455, 0.7085, 0.1429],
         [0.7836, 0.6682, 0.6465,  ..., 0.5407, 0.4025, 0.1554],
         [0.3964, 0.2589, 0.2234,  ..., 0.7774, 0.7741, 0.4875],
         ...,
         [0.4070, 0.2288, 0.1957,  ..., 0.7310, 0.6712, 0.6879],
         [0.3867, 0.2291, 0.2027,  ..., 0.6376, 0.6541, 0.6963],
         [0.5459, 0.2012, 0.1851,  ..., 0.6903, 0.6770, 0.6082]]),
 tensor([[0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         ...,
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0.]]))

In [61]:
log = []
for i in range(6):
    log.append(roc_auc_score(all_true_labels[:, i], y_pred[:, i]))
np.mean(log), log

(np.float64(0.8006971951127358),
 [np.float64(0.6259610227069551),
  np.float64(0.8560195831169795),
  np.float64(0.8749477206189878),
  np.float64(0.7997842170160295),
  np.float64(0.7418257418257419),
  np.float64(0.9056448853917209)])

## subfeatures sem normal

In [62]:
text_1davb = 'bloqueio atrioventricular de primeiro grau'
text_rbbb = 'bloqueio de ramo direito'
text_lbbb = 'bloqueio de ramo esquerdo'
text_sb = 'bradicardia'
text_af = 'fibrilacao atrial'
text_st = 'taquicardia'

texts = [text_1davb, text_rbbb, text_lbbb, text_sb, text_af, text_st, ]
num_classes = len(texts)

In [63]:
with torch.no_grad():
    # text synthesis
    zeroshot_weights = []
    for text in tqdm(texts):
        text = model._tokenize([text.lower()])

        class_embeddings = model.get_text_emb(text.input_ids.to(device=device), text.attention_mask.to(device=device)) # embed with text encoder
        class_embeddings = model.proj_t(class_embeddings) # embed with text encoder

        # normalize class_embeddings
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        # average over templates 
        class_embedding = class_embeddings.mean(dim=0) 
        # norm over new averaged templates
        class_embedding /= class_embedding.norm()

        zeroshot_weights.append(class_embedding)
    zeroshot_weights = torch.stack(zeroshot_weights, dim=1)

100%|██████████| 6/6 [00:00<00:00,  6.01it/s]


In [64]:
y_pred = []
all_binary_results = []
all_true_labels = []
with torch.no_grad():
    for data in tqdm(tst_loader):
        # read
        # report = data['raw_text']
        ecg = data['ecg'].to(torch.float32).to(device).contiguous()
        label = data['label'].float().to(device)
        exam_id = data['exam_id'].to(device)

        # predict
        ecg_emb = model.ext_ecg_emb(ecg)
        ecg_emb /= ecg_emb.norm(dim=-1, keepdim=True)

        # obtain logits (cos similarity)
        logits = ecg_emb @ zeroshot_weights
        logits = torch.squeeze(logits, 0) # (N, num_classes)
        # logits = torch.stack([torch.max(logits[:, s], axis = 1).values for s in slices]).T

        # norm_logits = (logits - logits.mean()) / (logits.std())
        norm_logits = (logits - logits.mean(axis = 1).unsqueeze(1)) / (logits.std(axis = 1).unsqueeze(1))
        probs = torch.sigmoid(norm_logits)
        # probs = torch.sigmoid(logits)
        # break

        # binary_result = torch.zeros_like(probs)
        # binary_result[np.arange(logits.shape[0]), torch.argmax(probs, axis = 1)] = 1

        y_pred.append(probs)
        # all_binary_results.append(binary_result)
        all_true_labels.append(label)

100%|██████████| 207/207 [00:05<00:00, 34.97it/s]


In [65]:
y_pred = torch.cat(y_pred, dim=0)
all_true_labels = torch.cat(all_true_labels, dim=0)

In [66]:
log = []
for i in range(6):
    log.append(roc_auc_score(all_true_labels[:, i], y_pred[:, i]))
np.mean(log), log

(np.float64(0.7833614639218197),
 [np.float64(0.618138744859646),
  np.float64(0.8715970625324532),
  np.float64(0.8800920117105814),
  np.float64(0.7452990135635018),
  np.float64(0.7067662067662068),
  np.float64(0.878275744098529)])

# label vs 6x

In [9]:
text_1davb = 'bloqueio atrioventricular de primeiro grau'
text_rbbb = 'bloqueio de ramo direito'
text_lbbb = 'bloqueio de ramo esquerdo'
text_sb = 'bradicardia'
text_af = 'fibrilacao atrial'
text_st = 'taquicardia'

texts = [text_1davb, text_rbbb, text_lbbb, text_sb, text_af, text_st, ]

In [12]:
def run6x(text_label, ):
    text_normal = 'ecg dentro dos limites da normalidade'
    texts = [text_label, text_normal]
    num_classes = len(texts)
    
    with torch.no_grad():
        # text synthesis
        zeroshot_weights = []
        for text in tqdm(texts):
            text = model._tokenize([text.lower()])

            class_embeddings = model.get_text_emb(text.input_ids.to(device=device), text.attention_mask.to(device=device)) # embed with text encoder
            class_embeddings = model.proj_t(class_embeddings) # embed with text encoder

            # normalize class_embeddings
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            # average over templates 
            class_embedding = class_embeddings.mean(dim=0) 
            # norm over new averaged templates
            class_embedding /= class_embedding.norm()

            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1)
    
    y_pred = []
    all_true_labels = []
    with torch.no_grad():
        for data in tqdm(tst_loader):
            # read
            # report = data['raw_text']
            ecg = data['ecg'].to(torch.float32).to(device).contiguous()
            label = data['label'].float().to(device)
            exam_id = data['exam_id'].to(device)

            # predict
            ecg_emb = model.ext_ecg_emb(ecg)
            ecg_emb /= ecg_emb.norm(dim=-1, keepdim=True)

            # obtain logits (cos similarity)
            logits = ecg_emb @ zeroshot_weights
            logits = torch.squeeze(logits, 0) # (N, num_classes)
            # logits = torch.stack([torch.max(logits[:, s], axis = 1).values for s in slices]).T

            # norm_logits = (logits - logits.mean()) / (logits.std())
            norm_logits = (logits - logits.mean(axis = 1).unsqueeze(1)) / (logits.std(axis = 1).unsqueeze(1))
            probs = torch.sigmoid(norm_logits)
            # probs = torch.sigmoid(logits)
            # break

            # binary_result = torch.zeros_like(probs)
            # binary_result[np.arange(logits.shape[0]), torch.argmax(probs, axis = 1)] = 1

            y_pred.append(probs)
            # all_binary_results.append(binary_result)
            all_true_labels.append(label)
    
    y_pred = torch.cat(y_pred, dim=0)
    all_true_labels = torch.cat(all_true_labels, dim=0)
    
    return roc_auc_score(all_true_labels[:, 0], y_pred[:, 0])

In [13]:
log = []
for text_label in texts:
    auc = run6x(text_label)
    log.append(auc)

100%|██████████| 2/2 [00:00<00:00,  5.72it/s]
100%|██████████| 207/207 [00:05<00:00, 35.81it/s]
100%|██████████| 2/2 [00:00<00:00,  6.54it/s]
100%|██████████| 207/207 [00:05<00:00, 34.94it/s]
100%|██████████| 2/2 [00:00<00:00,  6.10it/s]
100%|██████████| 207/207 [00:06<00:00, 32.96it/s]
100%|██████████| 2/2 [00:00<00:00,  6.51it/s]
100%|██████████| 207/207 [00:05<00:00, 35.75it/s]
100%|██████████| 2/2 [00:00<00:00,  6.45it/s]
100%|██████████| 207/207 [00:05<00:00, 35.75it/s]
100%|██████████| 2/2 [00:00<00:00,  6.25it/s]
100%|██████████| 207/207 [00:05<00:00, 35.12it/s]


In [14]:
np.mean(log), log

(np.float64(0.5671374932951905),
 [np.float64(0.5427319864115858),
  np.float64(0.5818210262828535),
  np.float64(0.5807482567495085),
  np.float64(0.6614071160379045),
  np.float64(0.547939388521366),
  np.float64(0.4881771857679242)])