In [None]:
# !pip install pytorch-pretrained-bert

###  This Notebook is used to compare BEHART vs BOC+lr to see how it works and adding some attention explaination.

In [None]:
import sys
sys.path.insert(0, '../')

from common.common import create_folder,load_obj
# from data import bert,dataframe,utils
from dataLoader.utils import seq_padding,code2index, position_idx, index_seg
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
from torch.utils.data.dataset import Dataset
import os
import torch
import torch.nn as nn
import pytorch_pretrained_bert as Bert
from model.utils import age_vocab
from model import optimiser
import sklearn.metrics as skm
import math
from torch.utils.data.dataset import Dataset
import random
import numpy as np
import torch
import time

# from data.utils import seq_padding, index_seg, position_idx, age_vocab, random_mask, code2index
# from sklearn.metrics import roc_auc_score

### Adding parameters

In [None]:
file_config = {
    'vocab': '../../outputs/vocab',
    'train': '../../outputs/nextvisit_train_idx.parquet',
    'test':  '../../outputs/nextvisit_test_idx.parquet',
}

optim_config = {
    'lr': 3e-5,
    'warmup_proportion': 0.1,
    'weight_decay': 0.01
}

global_params = {
    'batch_size': 128,
    'gradient_accumulation_steps': 1,
    'device': 'cuda:0' if torch.cuda.is_available() else 'cpu',
    'output_dir': '../../outputs/ckpts',
    'best_name': 'nextvisit_12m.pt',
    'save_model': True,
    'max_len_seq': 100,
    'max_age': 110,
    'month': 1,
    'age_symbol': None,
    'min_visit': 5
}

pretrainModel = "../../outputs/ckpts/mlm_bert.pt"


In [None]:
create_folder(global_params['output_dir'])

In [None]:
BertVocab = load_obj(file_config['vocab'])
ageVocab, _ = age_vocab(max_age=global_params['max_age'], mon=global_params['month'], symbol=global_params['age_symbol'])

In [None]:
def ensure_special_tokens(token2idx):
    base = ['[UNK]', '[PAD]', '[CLS]', '[SEP]', '[MASK]']
    for sp in base:
        if sp not in token2idx:
            token2idx[sp] = len(token2idx)

    alias = {'UNK':'[UNK]','PAD':'[PAD]','CLS':'[CLS]','SEP':'[SEP]','MASK':'[MASK]'}
    for a,b in alias.items():
        if a not in token2idx:
            token2idx[a] = token2idx[b]
    return token2idx

BertVocab['token2idx'] = ensure_special_tokens(BertVocab['token2idx'])

def format_label_vocab(token2idx):
    token2idx = token2idx.copy()
    # remove specials from label set
    for sp in ['PAD','[PAD]','SEP','[SEP]','CLS','[CLS]','MASK','[MASK]','UNK','[UNK]']:
        if sp in token2idx:
            del token2idx[sp]
    labelVocab = {tok:i for i,tok in enumerate(token2idx.keys())}
    return labelVocab

Vocab_diag = format_label_vocab(BertVocab['token2idx'])

print("✅ word vocab size:", len(BertVocab['token2idx']))
print("✅ label vocab size:", len(Vocab_diag))
print("✅ UNK id:", BertVocab['token2idx']['UNK'])


In [None]:
# ===== Cell: MultiLabelBinarizer =====
from sklearn.preprocessing import MultiLabelBinarizer

mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))
mlb.fit([[i] for i in list(Vocab_diag.values())])

print("✅ mlb fitted, n_labels =", len(mlb.classes_))

In [None]:
model_config = {
    'vocab_size': len(BertVocab['token2idx'].keys()), # number of disease + symbols for word embedding
    'hidden_size': 288, # word embedding and seg embedding hidden size
    'seg_vocab_size': 2, # number of vocab for seg embedding
    'age_vocab_size': len(ageVocab.keys()), # number of vocab for age embedding
    'max_position_embedding': global_params['max_len_seq'], # maximum number of tokens
    'hidden_dropout_prob': 0.1, # dropout rate
    'num_hidden_layers': 6, # number of multi-head attention layers required
    'num_attention_heads': 12, # number of attention heads
    'attention_probs_dropout_prob': 0.1, # multi-head attention dropout rate
    'intermediate_size': 512, # the size of the "intermediate" layer in the transformer encoder
    'hidden_act': 'gelu', # The non-linear activation function in the encoder and the pooler "gelu", 'relu', 'swish' are supported
    'initializer_range': 0.02, # parameter weight initializer range
}

feature_dict = {
    'word':True,
    'seg':True,
    'age':True,
    'position': True
}

### Set up model

In [None]:
import numpy as np

def code2index_safe(tokens, token2idx):
    """
    tokens: list/np.array of tokens (strings)
    token2idx: dict token->id (may have UNK or [UNK] or neither)
    """
    # normalize to list
    if isinstance(tokens, np.ndarray):
        tokens_list = tokens.tolist()
    else:
        tokens_list = list(tokens)

    # pick unk id safely
    if "UNK" in token2idx:
        unk_id = token2idx["UNK"]
    elif "[UNK]" in token2idx:
        unk_id = token2idx["[UNK]"]
    else:
        # last resort: use 0
        unk_id = 0

    out = []
    for t in tokens_list:
        # ensure token is a python str key, not numpy scalar
        if not isinstance(t, str):
            t = str(t)
        out.append(token2idx.get(t, unk_id))
    return tokens_list, out


In [None]:
import numpy as np
import torch
from torch.utils.data.dataset import Dataset

class NextVisit(Dataset):
    def __init__(self, token2idx, diag2idx, age2idx, dataframe, max_len, max_age=110, min_visit=5):
        self.vocab = token2idx
        self.label_vocab = diag2idx
        self.age2idx = age2idx
        self.max_len = int(max_len)

        self.code = dataframe["code"]
        self.age = dataframe["age"]
        self.label = dataframe["label"]

        if "patid" in dataframe.columns:
            self.patid = dataframe["patid"]
        elif "subject_id" in dataframe.columns:
            self.patid = dataframe["subject_id"]
        else:
            self.patid = dataframe.index

        # pad id (try PAD/[PAD], else 0)
        self.pad_id = self.vocab.get("PAD", self.vocab.get("[PAD]", 0))

        # (optional) make sure special aliases exist in THIS dict
        if "UNK" not in self.vocab and "[UNK]" in self.vocab:
            self.vocab["UNK"] = self.vocab["[UNK]"]
        if "PAD" not in self.vocab and "[PAD]" in self.vocab:
            self.vocab["PAD"] = self.vocab["[PAD]"]

    def __len__(self):
        return len(self.code)

    def __getitem__(self, index):
        codes = list(self.code[index])
        ages  = list(self.age[index])
        label = list(self.label[index])
        patid = int(self.patid[index])

        # truncate to max_len-1 then add CLS
        codes = codes[-(self.max_len - 1):]
        ages  = ages[-(self.max_len - 1):]

        if len(codes) == 0:
            codes = ["CLS"]
            ages  = [0]
        else:
            if codes[0] != "SEP":
                codes = ["CLS"] + codes
                ages  = [ages[0] if len(ages) > 0 else 0] + ages
            else:
                codes[0] = "CLS"

        # visit_ids
        visit_ids = []
        v = 0
        for t in codes:
            visit_ids.append(v)
            if t == "SEP":
                v += 1

        # keep raw tokens
        code_tokens = codes[:]

        # SAFE convert to ids (no KeyError)
        _, code_ids  = code2index_safe(np.array(codes, dtype=object), self.vocab)
        _, label_ids = code2index_safe(label, self.label_vocab)

        # robust age -> id mapping
        mapped_ages = []
        for a in ages:
            try:
                ai = int(a)
            except Exception:
                ai = 0
            if ai in self.age2idx:
                mapped_ages.append(ai)
            elif str(ai) in self.age2idx:
                mapped_ages.append(str(ai))
            else:
                mapped_ages.append(0)

        age_ids = seq_padding(mapped_ages, self.max_len, token2idx=self.age2idx)

        # pad to max_len
        code_ids  = seq_padding(code_ids,  self.max_len, symbol=self.pad_id)
        label_ids = seq_padding(label_ids, self.max_len, symbol=-1)

        if len(visit_ids) < self.max_len:
            visit_ids = visit_ids + [-1] * (self.max_len - len(visit_ids))
        else:
            visit_ids = visit_ids[:self.max_len]

        if len(code_tokens) < self.max_len:
            code_tokens = code_tokens + ["PAD"] * (self.max_len - len(code_tokens))
        else:
            code_tokens = code_tokens[:self.max_len]

        position = position_idx(code_ids)
        segment  = index_seg(code_ids)

        # attMask
        attMask = (np.array(code_ids) != self.pad_id).astype(np.float32)

        return (
            torch.LongTensor(age_ids),
            torch.LongTensor(code_ids),
            torch.LongTensor(position),
            torch.LongTensor(segment),
            torch.FloatTensor(attMask),
            torch.LongTensor(label_ids),
            torch.LongTensor([patid]),
            torch.LongTensor(visit_ids),
            code_tokens
        )


    def __len__(self):
        return len(self.code)


In [None]:
class BertConfig(Bert.modeling.BertConfig):
    def __init__(self, config):
        super().__init__(
            vocab_size_or_config_json_file=config.get('vocab_size'),
            hidden_size=config['hidden_size'],
            num_hidden_layers=config.get('num_hidden_layers'),
            num_attention_heads=config.get('num_attention_heads'),
            intermediate_size=config.get('intermediate_size'),
            hidden_act=config.get('hidden_act'),
            hidden_dropout_prob=config.get('hidden_dropout_prob'),
            attention_probs_dropout_prob=config.get('attention_probs_dropout_prob'),
            max_position_embeddings=config.get('max_position_embedding'),
            initializer_range=config.get('initializer_range'),
        )
        self.seg_vocab_size = config.get('seg_vocab_size')
        self.age_vocab_size = config.get('age_vocab_size')

class BertEmbeddings(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size)
        self.segment_embeddings = nn.Embedding(config.seg_vocab_size, config.hidden_size)
        self.age_embeddings = nn.Embedding(config.age_vocab_size, config.hidden_size)
        self.posi_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size).from_pretrained(
            embeddings=self._init_posi_embedding(config.max_position_embeddings, config.hidden_size)
        )

        self.LayerNorm = Bert.modeling.BertLayerNorm(config.hidden_size, eps=1e-12)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, word_ids, age_ids=None, seg_ids=None, posi_ids=None):
        if seg_ids is None: seg_ids = torch.zeros_like(word_ids)
        if age_ids is None: age_ids = torch.zeros_like(word_ids)
        if posi_ids is None: posi_ids = torch.zeros_like(word_ids)

        word_embed = self.word_embeddings(word_ids)
        segment_embed = self.segment_embeddings(seg_ids)
        age_embed = self.age_embeddings(age_ids)
        posi_embed = self.posi_embeddings(posi_ids)

        embeddings = word_embed + segment_embed + age_embed + posi_embed
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

    def _init_posi_embedding(self, max_position_embedding, hidden_size):
        lookup_table = np.zeros((max_position_embedding, hidden_size), dtype=np.float32)
        for pos in range(max_position_embedding):
            for i in range(0, hidden_size, 2):
                lookup_table[pos, i] = np.sin(pos/(10000**(2*i/hidden_size)))
            for i in range(1, hidden_size, 2):
                lookup_table[pos, i] = np.cos(pos/(10000**(2*i/hidden_size)))
        return torch.tensor(lookup_table)

class BertModel(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.embeddings = BertEmbeddings(config)
        self.encoder = Bert.modeling.BertEncoder(config)
        self.pooler = Bert.modeling.BertPooler(config)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None):
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        embedding_output = self.embeddings(input_ids, age_ids, seg_ids, posi_ids)
        encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=False)
        sequence_output = encoded_layers[-1]
        pooled_output = self.pooler(sequence_output)
        return sequence_output, pooled_output

class BertForMultiLabelPrediction(Bert.modeling.BertPreTrainedModel):
    def __init__(self, config, num_labels):
        super().__init__(config)
        self.num_labels = num_labels
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, num_labels)
        self.apply(self.init_bert_weights)

    def forward(self, input_ids, age_ids=None, seg_ids=None, posi_ids=None, attention_mask=None, labels=None):
        _, pooled_output = self.bert(input_ids, age_ids, seg_ids, posi_ids, attention_mask)
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        if labels is not None:
            loss_fct = nn.MultiLabelSoftMarginLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1, self.num_labels))
            return loss, logits
        return logits


### Load data

In [None]:
tok = BertVocab["token2idx"]

def force_alias(tok):
    # ensure base specials exist
    if "[UNK]" not in tok: tok["[UNK]"] = len(tok)
    if "[PAD]" not in tok: tok["[PAD]"] = len(tok)
    if "[CLS]" not in tok: tok["[CLS]"] = len(tok)
    if "[SEP]" not in tok: tok["[SEP]"] = len(tok)
    if "[MASK]" not in tok: tok["[MASK]"] = len(tok)

    # ensure aliases exist (THIS FIXES YOUR ERROR)
    tok["UNK"]  = tok.get("UNK",  tok["[UNK]"])
    tok["PAD"]  = tok.get("PAD",  tok["[PAD]"])
    tok["CLS"]  = tok.get("CLS",  tok["[CLS]"])
    tok["SEP"]  = tok.get("SEP",  tok["[SEP]"])
    tok["MASK"] = tok.get("MASK", tok["[MASK]"])
    return tok

BertVocab["token2idx"] = force_alias(tok)

print("✅ has UNK:", "UNK" in BertVocab["token2idx"], " id=", BertVocab["token2idx"]["UNK"])
print("✅ has PAD:", "PAD" in BertVocab["token2idx"], " id=", BertVocab["token2idx"]["PAD"])


In [None]:
def load_pretrained_partial(model, ckpt_path, device='cpu'):
    ckpt = torch.load(ckpt_path, map_location=device)
    if isinstance(ckpt, dict) and "state_dict" in ckpt:
        ckpt = ckpt["state_dict"]

    model_state = model.state_dict()
    new_state = {}

    for k, v in ckpt.items():
        if k not in model_state:
            continue
        v = v.to(model_state[k].dtype)

        if "word_embeddings.weight" in k:
            new_embed = model_state[k].clone()
            n_copy = min(v.shape[0], new_embed.shape[0])
            new_embed[:n_copy] = v[:n_copy]
            new_state[k] = new_embed
            continue

        if "posi_embeddings.weight" in k:
            new_pos = model_state[k].clone()
            n_copy = min(v.shape[0], new_pos.shape[0])
            new_pos[:n_copy] = v[:n_copy]
            new_state[k] = new_pos
            continue

        if model_state[k].shape == v.shape:
            new_state[k] = v

    model_state.update(new_state)
    model.load_state_dict(model_state)
    print("✅ MLM partial load done.")
    return model

def format_label_vocab(token2idx):
    token2idx = token2idx.copy()
    for sp in ['PAD','[PAD]','SEP','[SEP]','CLS','[CLS]','MASK','[MASK]','UNK','[UNK]']:
        if sp in token2idx:
            del token2idx[sp]
    return {tok:i for i,tok in enumerate(token2idx.keys())}

Vocab_diag = format_label_vocab(BertVocab["token2idx"])
print("✅ label vocab size:", len(Vocab_diag))



In [None]:
# ===== Cell: rebuild dfs + loaders =====
import pandas as pd
from torch.utils.data import DataLoader

# Try reading with fastparquet engine first (works around pyarrow extension type issues).
# Falls back to pandas default engine if fastparquet is not available or fails.
try:
    train_df = pd.read_parquet(file_config['train'], engine='fastparquet').reset_index(drop=True)
    test_df  = pd.read_parquet(file_config['test'],  engine='fastparquet').reset_index(drop=True)
except Exception as e:
    print('fastparquet engine failed or not installed, falling back to default engine:', e)
    train_df = pd.read_parquet(file_config['train']).reset_index(drop=True)
    test_df  = pd.read_parquet(file_config['test']).reset_index(drop=True)

train_df["label"] = train_df["label"].apply(lambda x: list(set(list(x))))
test_df["label"]  = test_df["label"].apply(lambda x: list(set(list(x))))

if "patid" not in train_df.columns: train_df["patid"] = train_df["subject_id"]
if "patid" not in test_df.columns:  test_df["patid"]  = test_df["subject_id"]

trainset = NextVisit(BertVocab['token2idx'], Vocab_diag, ageVocab, train_df, global_params['max_len_seq'])
testset  = NextVisit(BertVocab['token2idx'], Vocab_diag, ageVocab, test_df,  global_params['max_len_seq'])

trainload = DataLoader(trainset, batch_size=global_params['batch_size'], shuffle=True,  num_workers=0)
testload  = DataLoader(testset,  batch_size=global_params['batch_size'], shuffle=False, num_workers=0)

print("✅ loaders ready")
x = trainset[0]
print("✅ trainset[0] ok, fields:", len(x))


### Set up model

In [None]:
from sklearn.preprocessing import MultiLabelBinarizer
import sklearn.metrics as skm
device = global_params['device']
print(device)
conf = BertConfig(model_config)
model = BertForMultiLabelPrediction(conf, num_labels=len(Vocab_diag)).to(device)

model = load_pretrained_partial(model, pretrainModel, device=device).to(device)

mlb = MultiLabelBinarizer(classes=list(Vocab_diag.values()))
mlb.fit([[i] for i in list(Vocab_diag.values())])

optim = optimiser.adam(params=list(model.named_parameters()), config=optim_config)
print("✅ model/optim ready.")


### Evaluation Matrix



In [None]:
def precision_samples(logits, label):
    sig = nn.Sigmoid()
    output = sig(logits).detach().cpu().numpy()
    label  = label.detach().cpu().numpy()
    return skm.average_precision_score(label, output, average='samples')

# ===== Cell: Train/Eval (9-field batch) =====
import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score

device = global_params["device"]

def train_one_epoch(e, log_every=200):
    model.train()
    total_loss = 0.0
    n_steps = 0

    for step, batch in enumerate(trainload):
        # ✅ NEW: unpack 9 fields
        age_ids, input_ids, posi_ids, segment_ids, attMask, label_ids, patid, visit_ids, code_tokens = batch

        # label_ids 是 padding 的序列标签，不用于训练
        # 用 raw labels -> multi-hot 才对
        # 但是我们 dataset 里第6个返回的是 label_ids（被 safe code2index 处理过）
        # 所以这里改成：从 label_ids 里恢复有效 label（去掉 -1）再 multi-hot
        # label_ids shape: [B, L]
        label_ids_np = label_ids.numpy()
        raw_labels = []
        for i in range(label_ids_np.shape[0]):
            labs = [int(x) for x in label_ids_np[i].tolist() if int(x) >= 0]
            raw_labels.append(list(set(labs)))

        targets = torch.tensor(mlb.transform(raw_labels), dtype=torch.float32).to(device)

        age_ids    = age_ids.to(device)
        input_ids  = input_ids.to(device)
        posi_ids   = posi_ids.to(device)
        segment_ids= segment_ids.to(device)
        attMask    = attMask.to(device)

        loss, logits = model(
            input_ids,
            age_ids,
            segment_ids,
            posi_ids,
            attention_mask=attMask,
            labels=targets
        )

        if global_params.get("gradient_accumulation_steps", 1) > 1:
            loss = loss / global_params["gradient_accumulation_steps"]

        loss.backward()

        if (step + 1) % global_params.get("gradient_accumulation_steps", 1) == 0:
            optim.step()
            optim.zero_grad()

        total_loss += loss.item()
        n_steps += 1

        if step % log_every == 0:
            with torch.no_grad():
                prob = torch.sigmoid(logits).detach().cpu().numpy()
                y    = targets.detach().cpu().numpy()
                aps  = average_precision_score(y, prob, average="micro")
            print(f"[Train] epoch {e} step {step} | loss={loss.item():.4f} | APS(micro)={aps:.4f}")

    return total_loss / max(n_steps, 1)


@torch.no_grad()
def evaluation():
    model.eval()
    y_prob_list = []
    y_true_list = []

    for batch in testload:
        # ✅ NEW: unpack 9 fields
        age_ids, input_ids, posi_ids, segment_ids, attMask, label_ids, patid, visit_ids, code_tokens = batch

        # same: label_ids -> raw labels -> multi-hot
        label_ids_np = label_ids.numpy()
        raw_labels = []
        for i in range(label_ids_np.shape[0]):
            labs = [int(x) for x in label_ids_np[i].tolist() if int(x) >= 0]
            raw_labels.append(list(set(labs)))

        targets = torch.tensor(mlb.transform(raw_labels), dtype=torch.float32).to(device)

        age_ids     = age_ids.to(device)
        input_ids   = input_ids.to(device)
        posi_ids    = posi_ids.to(device)
        segment_ids = segment_ids.to(device)
        attMask     = attMask.to(device)

        loss, logits = model(
            input_ids,
            age_ids,
            segment_ids,
            posi_ids,
            attention_mask=attMask,
            labels=targets
        )

        y_prob_list.append(torch.sigmoid(logits).cpu().numpy())
        y_true_list.append(targets.cpu().numpy())

    y_prob = np.vstack(y_prob_list)
    y_true = np.vstack(y_true_list)

    aps   = average_precision_score(y_true, y_prob, average="micro")
    auroc = roc_auc_score(y_true, y_prob, average="micro")

    return aps, auroc


@torch.no_grad()
def collect_test_probs():
    model.eval()
    y_prob_list, y_true_list = [], []

    for batch in testload:
        age_ids, input_ids, posi_ids, segment_ids, attMask, label_ids, patid, visit_ids, code_tokens = batch

        # label_ids -> raw labels -> multi-hot
        label_ids_np = label_ids.numpy()
        raw_labels = []
        for i in range(label_ids_np.shape[0]):
            labs = [int(x) for x in label_ids_np[i].tolist() if int(x) >= 0]
            raw_labels.append(list(set(labs)))

        targets = torch.tensor(mlb.transform(raw_labels), dtype=torch.float32).to(device)

        age_ids     = age_ids.to(device)
        input_ids   = input_ids.to(device)
        posi_ids    = posi_ids.to(device)
        segment_ids = segment_ids.to(device)
        attMask     = attMask.to(device)

        loss, logits = model(
            input_ids,
            age_ids,
            segment_ids,
            posi_ids,
            attention_mask=attMask,
            labels=targets
        )

        y_prob_list.append(torch.sigmoid(logits).cpu().numpy())
        y_true_list.append(targets.cpu().numpy())

    y_prob = np.vstack(y_prob_list)
    y_true = np.vstack(y_true_list)
    return y_true, y_prob


In [None]:
ckpt_path = os.path.join(global_params["output_dir"], global_params["best_name"])
print("Loading:", ckpt_path)

state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state, strict=True)
model.eval()

print("✅ Loaded trained model for explain/plotting")


In [None]:
class NextVisitExplain(torch.utils.data.Dataset):
    def __init__(self, token2idx, dataframe, max_len):
        self.vocab = token2idx
        self.code = dataframe["code"]
        self.age  = dataframe["age"]
        self.label = dataframe["label"]
        self.patid = dataframe["patid"]
        self.max_len = max_len
        self.pad_id = self.vocab["PAD"]

    def __getitem__(self, idx):
        codes = list(self.code[idx])[-(self.max_len-1):]
        ages  = list(self.age[idx])[-(self.max_len-1):]

        codes = ["CLS"] + codes
        ages  = [ages[0] if len(ages)>0 else 0] + ages

        visit_ids = []
        v = 0
        for c in codes:
            visit_ids.append(v)
            if c == "SEP":
                v += 1

        code_tokens = (codes + ["PAD"]*self.max_len)[:self.max_len]
        _, code_ids = code2index(np.array(codes), self.vocab)
        code_ids = seq_padding(code_ids, self.max_len, symbol=self.pad_id)

        attMask = (np.array(code_ids) != self.pad_id).astype(np.float32)

        return (
            torch.LongTensor(code_ids),
            torch.FloatTensor(attMask),
            torch.LongTensor((visit_ids + [-1]*(self.max_len-len(visit_ids)))[:self.max_len]),
            code_tokens,
            self.label[idx],
            int(self.patid[idx])
        )

    def __len__(self):
        return len(self.code)

def collate_explain(batch):
    code_ids, att, visit_ids, code_tokens, labels, patids = zip(*batch)
    return (
        torch.stack(code_ids),
        torch.stack(att),
        torch.stack(visit_ids),
        list(code_tokens),
        list(labels),
        list(patids)
    )

testset_explain = NextVisitExplain(BertVocab["token2idx"], test_df, global_params["max_len_seq"])
testload_explain = DataLoader(testset_explain, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_explain)

print("✅ testload_explain ready")


In [None]:
batch = next(iter(testload_explain))
code_ids, attMask, visit_ids, code_tokens, labels, patids = batch

print("code_ids:", code_ids.shape)
print("visit_ids:", visit_ids.shape)
print("code_tokens[0][:20]:", code_tokens[0][:20])
print("patid[0]:", patids[0], "| attSum:", attMask[0].sum().item())


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from sklearn.metrics import average_precision_score, roc_auc_score

def create_folder(path):
    os.makedirs(path, exist_ok=True)

OUT_DIR = global_params.get("output_dir", "./outputs")
create_folder(OUT_DIR)

def sigmoid_np(x):
    return 1 / (1 + np.exp(-x))

def compute_micro_metrics(y_true, y_logits):
    """
    y_true: [N, C] 0/1
    y_logits: [N, C] raw logits
    returns: (APS_micro, AUROC_micro)
    """
    y_prob = sigmoid_np(y_logits)
    aps = average_precision_score(y_true, y_prob, average="micro")
    auroc = roc_auc_score(y_true, y_prob, average="micro")
    return aps, auroc


In [None]:
EPOCHS = 15

aps_list = []
auroc_list = []
loss_list = []

best_aps = -1

for e in range(EPOCHS):
    tr_loss = train_one_epoch(e, log_every=200)  # 你已有的 train func
    aps, auroc = evaluation()                    # 你已有的 eval func (回傳 aps, auroc)

    loss_list.append(float(tr_loss))
    aps_list.append(float(aps))
    auroc_list.append(float(auroc))

    print(f"==> epoch {e:02d}: train_loss={tr_loss:.4f} APS(micro)={aps:.4f} AUROC(micro)={auroc:.4f}")

    # 存 best model（可選）
    if aps > best_aps and global_params.get("save_model", True):
        best_aps = aps
        ckpt_path = os.path.join(OUT_DIR, global_params.get("best_name", "best.pt"))
        torch.save((model.module if hasattr(model, "module") else model).state_dict(), ckpt_path)
        print("✅ Saved best ckpt:", ckpt_path)

# --- plot APS vs epoch ---
fig = plt.figure()
plt.plot(np.arange(EPOCHS), aps_list, marker='o')
plt.title("NextVisit-12m: APS vs Epoch")
plt.xlabel("Epoch")
plt.ylabel("APS (micro)")
plt.grid(True)

aps_png = os.path.join(OUT_DIR, "nextvisit_aps_vs_epoch.png")
plt.savefig(aps_png, dpi=200, bbox_inches="tight")
plt.show()
print("✅ Saved:", aps_png)

# --- plot AUROC vs epoch ---
fig = plt.figure()
plt.plot(np.arange(EPOCHS), auroc_list, marker='s')
plt.title("NextVisit-12m: AUROC vs Epoch")
plt.xlabel("Epoch")
plt.ylabel("AUROC (micro)")
plt.grid(True)

auroc_png = os.path.join(OUT_DIR, "nextvisit_auroc_vs_epoch.png")
plt.savefig(auroc_png, dpi=200, bbox_inches="tight")
plt.show()
print("✅ Saved:", auroc_png)



In [None]:
y_true, y_prob = collect_test_probs()

# ===== micro flatten =====
y_true_micro = y_true.ravel()
y_prob_micro = y_prob.ravel()

# ===== PR curve =====
prec, rec, _ = precision_recall_curve(y_true_micro, y_prob_micro)
ap_micro = average_precision_score(y_true_micro, y_prob_micro)

plt.figure()
plt.plot(rec, prec)
plt.title(f"PR Curve (micro) | AP={ap_micro:.4f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.grid(True)

pr_png = os.path.join(OUT_DIR, "nextvisit_pr_curve_micro.png")
plt.savefig(pr_png, dpi=200, bbox_inches="tight")
plt.show()
print("✅ Saved:", pr_png)

# ===== ROC curve =====
fpr, tpr, _ = roc_curve(y_true_micro, y_prob_micro)
auc_micro = auc(fpr, tpr)  # 或 roc_auc_score(y_true_micro, y_prob_micro)

plt.figure()
plt.plot(fpr, tpr)
plt.title(f"ROC Curve (micro) | AUC={auc_micro:.4f}")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.grid(True)

roc_png = os.path.join(OUT_DIR, "nextvisit_roc_curve_micro.png")
plt.savefig(roc_png, dpi=200, bbox_inches="tight")
plt.show()
print("✅ Saved:", roc_png)

In [None]:
# --- PyTorch Baseline Logistic Regression (multi-label, GPU support) ---
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import average_precision_score, roc_auc_score, precision_recall_curve, roc_curve
import matplotlib.pyplot as plt
import os

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Use same X_train, X_test, Y_train, Y_test as above
X_train_torch = torch.tensor(X_train, dtype=torch.float32).to(device)
X_test_torch  = torch.tensor(X_test, dtype=torch.float32).to(device)
Y_train_torch = torch.tensor(Y_train, dtype=torch.float32).to(device)
Y_test_torch  = torch.tensor(Y_test, dtype=torch.float32).to(device)

n_features = X_train.shape[1]
n_labels = Y_train.shape[1]

class TorchLogReg(nn.Module):
    def __init__(self, n_features, n_labels):
        super().__init__()
        self.linear = nn.Linear(n_features, n_labels)
    def forward(self, x):
        return self.linear(x)

model = TorchLogReg(n_features, n_labels).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()

EPOCHS = 15
batch_size = 128
for epoch in range(EPOCHS):
    model.train()
    perm = torch.randperm(X_train_torch.size(0), device=device)
    total_loss = 0.0
    for i in range(0, X_train_torch.size(0), batch_size):
        idx = perm[i:i+batch_size]
        xb = X_train_torch[idx]
        yb = Y_train_torch[idx]
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * xb.size(0)
    avg_loss = total_loss / X_train_torch.size(0)
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f}")

# Evaluation
model.eval()
with torch.no_grad():
    logits = model(X_test_torch)
    probs = torch.sigmoid(logits).cpu().numpy()
    y_true = Y_test_torch.cpu().numpy()

aps_micro = average_precision_score(y_true, probs, average='micro')
auroc_micro = roc_auc_score(y_true, probs, average='micro')
print(f"[PyTorch-LogReg] APS = {aps_micro:.4f}, AUROC = {auroc_micro:.4f}")

OUT_DIR = "../../outputs"
os.makedirs(OUT_DIR, exist_ok=True)
PR_CURVE_PNG  = os.path.join(OUT_DIR, "baseline_torchlogreg_pr.png")
ROC_CURVE_PNG = os.path.join(OUT_DIR, "baseline_torchlogreg_roc.png")

precision, recall, _ = precision_recall_curve(y_true.ravel(), probs.ravel())
fpr, tpr, _ = roc_curve(y_true.ravel(), probs.ravel())

plt.figure(figsize=(6,4))
plt.plot(recall, precision, label=f"AP={aps_micro:.3f}")
plt.xlabel("Recall"); plt.ylabel("Precision")
plt.title("PR Curve (PyTorch LogReg Baseline)")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(PR_CURVE_PNG, dpi=150)
plt.show()

plt.figure(figsize=(6,4))
plt.plot(fpr, tpr, label=f"AUC={auroc_micro:.3f}")
plt.plot([0,1], [0,1], linestyle="--", alpha=0.5)
plt.xlabel("FPR"); plt.ylabel("TPR")
plt.title("ROC Curve (PyTorch LogReg Baseline)")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(ROC_CURVE_PNG, dpi=150)
plt.show()

print("PR →", PR_CURVE_PNG)
print("ROC →", ROC_CURVE_PNG)


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    precision_recall_curve, roc_curve, auc,
    average_precision_score, roc_auc_score
)

def micro_flat(y_true, y_prob):
    return y_true.ravel(), y_prob.ravel()

def pr_roc_stats(y_true_micro, y_prob_micro):
    # PR
    precision, recall, _ = precision_recall_curve(y_true_micro, y_prob_micro)
    ap = average_precision_score(y_true_micro, y_prob_micro)
    # ROC
    fpr, tpr, _ = roc_curve(y_true_micro, y_prob_micro)
    roc_auc = auc(fpr, tpr)  # same as roc_auc_score for binary micro-flatten
    return (precision, recall, ap), (fpr, tpr, roc_auc)

# =========================
# 1) Prepare inputs
# =========================
# --- BEHRT ---
y_true_behrt, y_prob_behrt = collect_test_probs()   # if already computed, reuse it
y_true_behrt_micro, y_prob_behrt_micro = micro_flat(y_true_behrt, y_prob_behrt)

# --- LogReg baseline ---
# rename your baseline variables to avoid overwrite confusion:
# y_true_lr = y_true
# probs_lr  = probs
y_true_lr_micro, probs_lr_micro = micro_flat(y_true_lr, probs_lr)

# (Optional sanity check) same length?
assert y_true_behrt_micro.shape == y_true_lr_micro.shape, \
    f"Mismatch: BEHRT {y_true_behrt_micro.shape} vs LR {y_true_lr_micro.shape}"

# =========================
# 2) Compute curves
# =========================
(behrt_prec, behrt_rec, behrt_ap), (behrt_fpr, behrt_tpr, behrt_auc) = pr_roc_stats(
    y_true_behrt_micro, y_prob_behrt_micro
)

(lr_prec, lr_rec, lr_ap), (lr_fpr, lr_tpr, lr_auc) = pr_roc_stats(
    y_true_lr_micro, probs_lr_micro
)

# =========================
# 3) Plot comparison
# =========================
os.makedirs(OUT_DIR, exist_ok=True)
PR_CMP_PNG  = os.path.join(OUT_DIR, "compare_pr_micro_behrt_vs_logreg.png")
ROC_CMP_PNG = os.path.join(OUT_DIR, "compare_roc_micro_behrt_vs_logreg.png")

# ---- PR comparison ----
plt.figure(figsize=(6,4))
plt.plot(behrt_rec, behrt_prec, label=f"BEHRT (AP={behrt_ap:.4f})")
plt.plot(lr_rec, lr_prec, label=f"BoC+LogReg (AP={lr_ap:.4f})")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("PR Curve (micro) — BEHRT vs Baseline")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(PR_CMP_PNG, dpi=200)
plt.show()
print("✅ Saved:", PR_CMP_PNG)

# ---- ROC comparison ----
plt.figure(figsize=(6,4))
plt.plot(behrt_fpr, behrt_tpr, label=f"BEHRT (AUC={behrt_auc:.4f})")
plt.plot(lr_fpr, lr_tpr, label=f"BoC+LogReg (AUC={lr_auc:.4f})")
plt.plot([0,1], [0,1], linestyle="--", alpha=0.5)
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve (micro) — BEHRT vs Baseline")
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.savefig(ROC_CMP_PNG, dpi=200)
plt.show()
print("✅ Saved:", ROC_CMP_PNG)
