In [None]:
import numpy as np
from data_preprocessing import load_data
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
from collections import Counter
from sklearn.metrics import ndcg_score

In [None]:


y_test = load_data('models_test/y_test.csv')
test_input = load_data('models_test/test_input.csv')

In [None]:

def filter_minor_classes(test_input, y):
    # count the number of instances in each class
    class_counts = Counter(y)

    # get the classes with only 1 instance
    small_classes = [k for k, v in class_counts.items() if v < 2]

    # remove these classes from your dataset
    mask = np.logical_not(np.isin(y, small_classes))

    X = [array[mask] for array in test_input]
    y = y[mask]

    # Convert back to numpy arrays and ensure the type is float32
    return X, y

def evaluate_model(model, test_input, y_test):
    # Evaluate the performance of the model on the test data
    loss, accuracy = model.evaluate(test_input, y_test)
    print(f"Test Loss: {loss}")
    print(f"Test Accuracy: {accuracy}")
    print(f">>>>>>>> test_input original: {test_input}")
    print(f">>>>>>>> y_test original: {y_test}")

    # Filter out classes with less than 2 instances
    y_test_classes = np.argmax(y_test, axis=1)

    # Reshape the 1D arrays in test_input to 2D
    test_input = [arr.reshape(-1, 1) if len(arr.shape) == 1 else arr for arr in test_input]
    # Concatenate all arrays in test_input along the last axis
    test_input = np.concatenate(test_input, axis=-1)

    test_input_filtered, y_test_filtered = filter_minor_classes(test_input, y_test_classes)

    print(f">>>>>>>> test_input: {test_input}")
    print(f">>>>>>>> y_test: {y_test}")

    # Calculate the predicted class as the one with highest probability
    y_pred = model.predict(test_input_filtered)
    y_pred_class = np.argmax(y_pred, axis=1)

    # Calculate metrics
    precision = precision_score(y_test_filtered, y_pred_class, average='weighted')
    recall = recall_score(y_test_filtered, y_pred_class, average='weighted')
    f1 = f1_score(y_test_filtered, y_pred_class, average='weighted')
    roc_auc = roc_auc_score(y_test_filtered, y_pred, multi_class='ovr')
    ndcg = ndcg_score(y_test_filtered, y_pred)

    print("Precision: ", precision)
    print("Recall: ", recall)
    print("F1 Score: ", f1)
    print("ROC AUC Score: ", roc_auc)
    print("NDCG: ", ndcg)