## My GraphLLM Model for MultiModal graph-text interaction

In [None]:
import torch

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
import dgl
print("DGL version:", dgl.__version__)


In [None]:
import os
import json
import pandas as pd
import numpy as np
from transformers import BertTokenizer, BertModel
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
from transformers import BertTokenizer, BertModel
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import precision_score, recall_score, f1_score, confusion_matrix, accuracy_score
import torch.nn as nn
from dgl.data.utils import save_graphs, load_graphs
import pickle
import os


# I set this to save processed data and preventing reprocessing for mutiple run
SAVE_DIR = 'processed_data'
FORCE_NEW_RUN = False  # If True start a new run and overwrite

if FORCE_NEW_RUN or not os.path.exists(os.path.join(SAVE_DIR, "latest_run")):
    print("Starting a new run...")
    if not os.path.exists(SAVE_DIR):
        os.makedirs(SAVE_DIR)
    CURRENT_RUN_DIR = os.path.join(SAVE_DIR, "latest_run")
    os.makedirs(CURRENT_RUN_DIR, exist_ok=True)
else:
    print("Reusing existing saved data...")
    CURRENT_RUN_DIR = os.path.join(SAVE_DIR, "latest_run")


# Initialize our Tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
bert_model = BertModel.from_pretrained('bert-base-uncased')
bert_model.eval()


# Since Data is Large I put this valve to limit it sometimes
NODE_LIMIT = None
EDGE_LIMIT = None

DATASET_PATH = '/home/reza/ML PJ/Model1/implementation/my_twiBot20_1'

# Load node data
with open(os.path.join(DATASET_PATH, 'node.json'), 'r') as f:
    nodes = json.load(f)


print(f"Number of nodes processed: {len(nodes)}")
if NODE_LIMIT:
    nodes = nodes[:NODE_LIMIT]

# Show some samples of nodes
print("First 3 entries in node.json:")
for node in nodes[:3]:
    print(json.dumps(node, indent=2))
    print('-' * 40)
user_count, tweet_count = 0, 0

for node in nodes:
    node_id = node['id']
    if node_id.startswith('u') and user_count < 3:
        print("User node:")
        print(json.dumps(node, indent=2))
        print('-' * 40)
        user_count += 1
    elif node_id.startswith('t') and tweet_count < 3:
        print("Tweet node:")
        print(json.dumps(node, indent=2))
        print('-' * 40)
        tweet_count += 1
    
    if user_count >= 3 and tweet_count >= 3:
        break

# edge data
edges_df = pd.read_csv(os.path.join(DATASET_PATH, 'edge.csv'))
print("\nFirst 5 entries in edge.csv:")
print(edges_df.head())

# label nodes both users and tweets and filter unlabeled data in our dataset
label_mapping = {'bot': 1, 'human': 0}
labels_df = pd.read_csv(os.path.join(DATASET_PATH, 'label.csv'))
print("\nFirst 5 entries in label.csv:")
print(labels_df.head())

labels_dict = {row['id']: label_mapping[row['label']] for _, row in labels_df.iterrows()}
labeled_user_ids = set(labels_dict.keys())
print(f"Number of labeled users: {len(labeled_user_ids)}")

node_types = {'user': [], 'tweet': []}
edge_types = {'friend': [], 'follow': [], 'post': []}

for node in nodes:
    node_id = node['id']
    if node_id.startswith('u') and node_id in labeled_user_ids:
        node_types['user'].append(node_id)
    elif node_id.startswith('t'):
        node_types['tweet'].append(node_id)

print(f"Filtered number of user nodes with labels: {len(node_types['user'])}")
print(f"Total tweet nodes: {len(node_types['tweet'])}")

node_id_to_index = {node_id: idx for idx, node_id in enumerate(node_types['user'] + node_types['tweet'])}

filtered_edges_df = edges_df[
    edges_df['source_id'].isin(labeled_user_ids) &
    edges_df['target_id'].isin(node_id_to_index)
]
print(f"Number of edges after filtering: {len(filtered_edges_df)}")

print("Mapping source and target IDs to indices...")
filtered_edges_df['src_index'] = filtered_edges_df['source_id'].map(node_id_to_index)
filtered_edges_df['dst_index'] = filtered_edges_df['target_id'].map(node_id_to_index)
print("Mapping complete.")

# Now process edges
filtered_edges = {'friend': [], 'follow': [], 'post': []}

for rel in filtered_edges:
    edge_type_df = filtered_edges_df[filtered_edges_df['relation'] == rel]
    filtered_edges[rel] = list(zip(edge_type_df['src_index'], edge_type_df['dst_index']))
    print(f"Processed {rel} edges: {len(filtered_edges[rel])} added.")

# Show counts of each edge type
print("\nFiltered Edge type counts:")
for edge_type, edges in filtered_edges.items():
    print(f"{edge_type}: {len(edges)} edges")

print("Edge separation by type completed.")

if any(len(edges) > 0 for edges in filtered_edges.values()):
    data_dict = {
        ('user', 'friend', 'user'): (torch.tensor([e[0] for e in filtered_edges['friend']]),
                                     torch.tensor([e[1] for e in filtered_edges['friend']])),
        ('user', 'follow', 'user'): (torch.tensor([e[0] for e in filtered_edges['follow']]),
                                     torch.tensor([e[1] for e in filtered_edges['follow']])),
        ('user', 'post', 'tweet'): (torch.tensor([e[0] for e in filtered_edges['post']]),
                                    torch.tensor([e[1] for e in filtered_edges['post']]))
    }
    print("Filtered edge tensors created.")
else:
    print("No edges were found. Check the filtering criteria and data consistency.")


print("Strat counting connected tweet ids")

filtered_post_edges = edges_df[(edges_df['relation'] == 'post') & (edges_df['source_id'].isin(labeled_user_ids))]
connected_tweet_ids = set(filtered_post_edges['target_id'])

#checkpoints
print(f"Number of connected tweet IDs: {len(connected_tweet_ids)}")
print("Sample connected tweet IDs:", list(connected_tweet_ids)[:2])
connected_tweet_ids_in_index = connected_tweet_ids.intersection(node_id_to_index.keys())
print(f"Number of connected tweet IDs actually in node_id_to_index: {len(connected_tweet_ids_in_index)}")


# Step 2: build graph
if any(len(edges) > 0 for edges in filtered_edges.values()):
    print("\nBuilding the DGL heterogeneous graph with filtered data...")
    G = dgl.heterograph(data_dict)
    print("Filtered DGL heterogeneous graph created.")
    
    for ntype in G.ntypes:
        filtered_ids = [node_id_to_index[node_id] for node_id in node_types[ntype] if node_id in node_id_to_index]
        if len(filtered_ids) == G.num_nodes(ntype):
            G.nodes[ntype].data[dgl.NID] = torch.tensor(filtered_ids, dtype=torch.int64)
        else:
            print(f"Warning: Mismatch in number of nodes for type '{ntype}'. Expected {G.num_nodes(ntype)}, got {len(filtered_ids)}.")

    print("\n=== Graph Data Inspection ===")
    for ntype in G.ntypes:
        print(f"Number of nodes for type '{ntype}':", G.num_nodes(ntype))
        print(f"Node data for '{ntype}':", G.nodes[ntype].data.keys())
        
    
    G.ndata[dgl.NTYPE] = {'user': torch.full((G.num_nodes('user'),), G.get_ntype_id('user'), dtype=torch.int32),'tweet': torch.full((G.num_nodes('tweet'),), G.get_ntype_id('tweet'), dtype=torch.int32)}

    
    # Step 3: Use LLM Tokenizer to embedd text data and generate features for user nodes
    print("\nGenerating BERT embeddings for user nodes...")
    user_features = []
    for node in nodes:
        node_id = node['id']
        if node_id.startswith('u') and node_id in labeled_user_ids:
            description = node.get('description', '')
            inputs = tokenizer(description, return_tensors='pt', truncation=True, max_length=128)
            
            with torch.no_grad():
                outputs = bert_model(**inputs)
                embedding = outputs.last_hidden_state.mean(dim=1).squeeze(0)  # Mean pooling for fixed-size embedding
            
            user_features.append(embedding)
    
    user_features_path = os.path.join(CURRENT_RUN_DIR, "user_features.pt")

    if os.path.exists(user_features_path):
        print("Loading precomputed user features...")
        G.nodes['user'].data['feat'] = torch.load(user_features_path)
    else:
        print("Computing user features...")
        G.nodes['user'].data['feat'] = torch.stack(user_features)
        print("BERT embeddings assigned to user nodes.")

        torch.save(G.nodes['user'].data['feat'], user_features_path)
        print("User features saved.")


    print("\nGenerating embeddings for connected tweet nodes present in node_id_to_index...")

    # Prepare batches of tweet texts
    tweet_texts = []
    for node in nodes:
        node_id = node['id']
        if node_id in connected_tweet_ids_in_index and node_id.startswith('t'):
            text = node.get('text', '')
            tweet_texts.append((node_id, text))

    batch_size = 20  # I set it based on my available memory
    tweet_embeddings = {}
    for i in tqdm(range(0, len(tweet_texts), batch_size), desc="Processing tweet batches"):
        batch = tweet_texts[i:i + batch_size]
        texts = [text for _, text in batch]

        inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)

        with torch.no_grad():
            outputs = bert_model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1)  # Mean pooling

        for (tweet_id, _), embedding in zip(batch, embeddings):
            tweet_embeddings[tweet_id] = embedding
    
    num_tweet_nodes = G.num_nodes('tweet')
    embedding_dim = next(iter(tweet_embeddings.values())).shape[0]
    print(f'Embedding dimension: {embedding_dim}')

    # We need a tensor to hold features
    tweet_features = torch.zeros((num_tweet_nodes, embedding_dim))
    print(f'Initialized tweet_features tensor with shape: {tweet_features.shape}')
    placeholder_embedding = torch.zeros(embedding_dim)

    tweet_node_ids = [node_id for node_id in node_id_to_index if node_id.startswith('t')]
    
    # change the batch size for next process
    batch_size = 1000 
    # Process batches and assign embeddings directly to the preallocated tensor
    for start in range(0, num_tweet_nodes, batch_size):
        end = min(start + batch_size, num_tweet_nodes)
        batch_ids = tweet_node_ids[start:end]
        tweet_feature_batch = torch.stack([
            tweet_embeddings.get(node_id, placeholder_embedding) for node_id in batch_ids
        ])

        print(f'\nProcessing batch from index {start} to {end}')
        print(f'Number of nodes in this batch: {len(batch_ids)}')
        print(f'Tweet feature batch shape: {tweet_feature_batch.shape}')

        tweet_features[start:end, :] = tweet_feature_batch

    print(f'\nFinal tweet_features tensor shape: {tweet_features.shape}')
    
    tweet_features_path = os.path.join(CURRENT_RUN_DIR, "tweet_features.pt")

    if os.path.exists(tweet_features_path):
        print("Loading precomputed tweet features...")
        G.nodes['tweet'].data['feat'] = torch.load(tweet_features_path)
    else:
        print("Computing tweet features...")
        G.nodes['tweet'].data['feat'] = tweet_features
        print("BERT embeddings assigned to connected tweet nodes in the graph.")

        torch.save(G.nodes['tweet'].data['feat'], tweet_features_path)
        print("Tweet features saved.")


    for ntype in G.ntypes:
        G.nodes[ntype].data[dgl.NID] = torch.arange(G.num_nodes(ntype))

    print("\n=== Graph Data Inspection ===")
    print("Graph G node data:", G.ndata.keys())
    
    graph_data_path = os.path.join(CURRENT_RUN_DIR, "graph_data.bin")

    if os.path.exists(graph_data_path):
        print("Loading precomputed graph...")
        G, _ = load_graphs(graph_data_path)
        G = G[0]
    else:
        print("Building graph...")
        save_graphs(graph_data_path, [G])
        print("Graph saved.")

    
    for ntype in G.ntypes:
        print(f"Number of nodes for type '{ntype}':", G.num_nodes(ntype))
        print(f"Node data for '{ntype}':", G.nodes[ntype].data.keys())

    print("\n=== Edge Types and Counts ===")
    for etype in G.etypes:
        print(f"Edge type '{etype}' has {G.num_edges(etype)} edges")
        print(f"Edge data for '{etype}':", G.edges[etype].data.keys())
  
   
else:
    print("Graph creation aborted due to no available edges.")


texts, labels = [], []
print("\nPreparing texts and labels for labeled users...")
for node in nodes:
    node_id = node['id']
    if node_id.startswith('u') and node_id in labeled_user_ids:
        label = labels_dict[node_id]
        description = node.get('description', '')
        tweets = ' '.join(node.get('tweet', []))
        texts.append(description + ' ' + tweets)
        labels.append(label)

labels = torch.tensor(labels)
print("Labels tensor created with shape:", labels.shape)
print("Sample labels:", labels[:5].tolist())

# Dump texts & Labels for next uses, this will help for repeated implementation

texts_path = os.path.join(CURRENT_RUN_DIR, "texts.pkl")
labels_path = os.path.join(CURRENT_RUN_DIR, "labels.pt")

if os.path.exists(texts_path) and os.path.exists(labels_path):
    print("Loading precomputed texts and labels...")
    with open(texts_path, "rb") as f:
        texts = pickle.load(f)
    labels = torch.load(labels_path)
else:
    print("Computing texts and labels...")
    with open(texts_path, "wb") as f:
        pickle.dump(texts, f)
    torch.save(labels, labels_path)
    print("Texts and labels saved.")



In [None]:
# After Building Graph, Learning phase starts

import torch
import torch.nn as nn
import dgl
from transformers import BertModel

# ================================
# Step 1: Graph Transformer Module
# ================================

class GraphTransformer(nn.Module):
    def __init__(self, hidden_dim, num_heads, num_layers, edge_types, node_types):
        super(GraphTransformer, self).__init__()
        self.edge_types = edge_types
        self.layers = nn.ModuleList([
            nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads)
            for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(hidden_dim)
        self.projection = nn.ModuleDict({
            ntype: nn.Linear(hidden_dim, hidden_dim) for ntype in node_types
        })

    def forward(self, G):
        h = {ntype: G.nodes[ntype].data['feat'] for ntype in G.ntypes}

        for layer in self.layers:
            updated_h = {ntype: [] for ntype in G.ntypes}

            for etype in self.edge_types:
                subgraph = G.edge_type_subgraph([etype])

                src_type, _, dst_type = subgraph.to_canonical_etype(etype)

                src_feat = subgraph.srcdata['feat']
                dst_feat = subgraph.dstdata['feat']

                # Apply attention mechanism
                attn_output, _ = layer(src_feat.unsqueeze(1), dst_feat.unsqueeze(1), dst_feat.unsqueeze(1))
                attn_output = attn_output.squeeze(1)
                attn_output = self.layer_norm(attn_output)

                updated_h[dst_type].append(attn_output)

            for ntype in updated_h:
                if updated_h[ntype]:
                    aggregated_features = torch.stack(updated_h[ntype], dim=0).mean(dim=0)
                    h[ntype] = self.projection[ntype](aggregated_features)

        # mean pooling
        final_graph_embedding = torch.cat([h[ntype] for ntype in h], dim=0).mean(dim=0, keepdim=True)
        return final_graph_embedding


# ============================
# Step 2: Text Transformer Module
# ============================
class TextTransformer(nn.Module):
    def __init__(self):
        super(TextTransformer, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased")
        self.hidden_dim = self.bert.config.hidden_size

    def forward(self, input_ids, attention_mask):
        with torch.no_grad():  # Freeze LLM model for efficiency
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state.mean(dim=1)

# ============================
# Step 3: Fusion Layer
# ============================
class FusionLayer(nn.Module):
    def __init__(self, graph_dim, text_dim, fusion_dim):
        super(FusionLayer, self).__init__()
        self.fc = nn.Linear(graph_dim + text_dim, fusion_dim)
        self.cross_attention = nn.MultiheadAttention(embed_dim=fusion_dim, num_heads=8)
        self.layer_norm = nn.LayerNorm(fusion_dim)

    def forward(self, graph_emb, text_emb):
        graph_emb = graph_emb.repeat(text_emb.size(0), 1)
        combined = torch.cat((graph_emb, text_emb), dim=-1)
        combined = self.fc(combined)
        #cross_attention
        fused_emb, _ = self.cross_attention(combined.unsqueeze(1), combined.unsqueeze(1), combined.unsqueeze(1))
        return self.layer_norm(fused_emb.squeeze(1))

# ============================
# Step 4: Classification head
# ============================
class ClassificationHead(nn.Module):
    def __init__(self, hidden_dim, num_classes):
        super(ClassificationHead, self).__init__()
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# ============================
# Step 5: Build Multimodal Transformer
# ============================

class MultimodalTransformer(nn.Module):
    def __init__(self, graph_dim, text_dim, fusion_dim, num_classes, edge_types, node_types):
        super(MultimodalTransformer, self).__init__()
        self.graph_transformer = GraphTransformer(
            hidden_dim=graph_dim,
            num_heads=8,
            num_layers=2,
            edge_types=edge_types,
            node_types=node_types 
        )

        self.text_transformer = TextTransformer()
        self.fusion_layer = FusionLayer(graph_dim, text_dim, fusion_dim)
        self.classifier = ClassificationHead(fusion_dim, num_classes)

    def forward(self, G, input_ids, attention_mask):
        # Process graph embeddings
        graph_emb = self.graph_transformer(G)

        # Process text embeddings
        text_emb = self.text_transformer(input_ids, attention_mask)

        # Fuse graph and text embeddings
        fused_emb = self.fusion_layer(graph_emb, text_emb)

        # Classification logits
        logits = self.classifier(fused_emb)
        return logits

# ============================
# Data
# ============================    
encoded_inputs = tokenizer(
    texts,
    padding=True,
    truncation=True,
    return_tensors='pt',
    max_length=128,
    add_special_tokens=True
)

input_ids = encoded_inputs['input_ids']
attention_mask = encoded_inputs['attention_mask']

# Extract graph features
graph_feats = G.ndata['feat']

def preprocess_graph_features(graph_feats):
    all_features = []
    offsets = {}
    current_offset = 0
    for ntype, features in graph_feats.items():
        offsets[ntype] = current_offset
        all_features.append(features)
        current_offset += features.size(0)
    
    combined_features = torch.cat(all_features, dim=0)
    return combined_features, offsets

# Preprocess graph features
graph_feats_tensor, node_type_offsets = preprocess_graph_features(graph_feats)

labels_tensor = labels.clone().detach()


class MultimodalDataset(torch.utils.data.Dataset):
    def __init__(self, graph_feats_tensor, input_ids, attention_mask, labels):
        self.graph_feats_tensor = graph_feats_tensor
        self.input_ids = input_ids
        self.attention_mask = attention_mask
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return (
            self.graph_feats_tensor,
            self.input_ids[idx],
            self.attention_mask[idx],
            self.labels[idx],
        )

# dataset creation
dataset = MultimodalDataset(graph_feats_tensor, input_ids, attention_mask, labels_tensor)
print(f"Total dataset size: {len(dataset)}")

train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, sampler=RandomSampler(train_dataset), batch_size=8)
test_dataloader = DataLoader(test_dataset, sampler=SequentialSampler(test_dataset), batch_size=8)

print(f"Training dataset size: {len(train_dataset)}")
print(f"Testing dataset size: {len(test_dataset)}")
  
# ============================
# Training Setup
# ============================
# Define hyperparameters and initialize the model
graph_dim = 768 
text_dim = 768
fusion_dim = 512
num_classes = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MultimodalTransformer(
    graph_dim=graph_dim,
    text_dim=text_dim,
    fusion_dim=fusion_dim,
    num_classes=num_classes,
    edge_types=G.etypes,
    node_types=G.ntypes
).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# ============================
# Training Pipeline
# ============================

from sklearn.metrics import roc_auc_score, roc_curve, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Initialize lists for metrics tracking
train_losses = []
eval_losses = []
train_accuracies = []
eval_accuracies = []
all_labels_roc = []
all_preds_roc = [] 

# Training Loop
epochs = 100
best_eval_loss = float('inf')  # saving the best model
for epoch in range(epochs):
    model.train()
    total_train_loss = 0
    correct_train = 0
    total_train = 0
    for _, input_ids, attention_mask, labels in train_dataloader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(G, input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        # Accuracy tracking
        preds = torch.argmax(logits, dim=1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = correct_train / total_train

    train_losses.append(avg_train_loss)
    train_accuracies.append(train_accuracy)

    print(f"Epoch {epoch + 1}: Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")

    model.eval()
    total_eval_loss = 0
    correct_eval = 0
    total_eval = 0
    all_labels_roc_epoch = []
    all_preds_roc_epoch = []
    all_preds = []
    all_labels = [] 
    with torch.no_grad():
        for _, input_ids, attention_mask, labels in test_dataloader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            logits = model(G, input_ids, attention_mask)
            loss = criterion(logits, labels)
            total_eval_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            correct_eval += (preds == labels).sum().item()
            total_eval += labels.size(0)

            all_labels_roc_epoch.extend(labels.cpu().numpy())
            all_preds_roc_epoch.extend(torch.softmax(logits, dim=1)[:, 1].cpu().numpy())
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_labels_roc = all_labels_roc_epoch
    all_preds_roc = all_preds_roc_epoch

    avg_eval_loss = total_eval_loss / len(test_dataloader)
    eval_accuracy = correct_eval / total_eval

    eval_losses.append(avg_eval_loss)
    eval_accuracies.append(eval_accuracy)

    print(f"Epoch {epoch + 1}: Evaluation Loss: {avg_eval_loss:.4f}, Evaluation Accuracy: {eval_accuracy:.4f}")

    # I want to save the best model for further training or pack it for inference
    if avg_eval_loss < best_eval_loss:
        best_eval_loss = avg_eval_loss
        torch.save(model.state_dict(), "best_model.pth")
        print("Best model saved.")

predicted_classes = [1 if prob >= 0.5 else 0 for prob in all_preds_roc]
precision = precision_score(all_labels_roc, predicted_classes, average='weighted')
recall = recall_score(all_labels_roc, predicted_classes, average='weighted')
f1 = f1_score(all_labels_roc, predicted_classes, average='weighted')
roc_auc = roc_auc_score(all_labels_roc, all_preds_roc)

print(f"Final Evaluation Results - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, AUC-ROC: {roc_auc:.4f}")

#----------------------------------------------

#---------------------------------------------

In [None]:

# ============================
# Visualization: Loss and Accuracy
# ============================

# Plot Training Loss
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, marker='o', label="Training Loss", color='blue')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Training Loss Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()

# Plot Evaluation Loss
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(eval_losses) + 1), eval_losses, marker='x', linestyle='--', label="Evaluation Loss", color='red')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Evaluation Loss Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()

# Plot Training Accuracy
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, marker='o', label="Training Accuracy", color='green')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.title("Training Accuracy Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()

# Plot Evaluation Accuracy
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(eval_accuracies) + 1), eval_accuracies, marker='x', linestyle='--', label="Evaluation Accuracy", color='purple')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.title("Evaluation Accuracy Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()


# ============================
# Confusion Matrix
# ============================

# Confusion Matrix with Predicted Classes
cm = confusion_matrix(all_labels_roc, predicted_classes)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
plt.xlabel("Predicted Label", fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.title("Confusion Matrix (Predicted Classes)", fontsize=14)
plt.show()

# Confusion Matrix with Binary Predictions
cm2 = confusion_matrix(all_labels, predicted_classes)
plt.figure(figsize=(8, 6))
sns.heatmap(cm2, annot=True, fmt="d", cmap="Blues", cbar=False)
plt.xlabel("Predicted Label", fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.title("Confusion Matrix (Binary Predictions)", fontsize=14)
plt.show()


# ============================
# AUC-ROC Curve
# ============================

# AUC-ROC Curve
fpr, tpr, thresholds = roc_curve(all_labels_roc, all_preds_roc)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f"AUC-ROC = {roc_auc:.4f}", color='tab:blue')
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title("ROC Curve", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()


# ============================
# Metrics Visualization
# ============================

metrics = ["Accuracy", "Precision", "Recall", "F1-Score", "AUC-ROC"]
scores = [eval_accuracies[-1], precision, recall, f1, roc_auc]

plt.figure(figsize=(8, 5))
sns.barplot(x=metrics, y=scores, palette="viridis")
plt.ylim(0, 1)
plt.ylabel("Score", fontsize=12)
plt.title("Evaluation Metrics", fontsize=14)
plt.show()


In [None]:
# ---------------------------------
# Calling saved model for further train
#----------------------------------

import torch
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score

# Load the saved model
model.load_state_dict(torch.load("best_model.pth"))
model.to(device)
print("Loaded the best model.")

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
criterion = torch.nn.CrossEntropyLoss()

new_train_losses = []
new_eval_losses = []
new_train_accuracies = []
new_eval_accuracies = []

# Training for additional epochs
additional_epochs = 40
best_new_eval_loss = float('inf')

for epoch in range(additional_epochs):
    model.train()
    total_train_loss = 0
    correct_train = 0
    total_train = 0
    for _, input_ids, attention_mask, labels in train_dataloader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        logits = model(G, input_ids, attention_mask)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()

        preds = torch.argmax(logits, dim=1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    avg_train_loss = total_train_loss / len(train_dataloader)
    train_accuracy = correct_train / total_train

    new_train_losses.append(avg_train_loss)
    new_train_accuracies.append(train_accuracy)

    print(f"Epoch {epoch + 1}: Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")

    model.eval()
    total_eval_loss = 0
    correct_eval = 0
    total_eval = 0
    all_labels_roc_epoch = []
    all_preds_roc_epoch = []
    with torch.no_grad():
        for _, input_ids, attention_mask, labels in test_dataloader:
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device)

            logits = model(G, input_ids, attention_mask)
            loss = criterion(logits, labels)
            total_eval_loss += loss.item()

            preds = torch.argmax(logits, dim=1)
            correct_eval += (preds == labels).sum().item()
            total_eval += labels.size(0)

            all_labels_roc_epoch.extend(labels.cpu().numpy())
            all_preds_roc_epoch.extend(torch.softmax(logits, dim=1)[:, 1].cpu().numpy())

    avg_eval_loss = total_eval_loss / len(test_dataloader)
    eval_accuracy = correct_eval / total_eval

    new_eval_losses.append(avg_eval_loss)
    new_eval_accuracies.append(eval_accuracy)

    print(f"Epoch {epoch + 1}: Evaluation Loss: {avg_eval_loss:.4f}, Evaluation Accuracy: {eval_accuracy:.4f}")

    if avg_eval_loss < best_new_eval_loss:
        best_new_eval_loss = avg_eval_loss
        torch.save(model.state_dict(), "best_model_updated.pth")
        print("New best model saved.")

predicted_classes = [1 if prob >= 0.5 else 0 for prob in all_preds_roc_epoch]
precision = precision_score(all_labels_roc_epoch, predicted_classes, average='weighted')
recall = recall_score(all_labels_roc_epoch, predicted_classes, average='weighted')
f1 = f1_score(all_labels_roc_epoch, predicted_classes, average='weighted')
roc_auc = roc_auc_score(all_labels_roc_epoch, all_preds_roc_epoch)

print(f"Final Evaluation Results After Additional Training - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1:.4f}, AUC-ROC: {roc_auc:.4f}")


In [None]:
# ============================
# Second Visualization
# ============================

# Plot Training Loss
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_losses) + 1), train_losses, marker='o', label="Training Loss", color='blue')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Training Loss Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()

# Plot Evaluation Loss
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(eval_losses) + 1), eval_losses, marker='x', linestyle='--', label="Evaluation Loss", color='red')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.title("Evaluation Loss Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()

# Plot Training Accuracy
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(train_accuracies) + 1), train_accuracies, marker='o', label="Training Accuracy", color='green')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.title("Training Accuracy Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()

# Plot Evaluation Accuracy
plt.figure(figsize=(8, 5))
plt.plot(range(1, len(eval_accuracies) + 1), eval_accuracies, marker='x', linestyle='--', label="Evaluation Accuracy", color='purple')
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.title("Evaluation Accuracy Over Epochs", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()


# ============================
# Confusion Matrix
# ============================

# Confusion Matrix with Predicted Classes
cm = confusion_matrix(all_labels_roc, predicted_classes)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar=False)
plt.xlabel("Predicted Label", fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.title("Confusion Matrix (Predicted Classes)", fontsize=14)
plt.show()

# Confusion Matrix with Binary Predictions
cm2 = confusion_matrix(all_labels, predicted_classes)
plt.figure(figsize=(8, 6))
sns.heatmap(cm2, annot=True, fmt="d", cmap="Blues", cbar=False)
plt.xlabel("Predicted Label", fontsize=12)
plt.ylabel("True Label", fontsize=12)
plt.title("Confusion Matrix (Binary Predictions)", fontsize=14)
plt.show()


# ============================
# AUC-ROC Curve
# ============================

# AUC-ROC Curve
fpr, tpr, thresholds = roc_curve(all_labels_roc, all_preds_roc)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, label=f"AUC-ROC = {roc_auc:.4f}", color='tab:blue')
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
plt.xlabel("False Positive Rate", fontsize=12)
plt.ylabel("True Positive Rate", fontsize=12)
plt.title("ROC Curve", fontsize=14)
plt.legend(fontsize=12)
plt.grid(alpha=0.5, linestyle='--')
plt.show()


# ============================
# Metrics Visualization
# ============================

metrics = ["Accuracy", "Precision", "Recall", "F1-Score", "AUC-ROC"]
scores = [eval_accuracies[-1], precision, recall, f1, roc_auc]

plt.figure(figsize=(8, 5))
sns.barplot(x=metrics, y=scores, palette="viridis")
plt.ylim(0, 1)
plt.ylabel("Score", fontsize=12)
plt.title("Evaluation Metrics", fontsize=14)
plt.show()


In [None]:
# Finished