In [None]:
# # Param
# test = "True"

In [None]:
is_test = True
if test == "False" or test == False:
    is_test = False

In [4]:
dataset = "data"

In [5]:
from utils.data_loader import DataLoader
from utils.data_io import join_path

In [6]:
data_loader = DataLoader(dataset)

In [7]:
train_df = data_loader.get_data_train()
test_df = data_loader.get_data_test()

In [8]:
data_desc = data_loader.get_data_desc()

label_column = data_desc['label_column']
text_column = data_desc['text_column']

In [9]:
keyword_concepts = data_loader.get_keyword_concepts()
keywords = []
for k in keyword_concepts.keys():
    keywords += keyword_concepts[k]

print(keywords)

['coronary', 'myocardial', 'hypertension', 'cardiac', 'systolic', 'colitis', 'esophageal', 'gastrointestinal', 'bowel', 'duodenal', 'defect', 'loss', 'airway', 'graft', 'respiratory', 'cancer', 'carcinoma', 'sarcoma', 'malignancy', 'chemotherapy', 'brain', 'cerebral', 'neuronal', 'motor', 'cord']


In [10]:
abstract_concepts = data_loader.get_abstract_concepts()
abstract_concepts = [ac['abstract_concept_name'] for ac in abstract_concepts]
abstract_concepts

['Cardiac Function and Disorders',
 'Heart Muscle and Blood Pressure',
 'Coronary Artery Issues',
 'Intestinal and Esophageal Conditions',
 'Gastrointestinal Tract Ailments',
 'Inflammatory Bowel Diseases',
 'General Pathological States',
 'Respiratory System Impairments',
 'Tissue and Graft Issues',
 'Malignant Tumors and Growths',
 'Cancer Treatment and Types',
 'Oncological Malignancies',
 'Central and Peripheral Nervous System Disorders',
 'Brain and Cerebral Conditions',
 'Spinal Cord and Motor Function Impairment']

In [11]:
from sklearn.preprocessing import LabelEncoder

In [12]:
le = LabelEncoder()
le.fit(train_df[label_column])
train_df[label_column] = le.transform(train_df[label_column])
test_df[label_column] = le.transform(test_df[label_column])

In [13]:
labels = list(le.classes_)
labels

['cardiovascular diseases',
 'digestive system diseases',
 'general pathological conditions',
 'neoplasms',
 'nervous system diseases']

In [14]:
if is_test:
    train_df = train_df.groupby(label_column).sample(1)
    test_df = test_df.groupby(label_column).sample(1)

In [15]:
train_texts = train_df[text_column].to_list()
train_labels = train_df[label_column].to_list()
test_texts = test_df[text_column].to_list()
test_labels = test_df[label_column].to_list()

In [16]:
from sklearn.model_selection import train_test_split

In [17]:
train_texts, val_texts, train_labels, val_labels = train_test_split(train_texts, train_labels, test_size=0.2, random_state=42)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertForSequenceClassification
import numpy as np
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
from tqdm import tqdm
import os
import networkx as nx
from matplotlib import cm
from sentence_transformers import SentenceTransformer

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizer, BertForSequenceClassification
from sentence_transformers import SentenceTransformer


def load_model(path):
    model = BertForSequenceClassification.from_pretrained(path)
    tokenizer = BertTokenizer.from_pretrained(path)
    return model, tokenizer


class ConceptNetwork(nn.Module):
    def __init__(self, concept_names, embedding_dim, keyword_nli_model_path, abstract_nli_model_path):
        """
        concept_names: List[List[str]] gồm [keywords, abstract_concepts, labels]
        """
        super(ConceptNetwork, self).__init__()
        self.embedding_dim = embedding_dim
        self.concept_names = concept_names  # [[keywords], [abstracts], [labels]]
        self.sbert = SentenceTransformer("all-MiniLM-L6-v2")

        # === Embeddings ===
        self.keyword_embeddings = nn.Parameter(torch.randn(len(concept_names[0]), embedding_dim))
        self.abstract_embeddings = nn.Parameter(torch.randn(len(concept_names[1]), embedding_dim))
        self.label_embeddings = nn.Parameter(torch.randn(len(concept_names[2]), embedding_dim))

        # === Beta parameters (learnable reliability weights) ===
        self.keyword_betas = nn.Parameter(torch.zeros(len(concept_names[0])))
        self.abstract_betas = nn.Parameter(torch.zeros(len(concept_names[1])))

        # === Semantic predictor (shared) ===
        self.semantic_predictor = nn.Linear(embedding_dim, 384)  # SBERT output dim

        # === Load NLI scorers ===
        self.keyword_scorer, self.keyword_tokenizer = load_model(keyword_nli_model_path)
        self.abstract_scorer, self.abstract_tokenizer = load_model(abstract_nli_model_path)

    def forward(self, texts, device):
        """
        texts: List[str] – batch of input texts
        device: torch.device
        Returns:
            predictions: [B, n_labels]
            keyword_semantic, abstract_semantic, label_semantic: [n_i, 384]
            keyword_scores: [B, n_keywords]
            abstract_scores: [B, n_abstracts]
        """
        B = len(texts)
        device = torch.device(device)

        # === Keyword layer: direct logits ===
        keyword_scores = []
        for cname in self.concept_names[0]:
            inputs = self.keyword_tokenizer(texts, [cname] * B,
                                            return_tensors='pt', padding=True, truncation=True).to(device)
            with torch.no_grad():
                outputs = self.keyword_scorer(**inputs)
                probs = torch.softmax(outputs.logits, dim=-1)
                scores = probs[:, 1]  # entailment
            keyword_scores.append(scores)
        keyword_scores = torch.stack(keyword_scores, dim=1)  # [B, n_kw]

        # === Abstract layer: direct logits ===
        abstract_direct = []
        for cname in self.concept_names[1]:
            inputs = self.abstract_tokenizer(texts, [cname] * B,
                                             return_tensors='pt', padding=True, truncation=True).to(device)
            with torch.no_grad():
                outputs = self.abstract_scorer(**inputs)
                probs = torch.softmax(outputs.logits, dim=-1)
                scores = probs[:, 1]
            abstract_direct.append(scores)
        abstract_direct = torch.stack(abstract_direct, dim=1)  # [B, n_abs]

        # === Relation logits: keyword → abstract ===
        attn_kw_abs = torch.matmul(self.keyword_embeddings, self.abstract_embeddings.T)  # [n_kw, n_abs]
        attn_kw_abs = torch.softmax(attn_kw_abs, dim=0)  # softmax theo n_kw
        relation_kw_abs = torch.matmul(keyword_scores, attn_kw_abs)  # [B, n_abs]

        # === Abstract scores: combine direct + relation ===
        abstract_betas = torch.sigmoid(self.abstract_betas)  # [n_abs]
        abstract_scores = abstract_direct * abstract_betas + relation_kw_abs * (1 - abstract_betas)  # [B, n_abs]

        # === Label layer: relation only (abstract → label) ===
        attn_abs_lbl = torch.matmul(self.abstract_embeddings, self.label_embeddings.T)  # [n_abs, n_lbl]
        attn_abs_lbl = torch.softmax(attn_abs_lbl, dim=0)  # softmax theo n_abs
        predictions = torch.matmul(abstract_scores, attn_abs_lbl)  # [B, n_lbl]

        # === Semantic predictions ===
        keyword_semantic = self.semantic_predictor(self.keyword_embeddings)     # [n_kw, 384]
        abstract_semantic = self.semantic_predictor(self.abstract_embeddings)   # [n_abs, 384]
        label_semantic = self.semantic_predictor(self.label_embeddings)         # [n_lbl, 384]

        return predictions, keyword_semantic, abstract_semantic, label_semantic, keyword_scores, abstract_scores


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import matplotlib.pyplot as plt
import math
from tqdm import tqdm


def train_model(model, train_texts, train_labels, val_texts, val_labels, concept_names, sbert_embeddings,
                batch_size=16, num_epochs=100, patience=5, lambda_semantic=1.0, device='cuda'):

    model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    mse_loss = nn.MSELoss()

    # Chuyển SBERT embeddings sang tensor 1 lần duy nhất
    sbert_tensors = [
        torch.tensor(arr, device=device, dtype=torch.float32)
        for arr in sbert_embeddings
    ]

    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        epoch_train_loss = 0

        for i in tqdm(range(0, len(train_texts), batch_size), desc=f"Epoch {epoch+1}"):
            batch_texts = train_texts[i:i + batch_size]
            batch_labels = torch.tensor(train_labels[i:i + batch_size], dtype=torch.long).to(device)

            optimizer.zero_grad()
            predictions, keyword_semantic, abstract_semantic, label_semantic, _, _ = model(batch_texts, device)

            # Loss predictor
            pred_loss = criterion(predictions, batch_labels)

            # Semantic loss
            semantic_loss = 0
            for layer_idx, embeddings in enumerate([keyword_semantic, abstract_semantic, label_semantic]):
                semantic_loss += mse_loss(embeddings, sbert_tensors[layer_idx])

            total_loss = pred_loss + lambda_semantic * semantic_loss
            total_loss.backward()
            optimizer.step()
            epoch_train_loss += total_loss.item()

        # === Validation ===
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for i in range(0, len(val_texts), batch_size):
                batch_texts = val_texts[i:i + batch_size]
                batch_labels = torch.tensor(val_labels[i:i + batch_size], dtype=torch.long).to(device)

                predictions, keyword_semantic, abstract_semantic, label_semantic, _, _ = model(batch_texts, device)

                pred_loss = criterion(predictions, batch_labels)
                semantic_loss = 0
                for layer_idx, embeddings in enumerate([keyword_semantic, abstract_semantic, label_semantic]):
                    semantic_loss += mse_loss(embeddings, sbert_tensors[layer_idx])

                val_loss += (pred_loss + lambda_semantic * semantic_loss).item()

        avg_train_loss = epoch_train_loss / math.ceil(len(train_texts) / batch_size)
        avg_val_loss = val_loss / math.ceil(len(val_texts) / batch_size)
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)

        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}")

        # Early stopping
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pt')
            print("✅ Saved new best model.")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("⏹️ Early stopping triggered.")
                break

    # === Plot losses ===
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training vs Validation Loss')
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig('loss_plot.png')
    plt.show()

    return train_losses, val_losses

In [None]:
def test_model(model, test_texts, test_labels, concept_names, batch_size=16, device='cuda'):
    model = model.to(device)
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.no_grad():
        for i in tqdm(range(0, len(test_texts), batch_size), desc="Testing"):
            batch_texts = test_texts[i:i+batch_size]
            batch_labels = test_labels[i:i+batch_size]
            
            outputs, _, _, _, _, _ = model(batch_texts, device)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            predictions.extend(preds)
            true_labels.extend(batch_labels)
    
    print("\nClassification Report:")
    print(classification_report(true_labels, predictions, target_names=concept_names[2]))

In [22]:
def interpretable_prediction(model, test_text, test_label, concept_names, device='cuda'):
    model = model.to(device)
    model.eval()
    
    print("\nInterpretable Prediction for Sample:")
    print(f"Input Text: {test_text}")
    print(f"True Label: {concept_names[2][test_label]}")
    
    with torch.no_grad():
        # Process single sample
        predictions, _, _, _, keyword_scores, abstract_scores = model([test_text], device)
        
        # Keyword layer activations
        print("\nKeyword Layer Activations:")
        for concept_name, score in zip(concept_names[0], keyword_scores[0]):
            print(f"  {concept_name}: {score:.4f}")
        
        # Abstract layer activations
        print("\nAbstract Layer Activations:")
        for concept_name, score in zip(concept_names[1], abstract_scores[0]):
            print(f"  {concept_name}: {score:.4f}")
        
        # Label layer predictions
        label_probs = torch.softmax(predictions, dim=-1)[0]
        print("\nLabel Layer Probabilities:")
        for concept_name, prob in zip(concept_names[2], label_probs):
            print(f"  {concept_name}: {prob:.4f}")
        
        # Final prediction
        predicted_label_idx = torch.argmax(predictions, dim=-1).item()
        print(f"\nPredicted Label: {concept_names[2][predicted_label_idx]}")
        
        return keyword_scores[0].cpu().numpy(), abstract_scores[0].cpu().numpy(), label_probs.cpu().numpy(), predicted_label_idx

In [None]:
import matplotlib.pyplot as plt
import networkx as nx
from matplotlib import cm

def visualize_network(model, concept_names, keyword_scores, abstract_scores, label_probs, predicted_label_idx, output_file='network_visualization.png'):
    G = nx.DiGraph()

    # Add nodes
    for i, kw in enumerate(concept_names[0]):
        G.add_node(f"K_{i}", label=kw, layer='keyword', score=keyword_scores[i])
    for i, abs_c in enumerate(concept_names[1]):
        G.add_node(f"A_{i}", label=abs_c, layer='abstract', score=abstract_scores[i])
    for i, lbl in enumerate(concept_names[2]):
        G.add_node(f"L_{i}", label=lbl, layer='label', score=label_probs[i])

    # Edges: keyword → abstract (use beta)
    for i in range(len(concept_names[0])):
        for j in range(len(concept_names[1])):
            weight = model.keyword_betas[i].item() * model.abstract_betas[j].item()
            G.add_edge(f"K_{i}", f"A_{j}", weight=weight)

    # Edges: abstract → label (recompute attention)
    with torch.no_grad():
        abs_emb = model.abstract_embeddings.detach().cpu()
        lbl_emb = model.label_embeddings.detach().cpu()
        attn = torch.matmul(abs_emb, lbl_emb.T)  # [n_abs, n_lbl]
        attn = torch.softmax(attn, dim=0)

    for i in range(len(concept_names[1])):
        for j in range(len(concept_names[2])):
            weight = attn[i, j].item()
            G.add_edge(f"A_{i}", f"L_{j}", weight=weight)

    # Positioning
    pos = {}
    max_nodes = max(len(concept_names[0]), len(concept_names[1]), len(concept_names[2]))
    for i in range(len(concept_names[0])):
        pos[f"K_{i}"] = (0, max_nodes - i)
    for i in range(len(concept_names[1])):
        pos[f"A_{i}"] = (1, max_nodes - i)
    for i in range(len(concept_names[2])):
        pos[f"L_{i}"] = (2, max_nodes - i)

    # Node visuals
    node_colors = [data['score'] for _, data in G.nodes(data=True)]
    node_sizes = [500 + data['score'] * 2000 for _, data in G.nodes(data=True)]
    edge_widths = [abs(G[u][v]['weight']) * 2 for u, v in G.edges()]

    # Draw
    plt.figure(figsize=(12, 8))
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=node_sizes, cmap=cm.viridis)
    nx.draw_networkx_edges(G, pos, width=edge_widths, alpha=0.5)

    labels = {n: f"{d['label']}\n{d['score']:.2f}" for n, d in G.nodes(data=True)}
    nx.draw_networkx_labels(G, pos, labels, font_size=8)

    # Highlight predicted node
    pred_node = f"L_{predicted_label_idx}"
    pred_size = 500 + G.nodes[pred_node]['score'] * 2000
    nx.draw_networkx_nodes(G, pos, nodelist=[pred_node], node_color='red', node_size=pred_size)

    plt.title("Concept Network Visualization")
    plt.savefig(output_file)
    plt.show()
    plt.close()


In [24]:
def load_model(path):
    model = BertForSequenceClassification.from_pretrained(path)
    tokenizer = BertTokenizer.from_pretrained(path)
    return model, tokenizer

In [25]:
keyword_nli_model_path = join_path(dataset, 'scorer_model', 'keyword_scorer')
abstract_nli_model_path = join_path(dataset, 'scorer_model', 'abstract_scorer')

In [26]:
# SBERT embeddings (precomputed for efficiency)
from sentence_transformers import SentenceTransformer
sbert_model = SentenceTransformer('all-MiniLM-L6-v2')
sbert_embeddings = [
    sbert_model.encode(keywords, convert_to_numpy=True),
    sbert_model.encode(abstract_concepts, convert_to_numpy=True),
    sbert_model.encode(labels, convert_to_numpy=True)
]

In [None]:
concept_names = [keywords, abstract_concepts, labels]
model = ConceptNetwork(
    concept_names, embedding_dim=64, 
    keyword_nli_model_path=keyword_nli_model_path,
    abstract_nli_model_path=abstract_nli_model_path
)

In [28]:
from transformers.utils import logging

logging.set_verbosity_error()

In [None]:
# Train
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_losses, val_losses = train_model(
    model, train_texts, train_labels, val_texts, val_labels,
    concept_names, sbert_embeddings, batch_size=16, num_epochs=20,
    patience=5, lambda_semantic=0.1, device=device
)

Epoch 1: 100%|██████████| 4/4 [00:35<00:00,  8.76s/it]


Epoch 1, Train Loss: 2.1116, Val Loss: 1.3244


Epoch 2: 100%|██████████| 4/4 [00:38<00:00,  9.57s/it]


Epoch 2, Train Loss: 2.0388, Val Loss: 1.3118


In [31]:
model.load_state_dict(torch.load('best_model.pt'))

<All keys matched successfully>

In [None]:
test_model(model, test_texts, test_labels, concept_names, batch_size=16, device=device)


Classification Report:
                                 precision    recall  f1-score   support

        cardiovascular diseases       0.00      0.00      0.00         1
      digestive system diseases       0.00      0.00      0.00         1
general pathological conditions       0.20      1.00      0.33         1
                      neoplasms       0.00      0.00      0.00         1
        nervous system diseases       0.00      0.00      0.00         1

                       accuracy                           0.20         5
                      macro avg       0.04      0.20      0.07         5
                   weighted avg       0.04      0.20      0.07         5



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [34]:
# Interpretable prediction and visualization for a single sample
if test_texts and test_labels:
    keyword_scores, abstract_scores, label_probs, predicted_label_idx = interpretable_prediction(
        model, test_texts[0], test_labels[0], concept_names, device=device
    )
    visualize_network(model, concept_names, keyword_scores, abstract_scores, label_probs, predicted_label_idx)


Interpretable Prediction for Sample:
Input Text: Plasma concentrations of epinephrine during CPR in the dog. STUDY OBJECTIVE: The purpose of this study was to evaluate whether the marked increase in the plasma concentrations of epinephrine during cardiopulmonary arrest and basic life support (BLS) could be due in part to decreased distribution and/or elimination. DESIGN AND INTERVENTIONS: Dogs were randomly assigned to undergo adrenalectomy or sham-operation. Some adrenalectomized animals received an epinephrine infusion. MEASUREMENTS AND MAIN RESULTS: In the seven sham-operated dogs, the plasma epinephrine concentrations increased markedly during BLS as expected. In the seven adrenalectomized dogs receiving a constant infusion of epinephrine, cardiopulmonary arrest and BLS induced a three to sixfold increase in plasma epinephrine concentrations, with an increase in the mean plasma epinephrine concentrations (calculated from the area under the curve) of 1.21 +/- 0.12 ng/mL (P less tha