## Setup

In [None]:
!pip install transformers datasets scikit-learn sentence-transformers tqdm lightgbm torch accelerate -q

In [None]:
import os
import pandas as pd
from google.colab import drive
from datasets import load_dataset
import json
from sentence_transformers import SentenceTransformer, util, InputExample, losses, models
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from lightgbm import LGBMClassifier
from sklearn.model_selection import RandomizedSearchCV
from scipy.stats import uniform, randint, loguniform
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tqdm.notebook import tqdm
import numpy as np
import gc
import torch
from torch.utils.data import DataLoader
import random
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import (
    confusion_matrix,
    roc_curve,
    auc,
    precision_recall_curve,
    average_precision_score
)
import joblib

In [None]:
drive.mount('/content/drive', force_remount=True)
base_drive_path = '/content/drive/MyDrive/Bert&Ernie_shared_folder/'
data_path = os.path.join(base_drive_path, 'data')
models_path = os.path.join(base_drive_path, 'models')
results_path = os.path.join(base_drive_path, 'results')
sbert_finetuned_path = os.path.join(models_path, "sbert_triplet_finetuned")

In [None]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [None]:
!huggingface-cli login --token #INSERT TOKEN HERE

## Dataset

In [None]:
dataset = load_dataset("sapienzanlp/nlp2025_hw1_cultural_dataset")['train']

wiki_text_train_path = os.path.join(data_path, "wikipedia_texts_train.csv")
wiki_text_val_path = os.path.join(data_path, "wikipedia_texts_val.csv")
test_unlabeled_path = os.path.join(data_path, "test_unlabeled.csv")
test_texts_path = os.path.join(data_path, "wikipedia_texts_test.csv")


wiki_text_df = pd.read_csv(wiki_text_train_path)

id_to_text = dict(zip(wiki_text_df['id'], wiki_text_df['english_text']))

none_count = 0
for id, text in id_to_text.items():
    if text is None:
      none_count += 1
if none_count > 0:
    print(f"Warning: Found {none_count} entries with None text in train mapping.")

def extract_entity_id(url):
    return url.strip().split("/")[-1]

def add_text(example):
    entity_id = extract_entity_id(example["item"])
    text = id_to_text.get(entity_id, "")
    if not isinstance(text, str):
        text = ""
    return {"text": text}

print("Processing training data...")
dataset = dataset.map(add_text, num_proc=4)
train_df = dataset.to_pandas()
print(f"Training data processed. Shape: {train_df.shape}")
print(f"Number of empty texts in training data: {(train_df['text'] == '').sum()}")

In [None]:
print("\nProcessing validation data...")
val_dataset = load_dataset("sapienzanlp/nlp2025_hw1_cultural_dataset")['validation']
wiki_text_df_val = pd.read_csv(wiki_text_val_path)
id_to_text_val = dict(zip(wiki_text_df_val['id'], wiki_text_df_val['english_text']))
none_count_val = 0
for id, text in id_to_text_val.items():
    if text is None:
      none_count_val += 1
if none_count_val > 0:
    print(f"Warning: Found {none_count_val} entries with None text in validation mapping.")

def add_text_val(example):
    entity_id = extract_entity_id(example["item"])
    text = id_to_text_val.get(entity_id, "")
    if not isinstance(text, str):
        text = ""
    return {"text": text}


val_dataset = val_dataset.map(add_text_val, num_proc=4)
val_df = val_dataset.to_pandas()
print(f"Validation data processed. Shape: {val_df.shape}")
print(f"Number of empty texts in validation data: {(val_df['text'] == '').sum()}")

### Embedding Generation

#### Embedding Model Fine-tuning

In [None]:
print("\nGenerating initial embeddings...")
model = SentenceTransformer('all-mpnet-base-v2')

initial_embeddings_train = model.encode(train_df['text'].tolist(), show_progress_bar=True, convert_to_numpy=True)
initial_embeddings_val = model.encode(val_df['text'].tolist(), show_progress_bar=True, convert_to_numpy=True)
print("Initial embeddings generated.")
del model
gc.collect()

In [None]:
def create_triplets(df, label_col='label', text_col='text', n_samples_per_label=200, max_triplets=30000):
    print(f"Creating triplets...")
    triplets = []
    label_to_texts = df.groupby(label_col)[text_col].apply(list).to_dict()
    labels = list(label_to_texts.keys())
    num_labels = len(labels)

    if num_labels < 2:
        print("Warning: Need at least 2 labels to create negative samples for triplets.")
        return []

    created_count = 0
    for label in tqdm(labels, desc="Generating triplets"):
        positives = label_to_texts[label]
        if len(positives) < 2:
            continue

        num_possible_positive_pairs = len(positives) * (len(positives) - 1) // 2
        num_triplets_for_label = min(n_samples_per_label, num_possible_positive_pairs)

        positive_pairs = []
        for i in range(len(positives)):
             for j in range(i + 1, len(positives)):
                 positive_pairs.append((positives[i], positives[j]))

        random.shuffle(positive_pairs)
        selected_pairs = positive_pairs[:num_triplets_for_label]

        possible_negative_labels = [lbl for lbl in labels if lbl != label]
        if not possible_negative_labels: continue

        for anchor, positive in selected_pairs:
            if created_count >= max_triplets: break

            negative_label = random.choice(possible_negative_labels)
            if not label_to_texts[negative_label]: continue

            negative = random.choice(label_to_texts[negative_label])

            if anchor and positive and negative:
                 triplets.append(InputExample(texts=[anchor, positive, negative]))
                 created_count += 1
        if created_count >= max_triplets: break

    print(f"Created {len(triplets)} triplets (capped at {max_triplets}).")
    return triplets

In [None]:
triplet_examples = create_triplets(train_df, n_samples_per_label=500, max_triplets=50000)


In [None]:
model_name = "all-mpnet-base-v2"
sbert_finetuned_path = os.path.join(models_path, "sbert_triplet_finetuned")

print(f"\nLoading base model {model_name} for fine-tuning...")
triplet_model = SentenceTransformer(model_name)

In [None]:
train_dataloader = DataLoader(triplet_examples, shuffle=True, batch_size=32)
train_loss = losses.TripletLoss(model=triplet_model)

In [None]:
print("Starting fine-tuning with TripletLoss...")
triplet_model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=1,
    warmup_steps=100,
    output_path=sbert_finetuned_path,
    show_progress_bar=True
)
print(f"Fine-tuning finished. Model saved to: {sbert_finetuned_path}")
del triplet_model
del initial_embeddings_train
del initial_embeddings_val
del triplet_examples
del train_dataloader

#### Generate embeddings

In [None]:
print("\nGenerating embeddings using the fine-tuned model...")
finetuned_model = SentenceTransformer(sbert_finetuned_path)

X_train = finetuned_model.encode(train_df['text'].tolist(), show_progress_bar=True, convert_to_numpy=True)
X_val = finetuned_model.encode(val_df['text'].tolist(), show_progress_bar=True, convert_to_numpy=True)
print("Fine-tuned embeddings generated for train and validation sets.")

y_train = train_df['label'].values
y_val = val_df['label'].values


del finetuned_model

gc.collect()

In [None]:
class_labels = sorted(np.unique(y_train))
print(f"Class labels found for plotting: {class_labels}")

## Model

### Training

In [None]:
print("Training kNN...")
knn = KNeighborsClassifier(n_neighbors=7)
knn.fit(X_train, y_train)
knn_preds = knn.predict(X_val)
print("\nkNN Classification Report:")
print(classification_report(y_val, knn_preds, zero_division=0))

In [None]:
print("\nTraining Logistic Regression...")
logreg = LogisticRegression(max_iter=1500, random_state=42, C=0.5)
logreg.fit(X_train, y_train)
logreg_preds = logreg.predict(X_val)
print("\nLogistic Regression Classification Report:")
print(classification_report(y_val, logreg_preds, zero_division=0))
print("-" * 40)

In [None]:
print("\nTraining SVM (RBF Kernel)...")
svm_clf = SVC(kernel='rbf', C=1.0, gamma='scale', random_state=42, cache_size=500)
svm_clf.fit(X_train, y_train)
svm_preds = svm_clf.predict(X_val)
print("\nSVM (RBF Kernel) Classification Report:")
print(classification_report(y_val, svm_preds, zero_division=0))

### Hyperparameter Tuning (Use this to re-train)

In [None]:
N_ITER_SEARCH = 20
CV_FOLDS = 3

#### KNN

In [None]:
print("\n--- Tuning kNN ---")
param_dist_knn = {
    'n_neighbors': randint(3, 15),
    'weights': ['uniform', 'distance'],
    'metric': ['euclidean', 'manhattan', 'cosine']
}
knn_clf = KNeighborsClassifier()
random_search_knn = RandomizedSearchCV(knn_clf, param_distributions=param_dist_knn,
                                       n_iter=N_ITER_SEARCH, cv=CV_FOLDS, scoring='f1_macro',
                                       n_jobs=-1, random_state=42, verbose=1)
random_search_knn.fit(X_train, y_train)

print("\nBest kNN Parameters:", random_search_knn.best_params_)
print("Best kNN Cross-validation F1-Macro Score:", random_search_knn.best_score_)
knn_tuned_preds = random_search_knn.best_estimator_.predict(X_val)
print("\nkNN Classification Report (Tuned):")
print(classification_report(y_val, knn_tuned_preds, zero_division=0))
del knn_clf
gc.collect()

#### Logistic Regression

In [None]:
param_dist_logreg = {
    'C': loguniform(1e-3, 1e2),
    'penalty': ['l1', 'l2'],
    'solver': ['saga'],
    'max_iter': [2000, 3000, 4000]
}
logreg_clf = LogisticRegression(random_state=42, n_jobs=-1)
random_search_logreg = RandomizedSearchCV(logreg_clf, param_distributions=param_dist_logreg,
                                          n_iter=N_ITER_SEARCH, cv=CV_FOLDS, scoring='f1_macro',
                                          n_jobs=-1, random_state=42, verbose=1)
random_search_logreg.fit(X_train, y_train)

print("\nBest Logistic Regression Parameters:", random_search_logreg.best_params_)
print("Best Logistic Regression Cross-validation F1-Macro Score:", random_search_logreg.best_score_)
logreg_tuned_preds = random_search_logreg.best_estimator_.predict(X_val)
print("\nLogistic Regression Classification Report (Tuned):")
print(classification_report(y_val, logreg_tuned_preds, zero_division=0))
del logreg_clf
gc.collect()

#### SVM

In [None]:
param_dist_svm = {
    'C': loguniform(1e-1, 1e2),
    'kernel': ['rbf', 'linear'],
    'gamma': ['scale', 'auto'] + list(loguniform(1e-4, 1e-1).rvs(5)),
}
svm_clf_tune = SVC(random_state=42, cache_size=700, probability=True)
random_search_svm = RandomizedSearchCV(svm_clf_tune, param_distributions=param_dist_svm,
                                       n_iter=N_ITER_SEARCH // 2, cv=CV_FOLDS, scoring='f1_macro',
                                       n_jobs=-1, random_state=42, verbose=1)
random_search_svm.fit(X_train, y_train)

print("\nBest SVM Parameters:", random_search_svm.best_params_)
print("Best SVM Cross-validation F1-Macro Score:", random_search_svm.best_score_)
svm_tuned_preds = random_search_svm.best_estimator_.predict(X_val)
print("\nSVM Classification Report (Tuned):")
print(classification_report(y_val, svm_tuned_preds, zero_division=0))
del svm_clf_tune
gc.collect()

#### Model Save

In [None]:
print("\n--- Saving Tuned Models ---")
model_save_dir = models_path

def save_model(estimator, model_name, search_object_name):
    try:
        if search_object_name not in globals():
             print(f"'{search_object_name}' not found. Skipping saving {model_name} model.")
             return

        best_estimator = globals()[search_object_name].best_estimator_
        filename = os.path.join(model_save_dir, f"best_{model_name}_model.joblib")
        joblib.dump(best_estimator, filename)
        print(f"Saved best {model_name} model to {filename}")
    except AttributeError:
         print(f"'{search_object_name}' does not have 'best_estimator_'. Was tuning run?")
    except Exception as e:
        print(f"Error saving {model_name} model: {e}")

save_model(KNeighborsClassifier, "knn", "random_search_knn")
save_model(LogisticRegression, "logreg", "random_search_logreg")
save_model(SVC, "svm", "random_search_svm")


print("\nModel saving process finished.")
gc.collect()

### Post-training Evaluation

The following code tests the SVM model, since it's our model of choice given its performance. It's possible to test the others by modifiying the variable loaded_svm_model and fetching the best estimator of the search.

In [None]:
print("--- Starting Evaluation on Unlabeled Test Set ---")


output_filename = "test_unlabelled_with_svm_predictions.csv"
output_file_path = os.path.join(results_path, output_filename)

model_save_dir = models_path

finetuned_sbert_path = os.path.join(models_path, "sbert_triplet_finetuned")

print(f"Loading unlabeled test data from: {test_unlabeled_path}")
if not os.path.exists(test_unlabeled_path):
    print(f"Error: File not found - {test_unlabeled_path}")
    raise FileNotFoundError(f"Required file not found: {test_unlabeled_path}")
test_df = pd.read_csv(test_unlabeled_path)
print(f"Loaded {len(test_df)} rows from test data.")

print(f"Loading test texts from: {test_texts_path}")
if not os.path.exists(test_texts_path):
    print(f"Error: File not found - {test_texts_path}")
    raise FileNotFoundError(f"Required file not found: {test_texts_path}")
wiki_text_df_test = pd.read_csv(test_texts_path)
print(f"Loaded {len(wiki_text_df_test)} rows from test texts.")

print("Creating text mapping for test data...")
id_to_text_test = dict(zip(wiki_text_df_test['id'], wiki_text_df_test['english_text']))

none_count_test = 0
for id_key, text_val in id_to_text_test.items():
    if text_val is None:
        none_count_test += 1
if none_count_test > 0:
    print(f"Warning: Found {none_count_test} entries with None text in the test mapping.")

def extract_entity_id(url):
    if isinstance(url, str) and "/" in url:
        return url.strip().split("/")[-1]
    return None

def add_text_test(row):
    entity_id = extract_entity_id(row["item"])
    if entity_id:
        text = id_to_text_test.get(entity_id, "")
        if not isinstance(text, str):
            return ""
        return text
    else:
        return ""

print("Mapping texts to the test dataframe...")
tqdm.pandas(desc="Adding text column")
test_df['text'] = test_df.progress_apply(add_text_test, axis=1)

num_empty_texts_test = (test_df['text'] == '').sum()
if num_empty_texts_test > 0:
     print(f"Warning: {num_empty_texts_test} rows in the test set have missing or empty text after mapping.")
print("Text mapping complete.")

print(f"Loading fine-tuned Sentence Transformer from: {finetuned_sbert_path}")
if not os.path.exists(finetuned_sbert_path):
     print(f"Error: Fine-tuned Sentence Transformer directory not found at {finetuned_sbert_path}")
     raise FileNotFoundError(f"Directory not found: {finetuned_sbert_path}. Ensure the fine-tuning step completed successfully and saved to the correct Drive path.")
try:
    eval_sbert_model = SentenceTransformer(finetuned_sbert_path)
    print("Sentence Transformer loaded successfully.")
except Exception as e:
    print(f"Error loading Sentence Transformer model: {e}")
    raise

print("Generating embeddings for the test data...")
test_texts = test_df['text'].tolist()
test_texts_cleaned = [str(text) if pd.notna(text) else "" for text in test_texts]

X_test_unlabelled = eval_sbert_model.encode(
    test_texts_cleaned,
    show_progress_bar=True,
    convert_to_numpy=True,
    batch_size=128
)
print(f"Generated {X_test_unlabelled.shape[0]} embeddings with dimension {X_test_unlabelled.shape[1]}.")

del eval_sbert_model
del test_texts
del test_texts_cleaned
gc.collect()
torch.cuda.empty_cache()

loaded_svm_model = random_search_svm.best_estimator_

print("Making predictions on the test data...")
test_predictions = loaded_svm_model.predict(X_test_unlabelled)
print(f"Generated {len(test_predictions)} predictions.")

print("Adding predictions to the test dataframe...")
output_df = test_df.copy()
output_df['label'] = test_predictions


print(f"Saving results with predictions to: {output_file_path}")
try:
    output_df.to_csv(output_file_path, index=False)
    print("Output file saved successfully.")
except Exception as e:
    print(f"Error saving output file: {e}")

del test_df
del wiki_text_df_test
del id_to_text_test
del X_test_unlabelled
del loaded_svm_model
del test_predictions
del output_df
gc.collect()

print("--- Evaluation script finished ---")

## Standalone Evaluation

In [None]:
print("--- Starting Evaluation on Unlabeled Test Set ---")


output_filename = "test_unlabelled_with_svm_predictions.csv"
output_file_path = os.path.join(results_path, output_filename)


model_save_dir = models_path

svm_model_filename = "best_svm_model.joblib"
svm_model_path = os.path.join(model_save_dir, svm_model_filename)

finetuned_sbert_path = os.path.join(models_path, "sbert_triplet_finetuned")

print(f"Loading unlabeled test data from: {test_unlabeled_path}")
if not os.path.exists(test_unlabeled_path):
    print(f"Error: File not found - {test_unlabeled_path}")
    raise FileNotFoundError(f"Required file not found: {test_unlabeled_path}")
test_df = pd.read_csv(test_unlabeled_path)
print(f"Loaded {len(test_df)} rows from test data.")

print(f"Loading test texts from: {test_texts_path}")
if not os.path.exists(test_texts_path):
    print(f"Error: File not found - {test_texts_path}")
    raise FileNotFoundError(f"Required file not found: {test_texts_path}")
wiki_text_df_test = pd.read_csv(test_texts_path)
print(f"Loaded {len(wiki_text_df_test)} rows from test texts.")

print("Creating text mapping for test data...")
id_to_text_test = dict(zip(wiki_text_df_test['id'], wiki_text_df_test['english_text']))

none_count_test = 0
for id_key, text_val in id_to_text_test.items():
    if text_val is None:
        none_count_test += 1
if none_count_test > 0:
    print(f"Warning: Found {none_count_test} entries with None text in the test mapping.")

def extract_entity_id(url):
    if isinstance(url, str) and "/" in url:
        return url.strip().split("/")[-1]
    return None

def add_text_test(row):
    entity_id = extract_entity_id(row["item"])
    if entity_id:
        text = id_to_text_test.get(entity_id, "")
        if not isinstance(text, str):
            return ""
        return text
    else:
        return ""

print("Mapping texts to the test dataframe...")
tqdm.pandas(desc="Adding text column")
test_df['text'] = test_df.progress_apply(add_text_test, axis=1)

num_empty_texts_test = (test_df['text'] == '').sum()
if num_empty_texts_test > 0:
     print(f"Warning: {num_empty_texts_test} rows in the test set have missing or empty text after mapping.")
print("Text mapping complete.")

print(f"Loading fine-tuned Sentence Transformer from: {finetuned_sbert_path}")
if not os.path.exists(finetuned_sbert_path):
     print(f"Error: Fine-tuned Sentence Transformer directory not found at {finetuned_sbert_path}")
     raise FileNotFoundError(f"Directory not found: {finetuned_sbert_path}. Ensure the fine-tuning step completed successfully and saved to the correct Drive path.")
try:
    eval_sbert_model = SentenceTransformer(finetuned_sbert_path)
    print("Sentence Transformer loaded successfully.")
except Exception as e:
    print(f"Error loading Sentence Transformer model: {e}")
    raise

print("Generating embeddings for the test data...")
test_texts = test_df['text'].tolist()
test_texts_cleaned = [str(text) if pd.notna(text) else "" for text in test_texts]

X_test_unlabelled = eval_sbert_model.encode(
    test_texts_cleaned,
    show_progress_bar=True,
    convert_to_numpy=True,
    batch_size=128
)
print(f"Generated {X_test_unlabelled.shape[0]} embeddings with dimension {X_test_unlabelled.shape[1]}.")

del eval_sbert_model
del test_texts
del test_texts_cleaned
gc.collect()
torch.cuda.empty_cache()

print(f"Loading the fine-tuned SVM model from: {svm_model_path}")
if not os.path.exists(svm_model_path):
    print(f"Error: SVM model file not found at {svm_model_path}")
    print("Please ensure the model saving block executed correctly and saved the file to the correct Drive path.")
    raise FileNotFoundError(f"Model file not found: {svm_model_path}")

try:
    loaded_svm_model = joblib.load(svm_model_path)
    print("Fine-tuned SVM model loaded successfully.")
except Exception as e:
    print(f"Error loading SVM model: {e}")
    raise

print("Making predictions on the test data...")
test_predictions = loaded_svm_model.predict(X_test_unlabelled)
print(f"Generated {len(test_predictions)} predictions.")

print("Adding predictions to the test dataframe...")
output_df = test_df.copy()
output_df['label'] = test_predictions


print(f"Saving results with predictions to: {output_file_path}")
try:
    output_df.to_csv(output_file_path, index=False)
    print("Output file saved successfully.")
except Exception as e:
    print(f"Error saving output file: {e}")

del test_df
del wiki_text_df_test
del id_to_text_test
del X_test_unlabelled
del loaded_svm_model
del test_predictions
del output_df
gc.collect()

print("--- Evaluation script finished ---")

## Additional Blocks

### Plotting

In [None]:
PLOT_DIR = os.path.join(results_path, "evaluation_plots")
os.makedirs(PLOT_DIR, exist_ok=True)
print(f"Plots will be saved to: {PLOT_DIR}")

In [None]:
def plot_confusion_matrix(y_true, y_pred, classes, model_name, save_dir=PLOT_DIR):
    try:
        cm = confusion_matrix(y_true, y_pred, labels=classes)
        plt.figure(figsize=(10, 8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.title(f'Confusion Matrix - {model_name}')
        plt.ylabel('Actual Label')
        plt.xlabel('Predicted Label')
        filename = os.path.join(save_dir, f"{model_name}_confusion_matrix.png")
        plt.savefig(filename, bbox_inches='tight')
        print(f"Saved confusion matrix to {filename}")
    except Exception as e:
        print(f"Error plotting confusion matrix for {model_name}: {e}")
    finally:
        plt.close()

def plot_multiclass_roc_curve(y_true, y_scores, classes, model_name, score_type="probability", save_dir=PLOT_DIR):
    plot_title = f'Multi-class ROC ({score_type.replace("_"," ").title()}) - {model_name}'
    print(f"Generating ROC curve for {model_name} using {score_type} scores...")
    try:
        n_classes = len(classes)
        fpr = dict()
        tpr = dict()
        roc_auc = dict()

        if len(y_scores.shape) > 1 and y_scores.shape[1] != n_classes and n_classes > 1 :
             print(f"Warning: Score shape {y_scores.shape} mismatch with n_classes {n_classes} for ROC. Attempting alignment.")
             unique_true = sorted(np.unique(y_true))
             if y_scores.shape[1] == len(unique_true):
                 classes = unique_true
                 n_classes = len(classes)
                 print(f"Aligned based on unique true labels. New n_classes: {n_classes}")
             else:
                  print("Cannot align score shapes for ROC. Skipping per-class curves.")
                  try:
                       y_true_bin_flat = pd.get_dummies(y_true).values.ravel()
                       y_scores_flat = y_scores.ravel()
                       if y_true_bin_flat.size == y_scores_flat.size:
                            fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin_flat, y_scores_flat)
                            roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
                            plt.figure(figsize=(8, 6))
                            plt.plot(fpr["micro"], tpr["micro"], label=f'Micro-average ROC (area = {roc_auc["micro"]:0.2f})', color='deeppink', linestyle=':', linewidth=4)
                            plt.plot([0, 1], [0, 1], 'k--', lw=2)
                            plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
                            plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
                            plt.title(f'Micro-Average ROC - {model_name}')
                            plt.legend(loc="lower right")
                            filename = os.path.join(save_dir, f"{model_name}_roc_curve_micro_only.png")
                            plt.savefig(filename, bbox_inches='tight')
                            print(f"Saved Micro-Average ROC curve plot to {filename}")
                            plt.close()
                       else:
                            print("Cannot compute Micro-Average ROC due to size mismatch.")
                  except Exception as micro_e:
                       print(f"Error computing Micro-Average ROC: {micro_e}")
                  return

        y_true_bin = pd.get_dummies(y_true, columns=classes).values

        valid_classes_for_macro = 0
        for i in range(n_classes):
             if i < y_true_bin.shape[1]:
                 current_scores = y_scores[:, i] if len(y_scores.shape) > 1 else y_scores
                 if len(y_scores.shape) == 1 and n_classes == 2 and i == 0:
                      fpr[i], tpr[i], roc_auc[i] = np.array([0]), np.array([0]), 0.0
                      continue
                 elif len(y_scores.shape) == 1 and n_classes == 2 and i == 1:
                       current_scores = y_scores

                 if i < current_scores.shape[0] and np.sum(y_true_bin[:, i]) > 0 and np.sum(y_true_bin[:, i]) < len(y_true_bin[:, i]):
                     try:
                         fpr[i], tpr[i], _ = roc_curve(y_true_bin[:, i], current_scores)
                         roc_auc[i] = auc(fpr[i], tpr[i])
                         valid_classes_for_macro += 1
                     except ValueError as roc_err:
                          print(f"Skipping ROC for class {classes[i]} due to error: {roc_err}")
                          fpr[i], tpr[i], roc_auc[i] = np.array([0]), np.array([0]), 0.0
                 else:
                      print(f"Skipping ROC for class {classes[i]} - class not present/only one class in y_true subset or index issue.")
                      fpr[i], tpr[i], roc_auc[i] = np.array([0]), np.array([0]), 0.0
             else:
                 fpr[i], tpr[i], roc_auc[i] = np.array([0]), np.array([0]), 0.0

        try:
             y_true_bin_flat = y_true_bin.ravel()
             y_scores_flat = y_scores.ravel()
             if y_true_bin_flat.size == y_scores_flat.size:
                 fpr["micro"], tpr["micro"], _ = roc_curve(y_true_bin_flat, y_scores_flat)
                 roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
             else:
                 print("Skipping Micro-Average ROC due to size mismatch after binarization.")
                 roc_auc["micro"] = 0.0
                 fpr["micro"], tpr["micro"] = np.array([0]), np.array([0])
        except Exception as e:
             print(f"Error calculating Micro-Average ROC: {e}")
             roc_auc["micro"] = 0.0
             fpr["micro"], tpr["micro"] = np.array([0]), np.array([0])


        if valid_classes_for_macro > 0:
            all_fpr = np.unique(np.concatenate([fpr[i] for i in range(n_classes) if i in fpr and len(fpr[i])>0]))
            mean_tpr = np.zeros_like(all_fpr)
            for i in range(n_classes):
                if i in fpr and i in tpr and len(fpr[i]) > 0 and len(tpr[i]) > 0:
                    mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
            mean_tpr /= valid_classes_for_macro
            fpr["macro"] = all_fpr
            tpr["macro"] = mean_tpr
            roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])
        else:
            print("Skipping Macro-Average ROC: No valid per-class curves computed.")
            roc_auc["macro"] = 0.0
            fpr["macro"], tpr["macro"] = np.array([0]), np.array([0])


        plt.figure(figsize=(10, 8))
        if "micro" in roc_auc and roc_auc["micro"] > 0:
             plt.plot(fpr["micro"], tpr["micro"],
                     label=f'Micro-average ROC (area = {roc_auc["micro"]:0.2f})',
                     color='deeppink', linestyle=':', linewidth=4)
        if "macro" in roc_auc and roc_auc["macro"] > 0:
             plt.plot(fpr["macro"], tpr["macro"],
                     label=f'Macro-average ROC (area = {roc_auc["macro"]:0.2f})',
                     color='navy', linestyle=':', linewidth=4)

        colors = plt.cm.get_cmap('tab10', n_classes)
        for i, color in zip(range(n_classes), colors(range(n_classes))):
            if i in roc_auc and roc_auc[i] > 0:
                 plt.plot(fpr[i], tpr[i], color=color, lw=2,
                         label=f'ROC Class {classes[i]} (area = {roc_auc[i]:0.2f})')

        plt.plot([0, 1], [0, 1], 'k--', lw=2)
        plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate'); plt.ylabel('True Positive Rate')
        plt.title(plot_title)
        plt.legend(loc="lower right", fontsize='small')
        filename = os.path.join(save_dir, f"{model_name}_roc_curve.png")
        plt.savefig(filename, bbox_inches='tight')
        print(f"Saved ROC curve plot to {filename}")

    except Exception as e:
        print(f"Error plotting ROC curve for {model_name}: {e}")
    finally:
        plt.close()


def plot_multiclass_precision_recall(y_true, y_scores, classes, model_name, score_type="probability", save_dir=PLOT_DIR):
    plot_title = f'Multi-class Precision-Recall ({score_type.replace("_"," ").title()}) - {model_name}'
    print(f"Generating Precision-Recall curve for {model_name} using {score_type} scores...")
    try:
        n_classes = len(classes)
        precision = dict()
        recall = dict()
        average_precision = dict()

        if len(y_scores.shape) > 1 and y_scores.shape[1] != n_classes and n_classes > 1:
             print(f"Warning: Score shape {y_scores.shape} mismatch with n_classes {n_classes} for PR. Attempting alignment.")
             unique_true = sorted(np.unique(y_true))
             if y_scores.shape[1] == len(unique_true):
                 classes = unique_true
                 n_classes = len(classes)
                 print(f"Aligned based on unique true labels. New n_classes: {n_classes}")
             else:
                  print("Cannot align score shapes for PR. Skipping per-class curves.")
                  try:
                       y_true_bin_flat = pd.get_dummies(y_true).values.ravel()
                       y_scores_flat = y_scores.ravel()
                       if y_true_bin_flat.size == y_scores_flat.size:
                           precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_bin_flat, y_scores_flat)
                           y_true_bin_temp = pd.get_dummies(y_true).values
                           if len(y_scores.shape) > 1 and y_true_bin_temp.shape[1] == y_scores.shape[1]:
                               average_precision["micro"] = average_precision_score(y_true_bin_temp, y_scores, average="micro")
                           else:
                               average_precision["micro"] = average_precision_score(y_true_bin_flat, y_scores_flat)


                           plt.figure(figsize=(8, 6))
                           plt.plot(recall["micro"], precision["micro"], label=f'Micro-average PR (AP = {average_precision["micro"]:0.2f})', color='navy', linestyle=':', linewidth=4)
                           plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
                           plt.xlabel('Recall'); plt.ylabel('Precision')
                           plt.title(f'Micro-Average Precision-Recall - {model_name}')
                           plt.legend(loc="best")
                           filename = os.path.join(save_dir, f"{model_name}_precision_recall_curve_micro_only.png")
                           plt.savefig(filename, bbox_inches='tight')
                           print(f"Saved Micro-Average PR curve plot to {filename}")
                           plt.close()
                       else:
                            print("Cannot compute Micro-Average PR due to size mismatch.")
                  except Exception as micro_e:
                       print(f"Error computing Micro-Average PR: {micro_e}")
                  return

        y_true_bin = pd.get_dummies(y_true, columns=classes).values

        valid_classes_count = 0
        for i in range(n_classes):
             if i < y_true_bin.shape[1]:
                 current_scores = y_scores[:, i] if len(y_scores.shape) > 1 else y_scores
                 if len(y_scores.shape) == 1 and n_classes == 2 and i == 0:
                       precision[i], recall[i], average_precision[i] = np.array([0]), np.array([1]), 0.0
                       continue
                 elif len(y_scores.shape) == 1 and n_classes == 2 and i == 1:
                       current_scores = y_scores

                 if i < current_scores.shape[0] and np.sum(y_true_bin[:, i]) > 0:
                     try:
                         precision[i], recall[i], _ = precision_recall_curve(y_true_bin[:, i], current_scores)
                         average_precision[i] = average_precision_score(y_true_bin[:, i], current_scores)
                         valid_classes_count += 1
                     except ValueError as pr_err:
                         print(f"Skipping PR for class {classes[i]} due to error: {pr_err}")
                         precision[i], recall[i], average_precision[i] = np.array([0]), np.array([1]), 0.0

                 else:
                     print(f"Skipping PR for class {classes[i]} - no positive samples or index issue.")
                     precision[i], recall[i], average_precision[i] = np.array([0]), np.array([1]), 0.0
             else:
                  precision[i], recall[i], average_precision[i] = np.array([0]), np.array([1]), 0.0

        try:
             y_true_bin_flat = y_true_bin.ravel()
             y_scores_flat = y_scores.ravel()
             if y_true_bin_flat.size == y_scores_flat.size:
                 precision["micro"], recall["micro"], _ = precision_recall_curve(y_true_bin_flat, y_scores_flat)
                 if len(y_scores.shape) > 1 and y_true_bin.shape[1] == y_scores.shape[1]:
                     average_precision["micro"] = average_precision_score(y_true_bin, y_scores, average="micro")
                 else:
                     average_precision["micro"] = average_precision_score(y_true_bin_flat, y_scores_flat)
             else:
                 print("Skipping Micro-Average PR due to size mismatch after binarization.")
                 average_precision["micro"] = 0.0
                 precision["micro"], recall["micro"] = np.array([0]), np.array([1])
        except Exception as e:
             print(f"Error calculating Micro-Average PR: {e}")
             average_precision["micro"] = 0.0
             precision["micro"], recall["micro"] = np.array([0]), np.array([1])


        plt.figure(figsize=(10, 8))
        if "micro" in average_precision and average_precision["micro"] > 0:
             plt.step(recall['micro'], precision['micro'], where='post', label=f'Micro-average PR (AP = {average_precision["micro"]:0.2f})', color='navy', linestyle=':', linewidth=4)

        colors = plt.cm.get_cmap('tab10', n_classes)
        for i, color in zip(range(n_classes), colors(range(n_classes))):
             if i in average_precision and average_precision[i] > 0:
                 plt.step(recall[i], precision[i], where='post', color=color, lw=2, label=f'PR Class {classes[i]} (AP = {average_precision[i]:0.2f})')

        plt.xlim([0.0, 1.0]); plt.ylim([0.0, 1.05])
        plt.xlabel('Recall'); plt.ylabel('Precision')
        plt.title(plot_title)
        plt.legend(loc="best", fontsize='small')
        filename = os.path.join(save_dir, f"{model_name}_precision_recall_curve.png")
        plt.savefig(filename, bbox_inches='tight')
        print(f"Saved Precision-Recall curve plot to {filename}")

    except Exception as e:
        print(f"Error plotting Precision-Recall curve for {model_name}: {e}")
    finally:
        plt.close()


def plot_score_distribution(y_scores, classes, model_name, score_type="probability", save_dir=PLOT_DIR):
    plot_title = f'Prediction Score Distributions ({score_type.replace("_"," ").title()}) - {model_name}'
    xlabel = score_type.replace("_"," ").title()

    print(f"Generating score distribution plot for {model_name} using {score_type} scores...")
    try:
        if len(y_scores.shape) == 1:
            if len(classes) == 2:
                print("Plotting distribution for binary scores (assuming score for class 1).")
                n_classes_plot = 1
                plot_labels = [f"Score ({classes[1]})"]
                scores_to_plot = [y_scores]
            else:
                print(f"Warning: Received 1D scores for {len(classes)} classes. Cannot plot distributions.")
                return
        elif len(y_scores.shape) > 1:
             n_classes_plot = y_scores.shape[1]
             if n_classes_plot != len(classes):
                 print(f"Warning: Number of score columns ({n_classes_plot}) differs from classes ({len(classes)}). Using column index for labels.")
                 plot_labels = [f"Score Col {i}" for i in range(n_classes_plot)]
             else:
                 plot_labels = [f"Score ({c})" for c in classes]
             scores_to_plot = [y_scores[:, i] for i in range(n_classes_plot)]
        else:
             print(f"Unexpected shape for y_scores: {y_scores.shape}. Cannot plot.")
             return


        if n_classes_plot == 0:
            print(f"No scores provided for {model_name}. Skipping distribution plot.")
            return
        if n_classes_plot > 15:
            print(f"Skipping score distribution plot for {model_name} due to high number of classes/columns ({n_classes_plot}).")
            return

        n_rows = (n_classes_plot + 1) // 2
        fig, axes = plt.subplots(n_rows, 2, figsize=(15, 5 * n_rows), squeeze=False)
        axes_flat = axes.flatten()

        for i in range(n_classes_plot):
            if i < len(axes_flat):
                sns.histplot(scores_to_plot[i], bins=30, kde=True, ax=axes_flat[i])
                axes_flat[i].set_title(f'{plot_labels[i]} Distribution')
                axes_flat[i].set_xlabel(xlabel)
                axes_flat[i].set_ylabel('Frequency')
                if score_type == 'probability':
                    axes_flat[i].set_xlim(0, 1)

        for j in range(n_classes_plot, len(axes_flat)):
            fig.delaxes(axes_flat[j])

        plt.suptitle(plot_title, y=1.03)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        filename = os.path.join(save_dir, f"{model_name}_score_distribution.png")
        plt.savefig(filename, bbox_inches='tight')
        print(f"Saved score distribution plot to {filename}")

    except Exception as e:
        print(f"Error plotting score distribution for {model_name}: {e}")
    finally:
        plt.close()

def save_misclassified_examples(y_true, y_pred, texts, model_name, save_dir=PLOT_DIR, n_examples=50):
    try:
        y_true = np.asarray(y_true)
        y_pred = np.asarray(y_pred)

        misclassified_mask = y_true != y_pred
        misclassified_indices = np.where(misclassified_mask)[0]

        if len(misclassified_indices) == 0:
            print(f"No misclassifications found for {model_name}.")
            return

        if isinstance(texts, (pd.Series, pd.DataFrame)):
             try:
                 if isinstance(texts.index, pd.RangeIndex) and texts.index.start == 0 and texts.index.step == 1:
                     misclassified_texts = texts.iloc[misclassified_indices].tolist()
                 else:
                     original_indices = texts.index[misclassified_indices]
                     misclassified_texts = texts.loc[original_indices].tolist()
             except Exception as e:
                 print(f"Error extracting texts using index, falling back to iloc: {e}")
                 misclassified_texts = texts.iloc[misclassified_indices].tolist()
        elif isinstance(texts, (list, np.ndarray)):
             misclassified_texts = [texts[i] for i in misclassified_indices]
        else:
            print(f"Unsupported text data type ({type(texts)}) for misclassified examples. Skipping.")
            return


        actual_labels = y_true[misclassified_indices]
        predicted_labels = y_pred[misclassified_indices]

        misclassified_df = pd.DataFrame({
            'Actual Label': actual_labels,
            'Predicted Label': predicted_labels,
            'Text': misclassified_texts
        })

        if len(misclassified_df) > n_examples:
            misclassified_df = misclassified_df.sample(n=n_examples, random_state=42)

        filename = os.path.join(save_dir, f"{model_name}_misclassified_examples.csv")
        misclassified_df.to_csv(filename, index=False, encoding='utf-8')
        print(f"Saved {len(misclassified_df)} misclassified examples to {filename}")

    except Exception as e:
        print(f"Error saving misclassified examples for {model_name}: {e}")

In [None]:
print("\nGenerating plots for initial kNN...")
try:
    if hasattr(knn, 'predict_proba'):
        knn_probs = knn.predict_proba(X_val)
        plot_multiclass_roc_curve(y_val, knn_probs, class_labels, "kNN_Initial")
        plot_multiclass_precision_recall(y_val, knn_probs, class_labels, "kNN_Initial")
        plot_score_distribution(knn_probs, class_labels, "kNN_Initial")
    else:
        print("kNN model does not support predict_proba. Skipping ROC, PR, Prob Dist plots.")

    plot_confusion_matrix(y_val, knn_preds, class_labels, "kNN_Initial")
    save_misclassified_examples(y_val, knn_preds, val_df['text'], "kNN_Initial")

except Exception as e:
    print(f"An error occurred during kNN_Initial plotting: {e}")
finally:
    plt.close('all')
    gc.collect()

In [None]:
print("\nGenerating plots for initial Logistic Regression...")
try:
    logreg_probs = logreg.predict_proba(X_val)
    plot_multiclass_roc_curve(y_val, logreg_probs, class_labels, "LogReg_Initial")
    plot_multiclass_precision_recall(y_val, logreg_probs, class_labels, "LogReg_Initial")
    plot_score_distribution(logreg_probs, class_labels, "LogReg_Initial")
    plot_confusion_matrix(y_val, logreg_preds, class_labels, "LogReg_Initial")
    save_misclassified_examples(y_val, logreg_preds, val_df['text'], "LogReg_Initial")

except Exception as e:
    print(f"An error occurred during LogReg_Initial plotting: {e}")
finally:
    plt.close('all')
    gc.collect()

In [None]:
print("\nGenerating plots for initial SVM...")
y_scores = None
score_type = None
try:
    if hasattr(svm_clf, 'predict_proba') and callable(getattr(svm_clf, 'predict_proba', None)) and getattr(svm_clf, 'probability', False):
        y_scores = svm_clf.predict_proba(X_val)
        score_type = "probability"
        print("Using predict_proba for SVM plots.")
    elif hasattr(svm_clf, 'decision_function') and callable(getattr(svm_clf, 'decision_function', None)):
        y_scores = svm_clf.decision_function(X_val)
        score_type = "decision_function"
        print("Using decision_function for SVM plots.")
    else:
        print("SVM model supports neither predict_proba nor decision_function. Skipping score-based plots.")

    plot_confusion_matrix(y_val, svm_preds, class_labels, "SVM_Initial")
    save_misclassified_examples(y_val, svm_preds, val_df['text'], "SVM_Initial")

    if y_scores is not None and score_type is not None:
        plot_multiclass_roc_curve(y_val, y_scores, class_labels, "SVM_Initial", score_type=score_type)
        plot_multiclass_precision_recall(y_val, y_scores, class_labels, "SVM_Initial", score_type=score_type)
        plot_score_distribution(y_scores, class_labels, "SVM_Initial", score_type=score_type)

except Exception as e:
    print(f"An error occurred during SVM_Initial plotting: {e}")
finally:
    plt.close('all')
    gc.collect()

In [None]:
print("\nGenerating plots for Tuned kNN...")
try:
    best_knn = random_search_knn.best_estimator_
    if hasattr(best_knn, 'predict_proba'):
        knn_tuned_probs = best_knn.predict_proba(X_val)
        plot_multiclass_roc_curve(y_val, knn_tuned_probs, class_labels, "kNN_Tuned")
        plot_multiclass_precision_recall(y_val, knn_tuned_probs, class_labels, "kNN_Tuned")
        plot_score_distribution(knn_tuned_probs, class_labels, "kNN_Tuned")
    else:
         print("Tuned kNN model does not support predict_proba. Skipping ROC, PR, Prob Dist plots.")

    plot_confusion_matrix(y_val, knn_tuned_preds, class_labels, "kNN_Tuned")
    save_misclassified_examples(y_val, knn_tuned_preds, val_df['text'], "kNN_Tuned")

except NameError:
     print("Skipping Tuned kNN plots - 'random_search_knn' or 'knn_tuned_preds' not defined.")
except Exception as e:
    print(f"An error occurred during kNN_Tuned plotting: {e}")
finally:
    plt.close('all')
    gc.collect()

In [None]:
print("\nGenerating plots for Tuned LogReg...")
try:
    best_logreg = random_search_logreg.best_estimator_
    logreg_tuned_probs = best_logreg.predict_proba(X_val)
    plot_multiclass_roc_curve(y_val, logreg_tuned_probs, class_labels, "LogReg_Tuned")
    plot_multiclass_precision_recall(y_val, logreg_tuned_probs, class_labels, "LogReg_Tuned")
    plot_score_distribution(logreg_tuned_probs, class_labels, "LogReg_Tuned")
    plot_confusion_matrix(y_val, logreg_tuned_preds, class_labels, "LogReg_Tuned")
    save_misclassified_examples(y_val, logreg_tuned_preds, val_df['text'], "LogReg_Tuned")

except NameError:
     print("Skipping Tuned LogReg plots - 'random_search_logreg' or 'logreg_tuned_preds' not defined.")
except Exception as e:
    print(f"An error occurred during LogReg_Tuned plotting: {e}")
finally:
    plt.close('all')
    gc.collect()

In [None]:
print("\nGenerating plots for Tuned SVM...")
y_scores_tuned = None
score_type_tuned = None
try:
    best_svm = random_search_svm.best_estimator_

    if hasattr(best_svm, 'predict_proba') and callable(getattr(best_svm, 'predict_proba', None)) and getattr(best_svm, 'probability', False):
         y_scores_tuned = best_svm.predict_proba(X_val)
         score_type_tuned = "probability"
         print("Using predict_proba for Tuned SVM plots.")
    elif hasattr(best_svm, 'decision_function') and callable(getattr(best_svm, 'decision_function', None)):
        y_scores_tuned = best_svm.decision_function(X_val)
        score_type_tuned = "decision_function"
        print("Using decision_function for Tuned SVM plots.")
    else:
        print("Tuned SVM model supports neither predict_proba nor decision_function. Skipping score-based plots.")

    plot_confusion_matrix(y_val, svm_tuned_preds, class_labels, "SVM_Tuned")
    save_misclassified_examples(y_val, svm_tuned_preds, val_df['text'], "SVM_Tuned")

    if y_scores_tuned is not None and score_type_tuned is not None:
        plot_multiclass_roc_curve(y_val, y_scores_tuned, class_labels, "SVM_Tuned", score_type=score_type_tuned)
        plot_multiclass_precision_recall(y_val, y_scores_tuned, class_labels, "SVM_Tuned", score_type=score_type_tuned)
        plot_score_distribution(y_scores_tuned, class_labels, "SVM_Tuned", score_type=score_type_tuned)

except NameError:
     print("Skipping Tuned SVM plots - 'random_search_svm' or 'svm_tuned_preds' not defined.")
except Exception as e:
    print(f"An error occurred during SVM_Tuned plotting: {e}")
finally:
    plt.close('all')
    gc.collect()

### Test Labels Generation

In [None]:
print("\n--- Analyzing the Generated Predictions File ---")
prediction_file_path = output_file_path
prediction_column = 'label'

if not os.path.exists(prediction_file_path):
    print(f"Error: Prediction file not found at '{prediction_file_path}'.")
    print("Please ensure the previous script block ran successfully and created the file.")
else:
    print(f"Loading prediction file: {prediction_file_path}")
    try:
        predictions_df = pd.read_csv(prediction_file_path)
        print(f"Successfully loaded {len(predictions_df)} rows.")

        print("\n--- Basic File Information ---")
        print(f"Columns in the file: {predictions_df.columns.tolist()}")
        print(f"Shape of the dataframe: {predictions_df.shape}")

        if prediction_column not in predictions_df.columns:
            print(f"\nError: The specified prediction column '{prediction_column}' was not found in the file.")
            print("Please check the column name used when saving the predictions.")
        else:
            print(f"\n--- Analyzing Predictions ('{prediction_column}' column) ---")

            missing_predictions = predictions_df[prediction_column].isnull().sum()
            total_predictions = len(predictions_df)
            print(f"Total predictions generated: {total_predictions}")
            if missing_predictions > 0:
                print(f"Warning: Found {missing_predictions} missing predictions ({missing_predictions / total_predictions * 100:.2f}%).")
            else:
                print("No missing predictions found in the prediction column.")

            print("\nPredicted Class Distribution:")
            class_counts = predictions_df[prediction_column].value_counts().sort_index()
            class_percentages = predictions_df[prediction_column].value_counts(normalize=True).sort_index() * 100

            distribution_summary = pd.DataFrame({
                'Count': class_counts,
                'Percentage': class_percentages
            })
            distribution_summary['Percentage'] = distribution_summary['Percentage'].map('{:.2f}%'.format)

            print(distribution_summary)

            unique_classes = predictions_df[prediction_column].nunique()
            print(f"\nNumber of unique predicted classes: {unique_classes}")
            print(f"Predicted classes: {sorted(predictions_df[prediction_column].unique().tolist())}")

            try:
                if total_predictions > 0 and total_predictions < 500000:
                    print("\nGenerating plot for predicted class distribution...")
                    plt.figure(figsize=(12, 6))
                    sns.countplot(data=predictions_df, x=prediction_column, order=class_counts.index)
                    plt.title('Distribution of Predicted Classes')
                    plt.xlabel('Predicted Class Label')
                    plt.ylabel('Frequency Count')
                    plt.xticks(rotation=45, ha='right')
                    plt.tight_layout()
                    plot_filename = os.path.join(PLOT_DIR, "predicted_class_distribution.png")
                    plt.savefig(plot_filename, bbox_inches='tight')
                    print(f"Saved predicted class distribution plot to: {plot_filename}")
                    plt.close()
                else:
                    print("\nSkipping plot generation due to large dataset size or no predictions.")
            except Exception as plot_err:
                print(f"\nCould not generate plot: {plot_err}")

        del predictions_df
        gc.collect()
        print("\nCleaned up predictions DataFrame from memory.")

    except Exception as e:
        print(f"\nAn error occurred while reading or processing the file '{prediction_file_path}': {e}")

print("\n--- Analysis script finished ---")