In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score


def plot_true_vs_pred_and_violin(file_path, feed_classes, feed_type=None, ax=None, expand_lims=False):
    if ax is None:
        ax = plt.gca()
    
    npz_file = np.load(file_path, allow_pickle=True)
    time_data = npz_file['time']
    true_labels = time_data.item().get('labels')
    pred_values = time_data.item().get('output')
    feed_data = npz_file['feed']  
    feed = feed_data.item().get('labels')
    feed_label = feed_classes[feed_type]

    if feed_type is not None:
        feed_mask = feed == feed_type
        
        if not np.any(feed_mask):
            print(f"Warning: No matching entries found for feed '{feed_type}'. Skipping plot.")
            return
        
        true_labels = true_labels[feed_mask]
        pred_values = pred_values[feed_mask]
    
    unique_labels = np.unique(true_labels)
    
    grouped_pred_values = [pred_values[true_labels == label] for label in unique_labels]

    non_empty_groups = False
    for label, group in zip(unique_labels, grouped_pred_values):
        if len(group) > 0:
            non_empty_groups = True
            #print(f"True Label {label} - Group Size: {len(group)} - Predicted Values Range: {np.min(group)} to {np.max(group)}")
        else:
            print(f"Warning: No predicted values found for True Label {label}")
    
    if not non_empty_groups:
        print(f"Skipping plot for {feed_type} as there are no non-empty groups.")
        return

    # Plot the violins
    violin_parts = ax.violinplot(
        grouped_pred_values, 
        positions=unique_labels,  
        widths=10 if expand_lims else 30,   
        showmeans=False,                    
        showmedians=False,                 
        showextrema=False,                  
        bw_method='silverman'              
    )
    
    ax.set_aspect('equal', adjustable='box', anchor='C')
    ax.set_xticks(unique_labels)  # Label x-axis with unique true labels
    ax.set_xticklabels([str(_) for _ in unique_labels])
    ax.set_xlabel('Time (hrs)', fontsize='xx-large')
    ax.set_ylabel('Predicted time (hrs)', fontsize='xx-large')
    ax.tick_params('both', labelsize='x-large')

    for vp in violin_parts['bodies']:
        vp.set_facecolor('grey')   
        vp.set_edgecolor('none')  
        vp.set_linewidth(1.5)     

    if expand_lims:
        x_min = y_min = min(unique_labels) - 20
        x_max = y_max = max(unique_labels) + 20
    else:
        x_min = y_min = 0
        x_max = y_max = max(unique_labels) + 20

    ax.plot([x_min, x_max], [x_min, y_max], color='k', linestyle='--', linewidth=1, label="y = x")

    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)

    r2 = r2_score(true_labels, pred_values)
    print(f"R² score: {r2:.4f}")
    
    # plt.title(f'Feed: {feed_label}')
    # plt.savefig(f"ambr03_violin_{feed_label}.png", dpi=100)

file_path = '/pscratch/sd/n/niranjan/output/prediction_s1_ambr01_1e-3_40e.npz'

npz_file = np.load(file_path, allow_pickle=True)
feed_data = npz_file['feed']

labels = feed_data.item().get('labels')
feeds_to_plot = np.unique(labels) 

feed_classes = feed_data.item().get('classes')

for feed in feeds_to_plot:
    fig, ax = plt.subplots(figsize=(7, 5.5))
    plot_true_vs_pred_and_violin(file_path, feed_classes, feed)
    feed_label = feed_classes[feed]
    plt.title(f'Feed: {feed_label}')
    #plt.savefig(f"ambr03_violin_{feed_label}.png", dpi=100)


fig, ax = plt.subplots(figsize=(7, 5.5))
plot_true_vs_pred_and_violin(file_path, feed_classes)
plt.title(f'Violin Plot: ')
#plt.savefig(f"ambr03_violin.png", dpi=100)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, precision_score, recall_score, accuracy_score
import os

def plot_confusion_matrix(npz_file_path, time_type=None, save_name=None):
    npz_file = np.load(npz_file_path, allow_pickle=True)

    feed_data = npz_file['feed']
    true_labels = feed_data.item().get('labels')
    pred_labels = feed_data.item().get('pred')
    time_data = npz_file['time']  
    time = time_data.item().get('labels')

    if time_type is not None:
        time_mask = time == time_type
        if not np.any(time_mask):
            print(f"Warning: No matching entries found for feed '{feed_type}'. Skipping plot.")
            return 
        true_labels = true_labels[time_mask]
        pred_labels = pred_labels[time_mask]
    
    cm = confusion_matrix(true_labels, pred_labels)

    cm_percentage = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100

    plt.figure(figsize=(7, 5.5))
    sns.heatmap(cm_percentage, annot=True, fmt=".2f", cmap="Greys", 
                xticklabels=custom_labels, yticklabels=custom_labels)

    plt.ylabel('True Labels', fontsize=12)
    plt.xlabel('Predicted Labels', fontsize=12)
    
    plt.xticks(rotation=0)  
    plt.yticks(rotation=0)  

    if save_name:
        save_path = os.path.join(os.getcwd(), save_name)
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Confusion matrix saved to: {save_path}")

    accuracy = accuracy_score(true_labels, pred_labels)

    print(f'Overall Accuracy: {accuracy:.4f}')



npz_file_path = '/pscratch/sd/n/niranjan/output/prediction_s1_ambr01_1e-3_40e.npz'
npz_file = np.load(npz_file_path, allow_pickle=True)
time_data = npz_file['time']
labels = time_data.item().get('labels')
time_to_plot = np.unique(labels) 

feed_data = npz_file['feed']
feed_to_plot = feed_data.item().get('classes')

custom_labels = feed_to_plot

for time in time_to_plot:
    plot_confusion_matrix(npz_file_path, time)
    plt.title(f'Time: {time}')
    #plt.savefig(f"ambr03_cm_{time}.png", dpi=100)

plot_confusion_matrix(npz_file_path)
#plt.savefig(f"ambr03_cm.png", dpi=100)
