# setup

In [1]:
import random
# from torch.nn.parallel import DistributedDataParallel as DDP
# import torch.multiprocessing as mp
# import torch.distributed as dist
import tempfile
import os
from torch import optim
import torch.nn as nn
import pandas as pd
import numpy as np
import torch
import yaml
import sys
sys.path.append("../utils")
# from utils_trainer import trainer_wBert
from trainer import trainer_wBert
# import utils_dataset
from dataset import ECG_TEXT_Dsataset
# import utils_builder
from builder import ECGCLIP

# import wandb

In [2]:
# os.environ["TOKENIZERS_PARALLELISM"] = "true"

# main

In [3]:
# dist.init_process_group("nccl")
torch.cuda.empty_cache()
# rank = dist.get_rank()

# device_id = torch.cuda.device_count()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
config = yaml.load(open("config.yaml", "r"), Loader=yaml.FullLoader)

## wandb

In [5]:
project="MERL_ICML",
name = config['wandb_name'],
# Track hyperparameters and run metadata
config_param={
        "learning_rate": config['optimizer']['params']['lr'],
        "total_epochs": config['trainer']['max_epochs'],
        'weight_decay': config['optimizer']['params']['weight_decay'],
        'ecg_model': config['network']['ecg_model'],
        'text_model': config['network']['text_model'],
        'batch_size': config['trainer']['batch_size'],
        'val_zeroshot': 'all_sets',
        'prompt_type': config['zeroshot']['prompt_type'],
}

In [6]:
# run = wandb.init(project = project, name = name, config = config_param)

In [7]:
name, config_param['learning_rate'], config_param['total_epochs'], config_param['weight_decay'], config_param['ecg_model']

(('vit_tiny_demo',), 0.001, 20, 1e-08, 'resnet18')

In [8]:
config_param['text_model'], config_param['batch_size'], config_param['val_zeroshot'], config_param['prompt_type']

('pucpr/biobertpt-all', 128, 'all_sets', 'CKEPE')

In [9]:
torch.manual_seed(42)
random.seed(0)
np.random.seed(0)

## utils dataset

In [10]:
data_path = config['dataset']['data_path']
data_path

'\\Users\\katri\\Downloads\\git\\lesaude\\code\\CODEmel'

In [11]:
dataset = ECG_TEXT_Dsataset(
    data_path=data_path, dataset_name=config['dataset']['dataset_name'])
train_dataset = dataset.get_dataset(train_test='train')
val_dataset = dataset.get_dataset(train_test='val')

Load CODEmel dataset!
train size: 35995
val size: 2006
tst size: 1999
total size: 40000
Apply Train-stage Transform!
train dataset length:  35995
Apply Val-stage Transform!
val dataset length:  2006


## utils builder

In [12]:
model = ECGCLIP(config['network'])

In [13]:
if config['network']['free_layers'] is not None:
    for layer_idx in range(int(config['network']['free_layers'])):
        for param in list(model.lm_model.encoder.layer[layer_idx].parameters()):
            param.requires_grad = False

In [14]:
# model = model.to(device_id)
# model = DDP(model, device_ids=[device_id], find_unused_parameters=True)

## utils trainer

In [15]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    **config['optimizer']['params'],
    betas=(0.9, 0.999)
)

In [16]:
trainer = trainer_wBert(model=model, 
                        optimizer=optimizer,
                        device=device,
                        model_name=config['wandb_name'],
                        **config['trainer'])

In [17]:
# trainer.train_w_TextEmb(train_dataset, val_dataset, config['zeroshot'])

# trainer

In [18]:
from torch.utils.data.dataloader import DataLoader
from torch.cuda.amp import autocast as autocast
from torch.cuda.amp import GradScaler as GradScaler
from tqdm import tqdm

from utils_loss import clip_loss

In [19]:
train_loader = DataLoader(train_dataset, batch_size=trainer.train_batch_size,
                            # num_workers=trainer.num_workers,
                            drop_last=True, shuffle=False,
                        #   sampler=DistributedSampler(train_dataset), 
                            )

val_loader = DataLoader(val_dataset, batch_size=trainer.val_batch_size,
                        # num_workers=trainer.num_workers,
                        drop_last=True, shuffle=False,
                        # sampler=DistributedSampler(val_dataset), 
                        )

In [20]:
model_checkpoints_folder = os.path.join('../checkpoints/')
# if self.device == 0:
if not os.path.exists(model_checkpoints_folder):
    print('create directory "{}" for save checkpoint!'.format(
        model_checkpoints_folder))
    print('---------------------------')
    os.makedirs(model_checkpoints_folder)
else:
    print('directory "{}" existing for save checkpoint!'.format(
        model_checkpoints_folder))

directory "../checkpoints/" existing for save checkpoint!


In [21]:
# automatically resume from checkpoint if it exists
print('#########################################')
print('Be patient..., checking checkpoint now...')
if os.path.exists(model_checkpoints_folder + trainer.model_name+'_checkpoint.pth'):
    ckpt = torch.load(model_checkpoints_folder + trainer.model_name+'_checkpoint.pth',
                        map_location='cpu')
    start_epoch = ckpt['epoch']
    trainer.model.load_state_dict(ckpt['model_state_dict'])
    trainer.optimizer.load_state_dict(ckpt['optimizer_state_dict'])
    print('continue training successful!')
else:
    start_epoch = 0
    print('Start training from 0 epoch')

print('#########################################')
print('training start!')

#########################################
Be patient..., checking checkpoint now...
Start training from 0 epoch
#########################################
training start!


In [22]:
# scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    trainer.optimizer,
    T_0=5000,
    T_mult=1,
    eta_min=1e-8,
)
niter = 1

In [23]:
skip_scheduler = False
scaler = GradScaler()

  scaler = GradScaler()


In [24]:
f1_total = []
acc_total = []
auc_total = []
zeroshot_csv = pd.DataFrame()
best_auc = 0
for epoch_counter in (range(start_epoch, trainer.max_epochs)):
    break

In [25]:
epoch_loss = 0
epoch_acc1 = []
epoch_acc5 = []
# self.model.train()
for data in tqdm(train_loader):
    break

  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/281 [00:00<?, ?it/s]


In [26]:
trainer.model.train()
# get raw text
report = data['raw_text']
# get ecg
ecg = data['ecg'].to(torch.float32).to(
    trainer.device).contiguous()
trainer.optimizer.zero_grad()

In [28]:
report_tokenize_output = trainer.model._tokenize(report)

input_ids = report_tokenize_output.input_ids.to(
    trainer.device).contiguous()
attention_mask = report_tokenize_output.attention_mask.to(
    trainer.device).contiguous()

output_dict = trainer.model(ecg, input_ids, attention_mask)

In [44]:
agg_proj_img_emb = output_dict['proj_ecg_emb']
agg_proj_text_emb = output_dict['proj_text_emb']
agg_proj_ecg_emb1, agg_proj_ecg_emb2 = output_dict['ecg_emb']

In [51]:
agg_proj_img_emb = torch.cat(agg_proj_img_emb, dim=0)
agg_proj_text_emb = torch.cat(agg_proj_text_emb, dim=0)

In [52]:
cma_loss, acc1, acc5 = clip_loss(agg_proj_img_emb, agg_proj_text_emb, device=trainer.device)
uma_loss, _, _ = clip_loss(agg_proj_ecg_emb1, agg_proj_ecg_emb2, device=trainer.device)
loss = cma_loss + uma_loss

In [56]:
print(f'loss is {loss.item()}, acc1 is {acc1.item()}, acc5 is {acc5.item()}, cma_loss is {cma_loss.item()}, uma_loss is {uma_loss.item()}')

loss is 20.105636596679688, acc1 is 0.78125, acc5 is 3.515625, cma_loss is 10.219921112060547, uma_loss is 9.885714530944824


In [58]:
epoch_loss += loss.item()
epoch_acc1.append(acc1.item())
epoch_acc5.append(acc5.item())

scaler.scale(loss).backward()
scaler.step(trainer.optimizer)
scaler.update()

In [59]:
if not skip_scheduler:
    scheduler.step()
niter += 1

In [None]:
# val_log = trainer.val(val_loader)