In [None]:
def plot_model(self, ax_title='Model Performance on Test Data', save_name=None):
    rc = {
        "font.family": "serif",
        "mathtext.fontset": "stix"
    }
    plt.rcParams.update(rc)
    plt.rcParams["font.serif"] = ["Times New Roman"] + plt.rcParams["font.serif"]

    # Get predictions and true values as NumPy arrays
    y_true = self.y_test.numpy().squeeze()
    y_pred = self.model(self.X_test).detach().numpy().squeeze()



    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(10, 5), dpi=500)
    ax[0].set_aspect('equal', adjustable='box')
    ax[0].plot(y_true,y_pred, label='True Values', color='cornflowerblue', ls='', marker='.', markersize=3)
    ax[0].plot([0,1], [0,1], color='indianred', linestyle='--',alpha = 0.7)



    # ax[0].plot(y_pred_sorted, label='Predicted Values', color='indianred', alpha=0.7,ls='', marker='.', markersize=3)
    mse = np.mean((y_true - y_pred) ** 2)
    mae = np.mean(np.abs(y_true - y_pred))

    bbox_props = dict(boxstyle='round', facecolor='white', alpha=0.7, pad=0.5, edgecolor='lightgrey')
    ax[0].text(0.31, 0.85, f'MSE: {mse:.4f}\nMAE: {mae:.4f}',
               transform=ax[0].transAxes, fontsize=12, verticalalignment='top', horizontalalignment='right',
               bbox=bbox_props)
    
    ax[0].set_title(ax_title, fontsize=20)
    ax[0].set_xlabel('Sample Index (sorted by descending true value)', fontsize=15)
    ax[0].set_ylabel('Scaled Precip.', fontsize=15)
    ax[0].grid(True, linestyle='--', alpha=0.7, axis='both', color='white')
    ax[0].legend(fontsize=12, loc='upper left')
    ax[0].set_ylim(-0.1, 1.1)
    ax[0].set_xlim(-0.1, 1.1)


    residuals = y_true - y_pred
    hist_range = [-1, 1]
    ax[1].hist(residuals, range=hist_range, label='Residuals (truth - pred.)', bins=50, color='cornflowerblue',
               alpha=0.5, histtype='stepfilled', edgecolor='dimgrey')
    ax[1].set_xlabel('Residuals', fontsize=15)
    ax[1].set_ylabel('Frequency', fontsize=15)
    ax[1].set_title('Residual Distribution', fontsize=20)
    ax[1].grid(True, linestyle='--', alpha=0.7, axis='y', color='white')
    ax[1].legend(fontsize=12, loc='upper left')
    ax[1].axvline(0, color='indianred', linestyle='--', alpha=0.7, linewidth=0.5)

    spine_args = ['top', 'right', 'left', 'bottom']

    for a in ax:
        a.set_facecolor('gainsboro')
        for spine in spine_args:
            a.spines[spine].set_visible(False)
        a.tick_params(axis='both', which='both', length=0)

    plt.tight_layout()
    if save_name is not None:
        plt.savefig(save_name)

    return mse, mae

