In [1]:
import os
import pickle
from tqdm import tqdm
import torch
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTPretrainEHRDataset, batcher
from torch.utils.data import DataLoader
from HEART_rep import HBERT_Pretrain
from set_seed_utils import set_random_seed

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
args = {
    "seed": 0,
    "dataset": "MIMIC-III",  # MIMIC-III, MIMIC-IV
    "batch_size": 32,
    "lr": 2e-5,
    "epochs": 30,
    "encoder": "hi_edge",
    "mask_rate": 0.7,
    "anomaly_rate": 0.05,
    "anomaly_loss_weight": 1,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "gat": "dotattn",  # dotattn, None
    "diag_med_emb": "tree",  # simple, tree
}

In [4]:
# here we only use the MIMIC dataset
args['max_visit_size'] = 15
args['predicted_token_type'] = ["diag", "med", "pro", "lab"]
args['mask_token_id'] = {"diag":3, "med":4, "pro":5, "lab":6}  # {token_type: masked_id}
args['special_tokens'] = ("[PAD]", "[CLS]", "[SEP]", 
                       "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
# note that here "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]" are used for masking pretraining task
# codes that are actually masked are not in the input sequence

In [5]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"
pretrain_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_pretrain.pkl" # for pretraining

In [6]:
ehr_data = pickle.load(open(full_data_path, 'rb'))

In [7]:
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
# lab was cut in 5 percentiles
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]
age_gender_set = [[str(c) + "_" + gender] \
    for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]

In [8]:
# tokenizer used full data
tokenizer = EHRTokenizer(diag_sentences, med_sentences, 
                         lab_sentences, pro_sentences, gender_set, 
                         age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [9]:
ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
pretrain_dataset = HBERTPretrainEHRDataset(ehr_pretrain_data, tokenizer, 
                                  token_type=args['predicted_token_type'], 
                                  mask_rate=args['mask_rate'],
                                  anomaly_rate=args['anomaly_rate'])

In [10]:
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=args["batch_size"], 
                                 collate_fn=batcher(pad_id = tokenizer.vocab.word2id["[PAD]"], 
                                                    n_token_type=len(args["predicted_token_type"]), is_train = True),
                                 shuffle=True)

In [11]:
set_random_seed(args["seed"])

[INFO] Random seed set to 0


In [12]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["mask_rate"]) \
    + "-" + str(args["anomaly_rate"]) \
    + "-" + str(args["anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

save_path = "./pretrained_models/" + exp_name
if not os.path.exists(save_path):
    os.makedirs(save_path)

Pretrain-HBERT-MIMIC-III-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [13]:
args["vocab_size"] = 7 + len(tokenizer.diag_voc.id2word) + \
                len(tokenizer.pro_voc.id2word) + \
                len(tokenizer.med_voc.id2word) + \
                len(tokenizer.lab_voc.id2word) + \
                len(tokenizer.age_voc.id2word) + \
                len(tokenizer.gender_voc.id2word) + \
                len(tokenizer.age_gender_voc.id2word)

args["label_vocab_size"] = {"diag":len(tokenizer.diag_voc.id2word), 
                        "pro":len(tokenizer.pro_voc.id2word), 
                        "med":len(tokenizer.med_voc.id2word), 
                        "lab":len(tokenizer.lab_voc.id2word)}  # {token_type: vocab_size}

In [14]:
loss_entity = ["diag", "med", "pro", "lab", "anomaly"]

In [15]:
model = HBERT_Pretrain(args, tokenizer).to(device)

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])

In [17]:
for epoch in range(1, 1 + args["epochs"]):
    train_iter = tqdm(pretrain_dataloader, ncols=140)
    model.train()
    ave_loss, ave_loss_dict = 0., {token_type: 0. for token_type in loss_entity}

    for step, batch in enumerate(train_iter):

        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss, loss_dict, perf_dict = model(*batch)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_iter.set_description(f"Epoch:{epoch: 03d}, Step:{step: 03d}, loss:{loss.item():.4f}, diag:{loss_dict['diag']:.4f}, med:{loss_dict['med']:.4f}, pro:{loss_dict['pro']:.4f}, lab:{loss_dict['lab']:.4f}, anomaly:{loss_dict['anomaly']:.4f}")

        ave_loss += loss.item()
        ave_loss_dict = {token_type: ave_loss_dict[token_type] + loss_dict[token_type] for token_type in loss_entity}

    ave_loss /= (step + 1)
    ave_loss_dict = {token_type: ave_loss_dict[token_type] / (step + 1) for token_type in loss_entity}
    print(f"Epoch {epoch} finished, ave_loss: {ave_loss:.4f}, ave_loss_dict: {ave_loss_dict}, perf_dict: {perf_dict}")

Epoch: 01, Step: 723, loss:0.0757, diag:0.0160, med:0.0942, pro:0.0134, lab:0.0577, anomaly:0.1970: 100%|█| 724/724 [01:34<00:00,  7.70it/s]


Epoch 1 finished, ave_loss: 0.1581, ave_loss_dict: {'diag': 0.11248204618620362, 'med': 0.15466383776694373, 'pro': 0.11071781535730671, 'lab': 0.1528678211172559, 'anomaly': 0.25975676049679025}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 02, Step: 723, loss:0.0650, diag:0.0144, med:0.0704, pro:0.0136, lab:0.0670, anomaly:0.1594: 100%|█| 724/724 [01:30<00:00,  7.97it/s]


Epoch 2 finished, ave_loss: 0.0764, ave_loss_dict: {'diag': 0.015267550339407765, 'med': 0.09004634326498812, 'pro': 0.015465260615819314, 'lab': 0.06097104582029976, 'anomaly': 0.20001955103973}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 03, Step: 723, loss:0.0782, diag:0.0168, med:0.0986, pro:0.0150, lab:0.0858, anomaly:0.1750: 100%|█| 724/724 [01:32<00:00,  7.82it/s]


Epoch 3 finished, ave_loss: 0.0731, ave_loss_dict: {'diag': 0.014956107828766108, 'med': 0.08687735711178068, 'pro': 0.015141081811026644, 'lab': 0.05972440521237929, 'anomaly': 0.18855774653715324}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.006896551724137931, 'recall': 0.0013793103448275863, 'f1': 0.0022988505747126436}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 04, Step: 723, loss:0.0660, diag:0.0107, med:0.0660, pro:0.0123, lab:0.0490, anomaly:0.1923: 100%|█| 724/724 [01:31<00:00,  7.87it/s]


Epoch 4 finished, ave_loss: 0.0715, ave_loss_dict: {'diag': 0.014798660742026336, 'med': 0.08606460803697781, 'pro': 0.014856304800992018, 'lab': 0.05886952901639826, 'anomaly': 0.18271210421267794}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.014942528735632184, 'recall': 0.01206896551724138, 'f1': 0.01129720853858785}, 'pro': {'precision': 0.0006242197253433209, 'recall': 0.0012484394506866417, 'f1': 0.0008322929671244277}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 05, Step: 723, loss:0.0661, diag:0.0142, med:0.0939, pro:0.0111, lab:0.0542, anomaly:0.1571: 100%|█| 724/724 [01:33<00:00,  7.78it/s]


Epoch 5 finished, ave_loss: 0.0703, ave_loss_dict: {'diag': 0.014651653603830719, 'med': 0.08542342252453058, 'pro': 0.01469512845425563, 'lab': 0.058307130913666926, 'anomaly': 0.17861798315087735}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.005747126436781609, 'f1': 0.008045977011494251}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0012484394506866417, 'f1': 0.0012484394506866417}, 'lab': {'precision': 0.0006666666666666666, 'recall': 0.0006666666666666666, 'f1': 0.0006666666666666666}}


Epoch: 06, Step: 723, loss:0.0767, diag:0.0199, med:0.1215, pro:0.0196, lab:0.0517, anomaly:0.1709: 100%|█| 724/724 [01:33<00:00,  7.76it/s]


Epoch 6 finished, ave_loss: 0.0697, ave_loss_dict: {'diag': 0.014537819310024545, 'med': 0.08512509018388571, 'pro': 0.014550609213109668, 'lab': 0.05792662611961859, 'anomaly': 0.17614563850902062}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.0040229885057471255, 'f1': 0.006206896551724138}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 07, Step: 723, loss:0.0643, diag:0.0159, med:0.0752, pro:0.0132, lab:0.0593, anomaly:0.1579: 100%|█| 724/724 [01:33<00:00,  7.76it/s]


Epoch 7 finished, ave_loss: 0.0689, ave_loss_dict: {'diag': 0.014443706252833427, 'med': 0.08459998958471401, 'pro': 0.014415764473687384, 'lab': 0.0576234877582907, 'anomaly': 0.17360461438434888}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.004597701149425287, 'recall': 0.006896551724137931, 'f1': 0.005517241379310345}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.004, 'recall': 0.0022222222222222222, 'f1': 0.0027777777777777775}}


Epoch: 08, Step: 723, loss:0.0770, diag:0.0148, med:0.0699, pro:0.0145, lab:0.0551, anomaly:0.2307: 100%|█| 724/724 [01:33<00:00,  7.73it/s]


Epoch 8 finished, ave_loss: 0.0684, ave_loss_dict: {'diag': 0.014323662898312862, 'med': 0.08411637022679681, 'pro': 0.014273283022934395, 'lab': 0.05736295423516746, 'anomaly': 0.17190327192785332}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.017241379310344827, 'recall': 0.007126436781609195, 'f1': 0.009195402298850575}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 09, Step: 723, loss:0.0578, diag:0.0127, med:0.0562, pro:0.0190, lab:0.0543, anomaly:0.1468: 100%|█| 724/724 [01:33<00:00,  7.77it/s]


Epoch 9 finished, ave_loss: 0.0680, ave_loss_dict: {'diag': 0.014240992695443044, 'med': 0.0837791592644229, 'pro': 0.014087697154024194, 'lab': 0.057205762971418994, 'anomaly': 0.1707328200052127}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0013333333333333333, 'recall': 0.001, 'f1': 0.0011111111111111111}}


Epoch: 10, Step: 723, loss:0.0727, diag:0.0151, med:0.0900, pro:0.0164, lab:0.0705, anomaly:0.1714: 100%|█| 724/724 [01:32<00:00,  7.81it/s]


Epoch 10 finished, ave_loss: 0.0676, ave_loss_dict: {'diag': 0.014183250299618883, 'med': 0.08347073438685408, 'pro': 0.014034020253945945, 'lab': 0.05702906120980148, 'anomaly': 0.16943914690235043}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 11, Step: 723, loss:0.0596, diag:0.0124, med:0.0769, pro:0.0172, lab:0.0587, anomaly:0.1327: 100%|█| 724/724 [01:34<00:00,  7.69it/s]


Epoch 11 finished, ave_loss: 0.0673, ave_loss_dict: {'diag': 0.014120061426469798, 'med': 0.08293099327846457, 'pro': 0.01388064899625905, 'lab': 0.05684087433703038, 'anomaly': 0.16861048159678338}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.020689655172413793, 'recall': 0.007471264367816092, 'f1': 0.010804597701149426}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0006242197253433209, 'f1': 0.0008322929671244277}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 12, Step: 723, loss:0.0753, diag:0.0156, med:0.0967, pro:0.0219, lab:0.0634, anomaly:0.1788: 100%|█| 724/724 [01:33<00:00,  7.73it/s]


Epoch 12 finished, ave_loss: 0.0673, ave_loss_dict: {'diag': 0.014073552840353293, 'med': 0.08305190461919287, 'pro': 0.013837181627287308, 'lab': 0.05678412560579362, 'anomaly': 0.168591708088942}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.006896551724137931, 'recall': 0.0022988505747126436, 'f1': 0.0034482758620689655}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 13, Step: 723, loss:0.0761, diag:0.0127, med:0.0906, pro:0.0187, lab:0.0734, anomaly:0.1851: 100%|█| 724/724 [01:34<00:00,  7.65it/s]


Epoch 13 finished, ave_loss: 0.0671, ave_loss_dict: {'diag': 0.013970363645901832, 'med': 0.08268850101209477, 'pro': 0.013693866393276752, 'lab': 0.056655174409404645, 'anomaly': 0.16830610860530185}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.006896551724137931, 'recall': 0.0019704433497536944, 'f1': 0.003065134099616858}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 14, Step: 723, loss:0.0523, diag:0.0111, med:0.0592, pro:0.0101, lab:0.0571, anomaly:0.1242: 100%|█| 724/724 [01:34<00:00,  7.64it/s]


Epoch 14 finished, ave_loss: 0.0665, ave_loss_dict: {'diag': 0.013927109650135534, 'med': 0.08242943171329425, 'pro': 0.013612905148035996, 'lab': 0.05652021236964681, 'anomaly': 0.16605757898265155}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.017241379310344827, 'recall': 0.017241379310344827, 'f1': 0.016091954022988502}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0013333333333333333, 'recall': 0.0006111111111111111, 'f1': 0.0008000000000000001}}


Epoch: 15, Step: 723, loss:0.0756, diag:0.0140, med:0.0861, pro:0.0156, lab:0.0705, anomaly:0.1916: 100%|█| 724/724 [01:33<00:00,  7.78it/s]


Epoch 15 finished, ave_loss: 0.0665, ave_loss_dict: {'diag': 0.013884375214021015, 'med': 0.08203961447045947, 'pro': 0.013583794811184812, 'lab': 0.056450249755868265, 'anomaly': 0.16648342099184818}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.006896551724137931, 'recall': 0.004597701149425287, 'f1': 0.005517241379310345}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0012484394506866417, 'f1': 0.0012484394506866417}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 16, Step: 723, loss:0.0695, diag:0.0153, med:0.0741, pro:0.0112, lab:0.0624, anomaly:0.1843: 100%|█| 724/724 [01:32<00:00,  7.79it/s]


Epoch 16 finished, ave_loss: 0.0662, ave_loss_dict: {'diag': 0.013877364958621026, 'med': 0.082084017852064, 'pro': 0.013507677200996415, 'lab': 0.05638200839086132, 'anomaly': 0.16536668187923195}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.006896551724137931, 'recall': 0.0034482758620689655, 'f1': 0.004597701149425287}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 17, Step: 723, loss:0.0720, diag:0.0130, med:0.0877, pro:0.0107, lab:0.0358, anomaly:0.2130: 100%|█| 724/724 [01:33<00:00,  7.73it/s]


Epoch 17 finished, ave_loss: 0.0663, ave_loss_dict: {'diag': 0.013829181595957576, 'med': 0.08204156271495872, 'pro': 0.013429241620251157, 'lab': 0.05623178724302442, 'anomaly': 0.1657239043292749}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.004597701149425287, 'f1': 0.006896551724137931}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0006666666666666666, 'recall': 0.0002222222222222222, 'f1': 0.0003333333333333333}}


Epoch: 18, Step: 723, loss:0.0667, diag:0.0169, med:0.0930, pro:0.0134, lab:0.0528, anomaly:0.1575: 100%|█| 724/724 [01:32<00:00,  7.79it/s]


Epoch 18 finished, ave_loss: 0.0660, ave_loss_dict: {'diag': 0.01378227830144129, 'med': 0.08205449011032753, 'pro': 0.013358786070554707, 'lab': 0.05619133794834601, 'anomaly': 0.16437488608062267}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.010344827586206896, 'f1': 0.011494252873563218}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0013333333333333333, 'recall': 0.001, 'f1': 0.0011111111111111111}}


Epoch: 19, Step: 723, loss:0.0595, diag:0.0129, med:0.0742, pro:0.0100, lab:0.0486, anomaly:0.1517: 100%|█| 724/724 [01:33<00:00,  7.75it/s]


Epoch 19 finished, ave_loss: 0.0658, ave_loss_dict: {'diag': 0.013744977477793865, 'med': 0.0817814152687788, 'pro': 0.013364041253708261, 'lab': 0.05614563344221418, 'anomaly': 0.16388561394456672}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.006896551724137931, 'recall': 0.0034482758620689655, 'f1': 0.004597701149425287}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0012484394506866417, 'f1': 0.0012484394506866417}, 'lab': {'precision': 0.002, 'recall': 0.0016666666666666668, 'f1': 0.0017777777777777776}}


Epoch: 20, Step: 723, loss:0.0654, diag:0.0141, med:0.0548, pro:0.0146, lab:0.0688, anomaly:0.1747: 100%|█| 724/724 [01:33<00:00,  7.73it/s]


Epoch 20 finished, ave_loss: 0.0656, ave_loss_dict: {'diag': 0.013743727770248088, 'med': 0.0817843642631206, 'pro': 0.013293492318341992, 'lab': 0.05605707431304685, 'anomaly': 0.16318421989926318}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0006666666666666666, 'recall': 0.0003333333333333333, 'f1': 0.0004444444444444444}}


Epoch: 21, Step: 723, loss:0.0510, diag:0.0100, med:0.0690, pro:0.0075, lab:0.0427, anomaly:0.1257: 100%|█| 724/724 [01:32<00:00,  7.82it/s]


Epoch 21 finished, ave_loss: 0.0653, ave_loss_dict: {'diag': 0.01369192408917482, 'med': 0.08160109739256663, 'pro': 0.01327363658858494, 'lab': 0.05593640772492023, 'anomaly': 0.16216709116957465}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0016666666666666668, 'recall': 0.0012222222222222222, 'f1': 0.001222222222222222}}


Epoch: 22, Step: 723, loss:0.0668, diag:0.0164, med:0.1005, pro:0.0129, lab:0.0594, anomaly:0.1447: 100%|█| 724/724 [01:33<00:00,  7.76it/s]


Epoch 22 finished, ave_loss: 0.0653, ave_loss_dict: {'diag': 0.013669426236507477, 'med': 0.08161199142775648, 'pro': 0.01318143863354434, 'lab': 0.055933688906180924, 'anomaly': 0.1618883259985493}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.020689655172413793, 'recall': 0.010344827586206895, 'f1': 0.013563218390804595}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 23, Step: 723, loss:0.0715, diag:0.0171, med:0.0895, pro:0.0125, lab:0.0579, anomaly:0.1805: 100%|█| 724/724 [01:33<00:00,  7.76it/s]


Epoch 23 finished, ave_loss: 0.0651, ave_loss_dict: {'diag': 0.01365501968478218, 'med': 0.08154712281915365, 'pro': 0.0131157285132502, 'lab': 0.055838160978331754, 'anomaly': 0.16125824988835094}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.04827586206896552, 'recall': 0.020114942528735632, 'f1': 0.02656814449917898}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 24, Step: 723, loss:0.0644, diag:0.0108, med:0.0938, pro:0.0116, lab:0.0615, anomaly:0.1445: 100%|█| 724/724 [01:32<00:00,  7.82it/s]


Epoch 24 finished, ave_loss: 0.0650, ave_loss_dict: {'diag': 0.013634868480717939, 'med': 0.08116044255374874, 'pro': 0.013097511801240936, 'lab': 0.055811851568396575, 'anomaly': 0.1611512534279191}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.006896551724137931, 'f1': 0.009195402298850575}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 25, Step: 723, loss:0.0634, diag:0.0129, med:0.0650, pro:0.0161, lab:0.0619, anomaly:0.1613: 100%|█| 724/724 [01:33<00:00,  7.75it/s]


Epoch 25 finished, ave_loss: 0.0648, ave_loss_dict: {'diag': 0.013619318792277443, 'med': 0.08126194585238208, 'pro': 0.013020116002273164, 'lab': 0.05566224116235148, 'anomaly': 0.16050875009917423}, perf_dict: {'diag': {'precision': 0.0005005005005005005, 'recall': 0.0005005005005005005, 'f1': 0.0005005005005005005}, 'med': {'precision': 0.006896551724137931, 'recall': 0.004597701149425287, 'f1': 0.005517241379310345}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0013333333333333333, 'recall': 0.0005555555555555556, 'f1': 0.0007777777777777776}}


Epoch: 26, Step: 723, loss:0.0590, diag:0.0133, med:0.0520, pro:0.0160, lab:0.0476, anomaly:0.1661: 100%|█| 724/724 [01:33<00:00,  7.75it/s]


Epoch 26 finished, ave_loss: 0.0646, ave_loss_dict: {'diag': 0.013584923489245674, 'med': 0.08106517479420532, 'pro': 0.01297143340954629, 'lab': 0.05555787550810292, 'anomaly': 0.1598972563853086}, perf_dict: {'diag': {'precision': 0.0005005005005005005, 'recall': 0.0005005005005005005, 'f1': 0.0005005005005005005}, 'med': {'precision': 0.020689655172413793, 'recall': 0.015517241379310345, 'f1': 0.016551724137931035}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 27, Step: 723, loss:0.0632, diag:0.0116, med:0.0780, pro:0.0075, lab:0.0476, anomaly:0.1713: 100%|█| 724/724 [01:33<00:00,  7.72it/s]


Epoch 27 finished, ave_loss: 0.0644, ave_loss_dict: {'diag': 0.01357589142000848, 'med': 0.08082825436323716, 'pro': 0.012906093952420955, 'lab': 0.055494267858811834, 'anomaly': 0.15928392754404586}, perf_dict: {'diag': {'precision': 0.0005005005005005005, 'recall': 0.00025025025025025025, 'f1': 0.00033366700033366696}, 'med': {'precision': 0.011494252873563218, 'recall': 0.008045977011494251, 'f1': 0.009195402298850575}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0006666666666666666, 'recall': 0.0002222222222222222, 'f1': 0.0003333333333333333}}


Epoch: 28, Step: 723, loss:0.0703, diag:0.0145, med:0.0878, pro:0.0150, lab:0.0495, anomaly:0.1847: 100%|█| 724/724 [01:33<00:00,  7.75it/s]


Epoch 28 finished, ave_loss: 0.0641, ave_loss_dict: {'diag': 0.013556429381785884, 'med': 0.08094061708890767, 'pro': 0.012829831653423813, 'lab': 0.05536891870941576, 'anomaly': 0.15801280502291673}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.02413793103448276, 'recall': 0.01103448275862069, 'f1': 0.014285714285714284}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 29, Step: 723, loss:0.0642, diag:0.0129, med:0.0994, pro:0.0094, lab:0.0497, anomaly:0.1496: 100%|█| 724/724 [01:34<00:00,  7.69it/s]


Epoch 29 finished, ave_loss: 0.0639, ave_loss_dict: {'diag': 0.01352108440758115, 'med': 0.08093592428823532, 'pro': 0.012790497396275071, 'lab': 0.05535173315287295, 'anomaly': 0.157081653595547}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.01206896551724138, 'f1': 0.01264367816091954}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0008322929671244277, 'f1': 0.0009987515605493133}, 'lab': {'precision': 0.002, 'recall': 0.0011666666666666668, 'f1': 0.0013777777777777777}}


Epoch: 30, Step: 723, loss:0.0564, diag:0.0109, med:0.0559, pro:0.0104, lab:0.0542, anomaly:0.1509: 100%|█| 724/724 [01:34<00:00,  7.68it/s]


Epoch 30 finished, ave_loss: 0.0639, ave_loss_dict: {'diag': 0.013477555690950698, 'med': 0.08081204358755883, 'pro': 0.012749420431344796, 'lab': 0.055323722406974814, 'anomaly': 0.1570588107946334}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.013793103448275862, 'recall': 0.006896551724137931, 'f1': 0.009195402298850575}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.001, 'recall': 0.0006666666666666666, 'f1': 0.0007777777777777776}}


Epoch: 31, Step: 723, loss:0.0678, diag:0.0165, med:0.0894, pro:0.0093, lab:0.0534, anomaly:0.1705: 100%|█| 724/724 [01:33<00:00,  7.70it/s]


Epoch 31 finished, ave_loss: 0.0637, ave_loss_dict: {'diag': 0.013457522592481575, 'med': 0.08086530952225568, 'pro': 0.012684423473884881, 'lab': 0.055262570192224414, 'anomaly': 0.15613399985251505}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.017241379310344827, 'recall': 0.005977011494252874, 'f1': 0.008505747126436782}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0006242197253433209, 'f1': 0.0008322929671244277}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 32, Step: 723, loss:0.0652, diag:0.0136, med:0.0922, pro:0.0120, lab:0.0570, anomaly:0.1512: 100%|█| 724/724 [01:34<00:00,  7.67it/s]


Epoch 32 finished, ave_loss: 0.0635, ave_loss_dict: {'diag': 0.013455058489180072, 'med': 0.08076611652105882, 'pro': 0.012703499805635017, 'lab': 0.05523606141751313, 'anomaly': 0.155122441275673}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.027586206896551724, 'recall': 0.009195402298850575, 'f1': 0.013563218390804599}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 33, Step: 601, loss:0.0639, diag:0.0149, med:0.0702, pro:0.0152, lab:0.0599, anomaly:0.1593:  83%|▊| 602/724 [01:18<00:15,  7.72it/s]


KeyboardInterrupt: 

In [18]:
torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_model.pt")