In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (GPT2Config, GPT2TokenizerFast,
                          GPT2LMHeadModel, PretrainedConfig, EncoderDecoderModel)
from transformers.modeling_outputs import BaseModelOutput

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import sys
sys.path.append("../utils")
import utils_dataset
import yaml


In [3]:
data_path = '/data/chenjian/ECG_MM/pretrain_data/'
dataset_name = 'mimic'
dataset = utils_dataset.ECG_TEXT_Dsataset(
        data_path=data_path, dataset_name=dataset_name)
train_dataset = dataset.get_dataset(train_test='train')
val_dataset = dataset.get_dataset(train_test='val')

Load mimic dataset!
train size: 756259
val size: 15434
total size: 771693
Apply Train-stage Transform!
train dataset length:  756259
Apply Val-stage Transform!
val dataset length:  15434


In [4]:
print(val_dataset[4]['raw_text'])
print(val_dataset[1]['raw_text'])
print(val_dataset[2]['raw_text'])

sinus rhythm. normal ecg.
sinus bradycardia. prolonged qt interval. borderline ecg.
sinus rhythm.. inferior t wave changes are nonspecific. borderline ecg.


## Load Model

In [5]:
from models.model import ECGCLIP

ckpt_path = '/home/chenjian/multi-modal_ECG/merl/MERL/zeroshot/78.72/resnet_mix_sep_bestZeroShotAll_ckpt.pth'
ckpt = torch.load(ckpt_path, map_location='cpu')
config = yaml.load(open("/home/chenjian/multi-modal_ECG/merl/MERL/finetune/config.yaml", "r"), Loader=yaml.FullLoader)
encoder = ECGCLIP(config['network'])
encoder.load_state_dict(ckpt, strict=True)

<All keys matched successfully>

In [6]:
encoder.lm_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

In [7]:
decoder_path = '/home/chenjian/multi-modal_ECG/distilgpt2'

In [8]:
from model_gpt2 import ERGPT2

model = ERGPT2(encoder=encoder, decoder_path=decoder_path)

Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at /home/chenjian/multi-modal_ECG/distilgpt2 and are newly initialized: ['transformer.h.0.crossattention.c_attn.bias', 'transformer.h.0.crossattention.c_attn.weight', 'transformer.h.0.crossattention.c_proj.bias', 'transformer.h.0.crossattention.c_proj.weight', 'transformer.h.0.crossattention.q_attn.bias', 'transformer.h.0.crossattention.q_attn.weight', 'transformer.h.0.ln_cross_attn.bias', 'transformer.h.0.ln_cross_attn.weight', 'transformer.h.1.crossattention.c_attn.bias', 'transformer.h.1.crossattention.c_attn.weight', 'transformer.h.1.crossattention.c_proj.bias', 'transformer.h.1.crossattention.c_proj.weight', 'transformer.h.1.crossattention.q_attn.bias', 'transformer.h.1.crossattention.q_attn.weight', 'transformer.h.1.ln_cross_attn.bias', 'transformer.h.1.ln_cross_attn.weight', 'transformer.h.2.crossattention.c_attn.bias', 'transformer.h.2.crossattention.c_attn.weight', 'transformer.h.2.crossattention.c_

Load pretrained distillgpt2
Description, Special token, Index
bos_token, [BOS], 50257
eos_token, <|endoftext|>, 50256
unk_token, <|endoftext|>, 50256
pad_token, [PAD], 50258


In [9]:
from torch.utils.data.dataloader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=16,
                                  num_workers=1,
                                  drop_last=True, shuffle=False,
                                  )
        
val_loader = DataLoader(val_dataset, batch_size=32,
                        num_workers=1,
                        drop_last=True, shuffle=False,
                        )

In [10]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [11]:
from tqdm import tqdm

In [12]:
model.to('cpu')
cpkt = torch.load('/data/chenjian/ECG_MM/report/ckpt/DisGPT2_0_ckpt.pth', map_location='cpu')
# cpkt = torch.load('/data/chenjian/ECG_MM/report/merl_disgpt2.pth', map_location='cpu')

model.load_state_dict(cpkt['model_state_dict'], strict=True)
model.to('cuda:0')

ERGPT2(
  (encoder): ECGCLIP(
    (downconv): Conv1d(512, 256, kernel_size=(1,), stride=(3,))
    (att_pool_head): AttentionPool2d(
      (mhsa): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (c_proj): Linear(in_features=256, out_features=256, bias=True)
    )
    (linear1): AttentionPool2d(
      (mhsa): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (c_proj): Linear(in_features=256, out_features=256, bias=True)
    )
    (linear2): AttentionPool2d(
      (mhsa): MultiheadAttention(
        (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
      )
      (c_proj): Linear(in_features=256, out_features=256, bias=True)
    )
    (decode_t): Transformer(
      (to_patch): Mlp(
        (fc1): Linear(in_features=256, out_features=256, bias=True)
        (act): SiLU()
        (fc2): Li

In [45]:
cpkt['model_state_dict'].keys()

odict_keys(['encoder.downconv.weight', 'encoder.downconv.bias', 'encoder.att_pool_head.positional_embedding', 'encoder.att_pool_head.sep_embedding', 'encoder.att_pool_head.mhsa.in_proj_weight', 'encoder.att_pool_head.mhsa.in_proj_bias', 'encoder.att_pool_head.mhsa.out_proj.weight', 'encoder.att_pool_head.mhsa.out_proj.bias', 'encoder.att_pool_head.c_proj.weight', 'encoder.att_pool_head.c_proj.bias', 'encoder.linear1.positional_embedding', 'encoder.linear1.sep_embedding', 'encoder.linear1.mhsa.in_proj_weight', 'encoder.linear1.mhsa.in_proj_bias', 'encoder.linear1.mhsa.out_proj.weight', 'encoder.linear1.mhsa.out_proj.bias', 'encoder.linear1.c_proj.weight', 'encoder.linear1.c_proj.bias', 'encoder.linear2.positional_embedding', 'encoder.linear2.sep_embedding', 'encoder.linear2.mhsa.in_proj_weight', 'encoder.linear2.mhsa.in_proj_bias', 'encoder.linear2.mhsa.out_proj.weight', 'encoder.linear2.mhsa.out_proj.bias', 'encoder.linear2.c_proj.weight', 'encoder.linear2.c_proj.bias', 'encoder.decode

In [13]:
import yaml as yaml
import sys
sys.path.append("../finetune/")
import torch
import torch.nn.functional as F
from tqdm import tqdm

prompt_type = 'CKEPE'
prompt_dict = '/home/chenjian/multi-modal_ECG/merl/MERL/zeroshot/CKEPE_prompt.json'
with open(prompt_dict, 'r') as f:
    prompt_dict = yaml.load(f, Loader=yaml.FullLoader)
target_class = [class_name for class_name in prompt_dict.values()]

def get_class_emd(model, class_name, device='cuda'):
    model.eval()
    with torch.no_grad(): # to(device=torch.device("cuda"iftorch.cuda.is_available()else"cpu")) 
        zeroshot_weights = []
        # compute embedding through model for each class
        for texts in tqdm(class_name):
            texts = texts.lower()
            texts = [texts] # convert to list
            texts = model._tokenize(texts) # tokenize
            class_embeddings, _ = model.get_text_emb(texts.input_ids.to(device=device)
                                                            , texts.attention_mask.to(device=device)
                                                            ) # embed with text encoder
            class_embeddings = model.proj_t(class_embeddings) # embed with text encoder

            # normalize class_embeddings
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            # average over templates 
            class_embedding = class_embeddings.mean(dim=0) 
            # norm over new averaged templates
            class_embedding /= class_embedding.norm() 
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1)
    return zeroshot_weights


import numpy as np
def get_ground_truth(model, reports, class_weight, device='cuda'):
    model.eval()
    y_pred = []
    with torch.no_grad():
        
        report_tokenize_output = model._tokenize(reports)
        input_ids = report_tokenize_output.input_ids.to(
            device).contiguous()
        attention_mask = report_tokenize_output.attention_mask.to(
            device).contiguous()
        class_embeddings, _ = model.get_text_emb(input_ids.to(device=device)
                                                        , attention_mask.to(device=device)
                                                        ) # embed with text encoder
        class_embeddings = model.proj_t(class_embeddings) # embed with text encoder

        # normalize class_embeddings
        class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
        logits = class_embeddings @ class_weight.to(device)
        logits = torch.squeeze(logits, 0) # (N, num_classes)
        # norm_logits = (logits - logits.mean()) / (logits.std())
        # logits = torch.sigmoid(norm_logits) 
            
        y_pred.append(logits.cpu().data.numpy())

    y_pred = np.concatenate(y_pred, axis=0)
    labels = np.array(y_pred)
    labels = np.argmax(labels, axis=1)
    return labels, y_pred

In [14]:
class_emb = get_class_emd(model=encoder.to('cuda:0'), class_name=target_class, device='cuda:0')

100%|██████████| 131/131 [00:00<00:00, 184.09it/s]


In [48]:
class_emb.shape

torch.Size([256, 131])

In [15]:
labels_all = []
reports_all = []
outputs_all = []
labels_pred_all = []
with torch.no_grad():
    for data in tqdm(val_loader):
        report = data['raw_text']#.to(device)
        # get ecg
        ecg = data['ecg'].to(torch.float32).contiguous().to('cuda:0')
        encoder_outputs = model.encoder_forward(ecg)
        encoder_outputs = BaseModelOutput(
                last_hidden_state=encoder_outputs)
        outputs = model.decoder.encoder_decoder.generate(
                max_length=256,
                bos_token_id=model.tokenizer.bos_token_id,
                eos_token_id=model.tokenizer.eos_token_id,
                pad_token_id=model.tokenizer.pad_token_id,
                num_beams=3,
                return_dict_in_generate=True,
                use_cache=True, 
                encoder_outputs=encoder_outputs,
            )
        output = model.tokenizer.batch_decode(
            outputs['sequences'], skip_special_tokens=True)
        reports_all.append(report)
        outputs_all.append(output)
        label, logits = get_ground_truth(model=model.encoder, reports=report, class_weight=class_emb, device='cuda:0')
        label_pred, logits_pred = get_ground_truth(model=model.encoder, reports=output, class_weight=class_emb, device='cuda:0')
        labels_all.append(label)
        labels_pred_all.append(label_pred)

100%|██████████| 482/482 [13:58<00:00,  1.74s/it]


In [16]:
reports_all = [i for item in reports_all for i in item]
outputs_all = [i for item in outputs_all for i in item]

In [17]:
reports_all[0]

'possible ectopic atrial rhythm.. left axis deviation. right bundle branch block. inferior infarct - age undetermined. abnormal ecg.'

In [18]:
outputs_all[0]

'sinus rhythm. left axis deviation. rbbb with left anterior fascicular block. inferior infarct - age undetermined. abnormal ecg.'

In [19]:
labels_a = np.hstack(labels_all)
labels_p = np.hstack(labels_pred_all)

In [20]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer



# 初始化 BLEU 和 ROUGE scorer
smooth = SmoothingFunction().method1
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

# 存储 BLEU 和 ROUGE 分数
bleu_scores = {'bleu1': [], 'bleu2': [], 'bleu3': [], 'bleu4': []}
rougeL_scores = []

# 逐一计算每个样本的 BLEU 和 ROUGE 分数
for ref, gen in zip(reports_all, outputs_all):
    # 计算 BLEU-1 分数
    bleu1 = sentence_bleu([ref.split()], gen.split(), weights=(1, 0, 0, 0), smoothing_function=smooth)
    bleu_scores['bleu1'].append(bleu1)
    
    # 计算 BLEU-2 分数
    bleu2 = sentence_bleu([ref.split()], gen.split(), weights=(0.5, 0.5, 0, 0), smoothing_function=smooth)
    bleu_scores['bleu2'].append(bleu2)
    
    # 计算 BLEU-3 分数
    bleu3 = sentence_bleu([ref.split()], gen.split(), weights=(0.33, 0.33, 0.33, 0), smoothing_function=smooth)
    bleu_scores['bleu3'].append(bleu3)
    
    # 计算 BLEU-4 分数
    bleu4 = sentence_bleu([ref.split()], gen.split(), weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smooth)
    bleu_scores['bleu4'].append(bleu4)
    
    # 计算 ROUGE-L 分数
    rougeL = scorer.score(ref, gen)['rougeL'].fmeasure
    rougeL_scores.append(rougeL)

# 计算平均值
average_bleu1 = sum(bleu_scores['bleu1']) / len(bleu_scores['bleu1'])
average_bleu2 = sum(bleu_scores['bleu2']) / len(bleu_scores['bleu2'])
average_bleu3 = sum(bleu_scores['bleu3']) / len(bleu_scores['bleu3'])
average_bleu4 = sum(bleu_scores['bleu4']) / len(bleu_scores['bleu4'])
average_rougeL = sum(rougeL_scores) / len(rougeL_scores)

# 打印平均结果
print(f"Average BLEU-1 Score: {average_bleu1:.4f}")
print(f"Average BLEU-2 Score: {average_bleu2:.4f}")
print(f"Average BLEU-3 Score: {average_bleu3:.4f}")
print(f"Average BLEU-4 Score: {average_bleu4:.4f}")
print(f"Average ROUGE-L F1 Score: {average_rougeL:.4f}")

Average BLEU-1 Score: 0.6133
Average BLEU-2 Score: 0.5598
Average BLEU-3 Score: 0.5157
Average BLEU-4 Score: 0.4850
Average ROUGE-L F1 Score: 0.7059


In [21]:
from sklearn.metrics import f1_score, precision_score, recall_score

f1 = f1_score(labels_a, labels_p, average='macro')
pre = precision_score(labels_a, labels_p, average='macro')
rec = recall_score(labels_a, labels_p, average='macro')

print('f1:', f1)
print('pre:', pre)
print('rec:', rec)

f1: 0.23329257874351925
pre: 0.2495743946337152
rec: 0.23694279461925052


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [22]:
metrics = {'BLEU-1':average_bleu1,
        'BLEU-2':average_bleu2,
        'BLEU-3':average_bleu3,
        'BLEU-4':average_bleu4,
        'ROUGE-L F1 Score':average_rougeL,
        'CE F1 Score': f1,
        'CE Precision Score': pre,
        'CE Recall Score': rec,
        'report':reports_all,
        'generated report': outputs_all
        }

torch.save({
    'metrics': metrics},
    f'/data/chenjian/ECG_MM/report/ckpt/disGPT2_align_metrics.pth'     
    )