In [1]:
import os
import pickle
from tqdm import tqdm
import torch
from token_utils_rep import EHRTokenizer
from dataset_utils_rep import HBERTPretrainEHRDataset, batcher
from torch.utils.data import DataLoader
from HEART_rep import HBERT_Pretrain
from set_seed_utils import set_random_seed

Disabling PyTorch because PyTorch >= 2.1 is required but found 1.13.1
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [3]:
args = {
    "seed": 0,
    "dataset": "MIMIC-IV",  # MIMIC-III, MIMIC-IV
    "batch_size": 32,
    "lr": 2e-5,
    "epochs": 30,
    "encoder": "hi_edge",
    "mask_rate": 0.7,
    "anomaly_rate": 0.05,
    "anomaly_loss_weight": 1,
    "num_hidden_layers": 5,
    "num_attention_heads": 6,
    "attention_probs_dropout_prob": 0.2,
    "hidden_dropout_prob": 0.2,
    "edge_hidden_size": 32,
    "hidden_size": 288,  # must be divisible by num_attention_heads
    "intermediate_size": 288,
    "gnn_n_heads": 1,
    "gnn_temp": 1,
    "gat": "dotattn",  # dotattn, None
    "diag_med_emb": "tree",  # simple, tree
}

In [4]:
# here we only use the MIMIC dataset
args['max_visit_size'] = 15
args['predicted_token_type'] = ["diag", "med", "pro", "lab"]
args['mask_token_id'] = {"diag":3, "med":4, "pro":5, "lab":6}  # {token_type: masked_id}
args['special_tokens'] = ("[PAD]", "[CLS]", "[SEP]", 
                       "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]")
# note that here "[MASK0]", "[MASK1]", "[MASK2]", "[MASK3]" are used for masking pretraining task
# codes that are actually masked are not in the input sequence

In [5]:
full_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic.pkl"
pretrain_data_path = f"/home/lideyi/HeteroGT-cuda/data_process/{args['dataset']}-processed/mimic_pretrain.pkl" # for pretraining

In [6]:
ehr_data = pickle.load(open(full_data_path, 'rb'))

In [7]:
diag_sentences = ehr_data["ICD9_CODE"].values.tolist()
med_sentences = ehr_data["NDC"].values.tolist()
# lab was cut in 5 percentiles
lab_sentences = ehr_data["LAB_TEST"].values.tolist()
pro_sentences = ehr_data["PRO_CODE"].values.tolist()
gender_set = [["M"], ["F"]]
age_set = [[c] for c in set(ehr_data["AGE"].values.tolist())]
age_gender_set = [[str(c) + "_" + gender] \
    for c in set(ehr_data["AGE"].values.tolist()) for gender in ["M", "F"]]

In [8]:
# tokenizer used full data
tokenizer = EHRTokenizer(diag_sentences, med_sentences, 
                         lab_sentences, pro_sentences, gender_set, 
                         age_set, age_gender_set, special_tokens=args["special_tokens"])
tokenizer.build_tree()

In [9]:
ehr_pretrain_data = pickle.load(open(pretrain_data_path, 'rb'))
pretrain_dataset = HBERTPretrainEHRDataset(ehr_pretrain_data, tokenizer, 
                                  token_type=args['predicted_token_type'], 
                                  mask_rate=args['mask_rate'],
                                  anomaly_rate=args['anomaly_rate'])

In [10]:
pretrain_dataloader = DataLoader(pretrain_dataset, batch_size=args["batch_size"], 
                                 collate_fn=batcher(pad_id = tokenizer.vocab.word2id["[PAD]"], 
                                                    n_token_type=len(args["predicted_token_type"]), is_train = True),
                                 shuffle=True)

In [11]:
set_random_seed(args["seed"])

[INFO] Random seed set to 0


In [12]:
exp_name = "Pretrain-HBERT" \
    + "-" + str(args["dataset"]) \
    + "-" + str(args["encoder"]) \
    + "-" + str(args["mask_rate"]) \
    + "-" + str(args["anomaly_rate"]) \
    + "-" + str(args["anomaly_loss_weight"]) \
    + "-" + str(args["hidden_size"]) \
    + "-" + str(args["edge_hidden_size"]) \
    + "-" + str(args["num_hidden_layers"]) \
    + "-" + str(args["num_attention_heads"]) \
    + "-" + str(args["attention_probs_dropout_prob"]) \
    + "-" + str(args["hidden_dropout_prob"]) \
    + "-" + str(args["intermediate_size"]) \
    + "-" + str(args["gat"]) \
    + "-" + str(args["gnn_n_heads"]) \
    + "-" + str(args["gnn_temp"]) \
    + "-" + str(args["diag_med_emb"])
print(exp_name)

save_path = "./pretrained_models/" + exp_name
if not os.path.exists(save_path):
    os.makedirs(save_path)

Pretrain-HBERT-MIMIC-IV-hi_edge-0.7-0.05-1-288-32-5-6-0.2-0.2-288-dotattn-1-1-tree


In [13]:
args["vocab_size"] = 7 + len(tokenizer.diag_voc.id2word) + \
                len(tokenizer.pro_voc.id2word) + \
                len(tokenizer.med_voc.id2word) + \
                len(tokenizer.lab_voc.id2word) + \
                len(tokenizer.age_voc.id2word) + \
                len(tokenizer.gender_voc.id2word) + \
                len(tokenizer.age_gender_voc.id2word)

args["label_vocab_size"] = {"diag":len(tokenizer.diag_voc.id2word), 
                        "pro":len(tokenizer.pro_voc.id2word), 
                        "med":len(tokenizer.med_voc.id2word), 
                        "lab":len(tokenizer.lab_voc.id2word)}  # {token_type: vocab_size}

In [14]:
loss_entity = ["diag", "med", "pro", "lab", "anomaly"]

In [15]:
model = HBERT_Pretrain(args, tokenizer).to(device)

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=args["lr"])

In [17]:
for epoch in range(1, 1 + args["epochs"]):
    train_iter = tqdm(pretrain_dataloader, ncols=140)
    model.train()
    ave_loss, ave_loss_dict = 0., {token_type: 0. for token_type in loss_entity}

    for step, batch in enumerate(train_iter):

        batch = [x.to(device) if isinstance(x, torch.Tensor) else x for x in batch]
        loss, loss_dict, perf_dict = model(*batch)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_iter.set_description(f"Epoch:{epoch: 03d}, Step:{step: 03d}, loss:{loss.item():.4f}, diag:{loss_dict['diag']:.4f}, med:{loss_dict['med']:.4f}, pro:{loss_dict['pro']:.4f}, lab:{loss_dict['lab']:.4f}, anomaly:{loss_dict['anomaly']:.4f}")

        ave_loss += loss.item()
        ave_loss_dict = {token_type: ave_loss_dict[token_type] + loss_dict[token_type] for token_type in loss_entity}

    ave_loss /= (step + 1)
    ave_loss_dict = {token_type: ave_loss_dict[token_type] / (step + 1) for token_type in loss_entity}
    print(f"Epoch {epoch} finished, ave_loss: {ave_loss:.4f}, ave_loss_dict: {ave_loss_dict}, perf_dict: {perf_dict}")

Epoch: 01, Step: 1327, loss:0.0712, diag:0.0130, med:0.0539, pro:0.0118, lab:0.0328, anomaly:0.2446: 100%|█| 1328/1328 [02:14<00:00,  9.85it


Epoch 1 finished, ave_loss: 0.1208, ave_loss_dict: {'diag': 0.0689730423464754, 'med': 0.08907866326205612, 'pro': 0.062458091660217575, 'lab': 0.08675876539340804, 'anomaly': 0.29685477887740336}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 02, Step: 1327, loss:0.0643, diag:0.0140, med:0.0458, pro:0.0117, lab:0.0260, anomaly:0.2241: 100%|█| 1328/1328 [02:12<00:00, 10.05it


Epoch 2 finished, ave_loss: 0.0703, ave_loss_dict: {'diag': 0.01486016594010955, 'med': 0.05149918885239546, 'pro': 0.012639937157243356, 'lab': 0.0327056296960821, 'anomaly': 0.23996228897517705}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 03, Step: 1327, loss:0.0681, diag:0.0144, med:0.0537, pro:0.0136, lab:0.0278, anomaly:0.2308: 100%|█| 1328/1328 [02:11<00:00, 10.13it


Epoch 3 finished, ave_loss: 0.0676, ave_loss_dict: {'diag': 0.014820895272767148, 'med': 0.05085801701510258, 'pro': 0.012592550754232937, 'lab': 0.03165016316315047, 'anomaly': 0.22825778009796358}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 04, Step: 1327, loss:0.0681, diag:0.0150, med:0.0572, pro:0.0128, lab:0.0394, anomaly:0.2162: 100%|█| 1328/1328 [02:11<00:00, 10.07it


Epoch 4 finished, ave_loss: 0.0660, ave_loss_dict: {'diag': 0.014742796603837106, 'med': 0.05043055549129305, 'pro': 0.012515659864812371, 'lab': 0.03134061354324671, 'anomaly': 0.22098061995558352}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 05, Step: 1327, loss:0.0638, diag:0.0168, med:0.0462, pro:0.0136, lab:0.0328, anomaly:0.2095: 100%|█| 1328/1328 [02:12<00:00, 10.06it


Epoch 5 finished, ave_loss: 0.0654, ave_loss_dict: {'diag': 0.014695725636556745, 'med': 0.05021210568003267, 'pro': 0.012393992302338985, 'lab': 0.031166204073219503, 'anomaly': 0.21828199432691536}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 06, Step: 1327, loss:0.0636, diag:0.0145, med:0.0516, pro:0.0110, lab:0.0276, anomaly:0.2133: 100%|█| 1328/1328 [02:10<00:00, 10.16it


Epoch 6 finished, ave_loss: 0.0647, ave_loss_dict: {'diag': 0.014654107433912086, 'med': 0.04984376500977809, 'pro': 0.012302459376101124, 'lab': 0.031091347416716975, 'anomaly': 0.215538613767509}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 07, Step: 1327, loss:0.0617, diag:0.0128, med:0.0599, pro:0.0139, lab:0.0344, anomaly:0.1873: 100%|█| 1328/1328 [02:11<00:00, 10.11it


Epoch 7 finished, ave_loss: 0.0639, ave_loss_dict: {'diag': 0.014563119715691197, 'med': 0.04965191310080868, 'pro': 0.012219100365841604, 'lab': 0.030972728738561273, 'anomaly': 0.21186289778091463}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 08, Step: 1327, loss:0.0598, diag:0.0129, med:0.0419, pro:0.0117, lab:0.0405, anomaly:0.1917: 100%|█| 1328/1328 [02:11<00:00, 10.07it


Epoch 8 finished, ave_loss: 0.0633, ave_loss_dict: {'diag': 0.014493896189027926, 'med': 0.049241545220208634, 'pro': 0.012110452050573182, 'lab': 0.03085183493052441, 'anomaly': 0.20974087268562921}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 09, Step: 1327, loss:0.0626, diag:0.0149, med:0.0491, pro:0.0118, lab:0.0342, anomaly:0.2032: 100%|█| 1328/1328 [02:11<00:00, 10.07it


Epoch 9 finished, ave_loss: 0.0628, ave_loss_dict: {'diag': 0.01443162233658494, 'med': 0.048918638684147274, 'pro': 0.012050173026565686, 'lab': 0.03069994691478829, 'anomaly': 0.20789585948125067}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 10, Step: 1327, loss:0.0655, diag:0.0156, med:0.0474, pro:0.0133, lab:0.0287, anomaly:0.2226: 100%|█| 1328/1328 [02:11<00:00, 10.09it


Epoch 10 finished, ave_loss: 0.0623, ave_loss_dict: {'diag': 0.014345637683936182, 'med': 0.04859110452964094, 'pro': 0.011891032079605004, 'lab': 0.030513176348071592, 'anomaly': 0.20612592263573623}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 11, Step: 1327, loss:0.0633, diag:0.0144, med:0.0440, pro:0.0117, lab:0.0302, anomaly:0.2161: 100%|█| 1328/1328 [02:11<00:00, 10.07it


Epoch 11 finished, ave_loss: 0.0618, ave_loss_dict: {'diag': 0.014221808962883001, 'med': 0.048382088314764295, 'pro': 0.01176713235011459, 'lab': 0.03027299555658127, 'anomaly': 0.20416103507650185}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 12, Step: 1327, loss:0.0649, diag:0.0129, med:0.0517, pro:0.0118, lab:0.0330, anomaly:0.2152: 100%|█| 1328/1328 [02:12<00:00, 10.02it


Epoch 12 finished, ave_loss: 0.0614, ave_loss_dict: {'diag': 0.01412578816076808, 'med': 0.0482404395511531, 'pro': 0.011690864326146218, 'lab': 0.030013604166760414, 'anomaly': 0.20269584837537932}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 13, Step: 1327, loss:0.0634, diag:0.0126, med:0.0484, pro:0.0095, lab:0.0316, anomaly:0.2151: 100%|█| 1328/1328 [02:12<00:00, 10.05it


Epoch 13 finished, ave_loss: 0.0610, ave_loss_dict: {'diag': 0.014041992065676275, 'med': 0.04802701954943049, 'pro': 0.011573861304709845, 'lab': 0.02981112036203225, 'anomaly': 0.20136033975889525}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 14, Step: 1327, loss:0.0570, diag:0.0132, med:0.0408, pro:0.0108, lab:0.0290, anomaly:0.1910: 100%|█| 1328/1328 [02:13<00:00,  9.97it


Epoch 14 finished, ave_loss: 0.0604, ave_loss_dict: {'diag': 0.013984889670763254, 'med': 0.04777520284875778, 'pro': 0.011487996860699033, 'lab': 0.02962792443296681, 'anomaly': 0.19911793546459403}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 15, Step: 1327, loss:0.0598, diag:0.0158, med:0.0522, pro:0.0109, lab:0.0290, anomaly:0.1911: 100%|█| 1328/1328 [02:12<00:00, 10.00it


Epoch 15 finished, ave_loss: 0.0599, ave_loss_dict: {'diag': 0.013921655928407493, 'med': 0.04773969448687711, 'pro': 0.011389627878088504, 'lab': 0.029435372468173862, 'anomaly': 0.19706468101231808}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0012484394506866417, 'f1': 0.0012484394506866417}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 16, Step: 1327, loss:0.0657, diag:0.0144, med:0.0521, pro:0.0106, lab:0.0324, anomaly:0.2192: 100%|█| 1328/1328 [02:13<00:00,  9.97it


Epoch 16 finished, ave_loss: 0.0595, ave_loss_dict: {'diag': 0.013895390847002167, 'med': 0.04768928275611925, 'pro': 0.011351751209887484, 'lab': 0.029276869771020286, 'anomaly': 0.19503965080682054}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 17, Step: 1327, loss:0.0579, diag:0.0143, med:0.0448, pro:0.0118, lab:0.0266, anomaly:0.1919: 100%|█| 1328/1328 [02:13<00:00,  9.98it


Epoch 17 finished, ave_loss: 0.0590, ave_loss_dict: {'diag': 0.013857781923521894, 'med': 0.04756830111202077, 'pro': 0.011252506442525121, 'lab': 0.02911116382533527, 'anomaly': 0.19335649759475007}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 19, Step: 1327, loss:0.0605, diag:0.0153, med:0.0480, pro:0.0123, lab:0.0298, anomaly:0.1973: 100%|█| 1328/1328 [02:13<00:00,  9.97it


Epoch 19 finished, ave_loss: 0.0583, ave_loss_dict: {'diag': 0.013834418052897215, 'med': 0.04745322664503951, 'pro': 0.011158047287433172, 'lab': 0.028919226706661957, 'anomaly': 0.1903143968567791}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 20, Step: 1238, loss:0.0545, diag:0.0124, med:0.0488, pro:0.0118, lab:0.0287, anomaly:0.1708:  93%|▉| 1237/1328 [02:03<00:08, 10.11itIOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

Epoch: 21, Step: 1327, loss:0.0581, diag:0.0148, med:0.0480, pro:0.0100, lab:0.0256, anomaly:0.1921: 100%|█| 1328/1328 [02:13<00:00,  9.98it


Epoch 21 finished, ave_loss: 0.0572, ave_loss_dict: {'diag': 0.013797474277475063, 'med': 0.047337651586564014, 'pro': 0.011043722727279886, 'lab': 0.028708335348820113, 'anomaly': 0.18528252699221653}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.0012484394506866417, 'f1': 0.0012484394506866417}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 22, Step: 1327, loss:0.0546, diag:0.0142, med:0.0487, pro:0.0102, lab:0.0264, anomaly:0.1733: 100%|█| 1328/1328 [02:12<00:00, 10.02it


Epoch 22 finished, ave_loss: 0.0568, ave_loss_dict: {'diag': 0.013762240645234439, 'med': 0.04721400058121387, 'pro': 0.01100338060697765, 'lab': 0.028640210140406727, 'anomaly': 0.1834716462833156}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.007142857142857143, 'recall': 0.0014285714285714286, 'f1': 0.0023809523809523807}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 23, Step: 1327, loss:0.0543, diag:0.0118, med:0.0482, pro:0.0108, lab:0.0254, anomaly:0.1750: 100%|█| 1328/1328 [02:12<00:00, 10.04it


Epoch 23 finished, ave_loss: 0.0564, ave_loss_dict: {'diag': 0.013711753294309488, 'med': 0.04735630620775901, 'pro': 0.010969839341439172, 'lab': 0.028510649005478495, 'anomaly': 0.1815800489577274}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.014285714285714285, 'recall': 0.008571428571428572, 'f1': 0.009523809523809523}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 24, Step: 1327, loss:0.0474, diag:0.0136, med:0.0432, pro:0.0116, lab:0.0307, anomaly:0.1377: 100%|█| 1328/1328 [02:13<00:00,  9.98it


Epoch 24 finished, ave_loss: 0.0561, ave_loss_dict: {'diag': 0.013692347588107348, 'med': 0.04725267799818013, 'pro': 0.010926439460501614, 'lab': 0.028453356666344565, 'anomaly': 0.1803667917454907}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 25, Step: 1327, loss:0.0549, diag:0.0138, med:0.0406, pro:0.0119, lab:0.0279, anomaly:0.1802: 100%|█| 1328/1328 [02:11<00:00, 10.08it


Epoch 25 finished, ave_loss: 0.0557, ave_loss_dict: {'diag': 0.013706982415590256, 'med': 0.04720100506309824, 'pro': 0.010882470375964844, 'lab': 0.02837076169404998, 'anomaly': 0.17858665299613075}, perf_dict: {'diag': {'precision': 0.0005042864346949068, 'recall': 0.0002521432173474534, 'f1': 0.0003361909564632711}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 26, Step: 1327, loss:0.0589, diag:0.0122, med:0.0462, pro:0.0122, lab:0.0285, anomaly:0.1953: 100%|█| 1328/1328 [02:13<00:00,  9.93it


Epoch 26 finished, ave_loss: 0.0553, ave_loss_dict: {'diag': 0.013688759923160794, 'med': 0.04706388638994421, 'pro': 0.010826237900154552, 'lab': 0.028353406373041134, 'anomaly': 0.17635403409690023}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 27, Step: 1327, loss:0.0520, diag:0.0149, med:0.0484, pro:0.0111, lab:0.0267, anomaly:0.1591: 100%|█| 1328/1328 [02:13<00:00,  9.97it


Epoch 27 finished, ave_loss: 0.0551, ave_loss_dict: {'diag': 0.01366958532381399, 'med': 0.04708934000536171, 'pro': 0.010796801793369386, 'lab': 0.028248529704512065, 'anomaly': 0.1755931661632585}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0012484394506866417, 'recall': 0.00041614648356221387, 'f1': 0.0006242197253433209}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 28, Step: 1327, loss:0.0481, diag:0.0145, med:0.0355, pro:0.0094, lab:0.0280, anomaly:0.1530: 100%|█| 1328/1328 [02:13<00:00,  9.94it


Epoch 28 finished, ave_loss: 0.0548, ave_loss_dict: {'diag': 0.0136431014615515, 'med': 0.047002285962969244, 'pro': 0.010769678655652756, 'lab': 0.028191513064210808, 'anomaly': 0.17457573689767217}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 29, Step: 1327, loss:0.0512, diag:0.0142, med:0.0490, pro:0.0103, lab:0.0361, anomaly:0.1462: 100%|█| 1328/1328 [02:13<00:00,  9.97it


Epoch 29 finished, ave_loss: 0.0548, ave_loss_dict: {'diag': 0.01361300334045439, 'med': 0.04704930072268808, 'pro': 0.010711967795722591, 'lab': 0.028101788693477948, 'anomaly': 0.17440529254612974}, perf_dict: {'diag': {'precision': 0.0005042864346949068, 'recall': 0.00010085728693898135, 'f1': 0.00016809547823163555}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}


Epoch: 30, Step: 1327, loss:0.0550, diag:0.0141, med:0.0522, pro:0.0101, lab:0.0266, anomaly:0.1721: 100%|█| 1328/1328 [02:13<00:00,  9.92it

Epoch 30 finished, ave_loss: 0.0544, ave_loss_dict: {'diag': 0.013597680361666548, 'med': 0.046992460204041506, 'pro': 0.010665677047084949, 'lab': 0.028043391355537777, 'anomaly': 0.17259993098384466}, perf_dict: {'diag': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'med': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'pro': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}, 'lab': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0}}





In [18]:
torch.save(model.cpu().state_dict(), f"{save_path}/pretrained_model.pt")