In [None]:
%load_ext autoreload

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # NOTSET, DEBUG, INFO, WARN, ERROR, CRITICAL

from JPAS_DA import global_setup
from JPAS_DA.data import data_loaders
from JPAS_DA.data import generate_toy_data
from JPAS_DA.models import model_building_tools
from JPAS_DA.training import save_load_tools
from JPAS_DA.evaluation import evaluation_tools

import os
import torch
import numpy as np
from sklearn.manifold import TSNE

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

from JPAS_DA.utils import plotting_utils
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.style.use('default')
plt.close('all')
font, rcnew = plotting_utils.matplotlib_default_config()
mpl.rc('font', **font)
plt.rcParams.update(rcnew)
plt.style.use('tableau-colorblind10')
%matplotlib inline

from JPAS_DA.utils import aux_tools
aux_tools.set_seed(42)

In [None]:
# =========================
# Shared Parameters
# =========================
n_classes = 5
class_proportions = np.array([0.55, 0.05, 0.05, 0.25, 0.10])
assert np.isclose(class_proportions.sum(), 1.0)

# sample sizes
n_samples_train    = 16384
n_samples_val      = 16384
n_samples_test     = 16384
n_samples_train_DA = 64
n_samples_val_DA   = 1024

# seeds
seed_structure = 137
seed_train     = 42
seed_val       = 276
seed_test      = 0
seed_train_DA  = 1
seed_val_DA    = 2
seed_transform = 3

# =========================
# Create specs
# =========================
specs_target = [
    generate_toy_data.spec_gaussian(center=[-1.0, 3.7], sigma=(1.1, 0.5)),
    generate_toy_data.spec_spline(control_points=[[3.8,-3.8], [6,-6], [-7,-7], [-5,-2], [0,0], [2,1], [7,8]],
                thickness=0.4, jitter=0.05, closed=False),
    generate_toy_data.spec_spiral(center=[0,0], a=0.8, b=6.0, turns=1.5, theta0=2*np.pi,
                radial_noise=0.2, jitter=0.05),
    generate_toy_data.spec_mixture([
        generate_toy_data.spec_ring(center=[5.5,6.6], radius=2., width=0.2, arc=(0.0, 2*np.pi), jitter=0.1),
        generate_toy_data.spec_gaussian(center=[-3,-4], sigma=(0.5, 0.5)),
    ], weights=[0.1, 0.9]),
    generate_toy_data.spec_mixture([
        generate_toy_data.spec_spline(control_points=[[-4,2], [-6, -1], [-8,-1], [-8,2], [-6, 6], [-2, 7], [1, 7]],
            thickness=0.4, jitter=0.05, closed=False),
        generate_toy_data.spec_gaussian(center=[-2,-8], sigma=(1.0, 0.3)),
    ], weights=[0.5, 0.5])
]

specs_source = [
    generate_toy_data.spec_gaussian(center=[1.7, 3.5], sigma=(0.4, 0.8)),
    generate_toy_data.spec_spline(control_points=[[3.8,-3.8], [6,-6], [-7,-7], [-5,-2], [0,0], [2,1], [7,8]],
                thickness=0.2, jitter=0.05, closed=False),
    generate_toy_data.spec_spiral(center=[0,0], a=0.8, b=6.0, turns=1.5, theta0=2*np.pi,
                radial_noise=0.1, jitter=0.1),
    generate_toy_data.spec_mixture([
        generate_toy_data.spec_ring(center=[5.5,6.5], radius=2.2, width=0.1, arc=(0.0, 2*np.pi), jitter=0.05),
        generate_toy_data.spec_gaussian(center=[-2,-5], sigma=(0.3, 0.5)),
    ], weights=[0.1, 0.9]),
    generate_toy_data.spec_mixture([
        generate_toy_data.spec_spline(control_points=[[-4,2], [-6, -1], [-8,-1], [-8,2], [-6, 6], [-2, 7], [1, 7]],
            thickness=0.2, jitter=0.05, closed=False),
        generate_toy_data.spec_gaussian(center=[-4.5,-8.], sigma=(1.5, 1.2)),
    ], weights=[0.5, 0.5])
]

# =========================
# Generate Train/Val Source with the SAME shared specs and Target/Test with DIFFERENT shifted specs
# =========================
xx_train, yy_train, train_counts = generate_toy_data.generate_dataset_from_specs(
    n_samples_train, specs_source, class_proportions, seed=seed_train
)
xx_val, yy_val, val_counts = generate_toy_data.generate_dataset_from_specs(
    n_samples_val, specs_source, class_proportions, seed=seed_val
)
xx_test, yy_test, test_counts = generate_toy_data.generate_dataset_from_specs(
    n_samples_test, specs_target, class_proportions, seed=seed_test
)
xx_train_DA, yy_train_DA, _ = generate_toy_data.generate_dataset_from_specs(
    n_samples_train_DA, specs_target, class_proportions, seed=seed_train_DA
)
xx_val_DA, yy_val_DA, _ = generate_toy_data.generate_dataset_from_specs(
    n_samples_val_DA, specs_target, class_proportions, seed=seed_val_DA
)

In [None]:
path_load_Fully_Supervised = os.path.join(global_setup.path_models, "06_example_model_Fully_Supervised")
path_load_no_DA = os.path.join(global_setup.path_models, "06_example_model")
path_load_DA = os.path.join(global_setup.path_models, "07_example_model_DA")

In [None]:
dset_train_DA = data_loaders.DataLoader(xx_train_DA, yy_train_DA)
dset_val_no_DA = data_loaders.DataLoader(xx_val, yy_val)
dset_val_DA = data_loaders.DataLoader(xx_val_DA, yy_val_DA)
dset_test = data_loaders.DataLoader(xx_test, yy_test)

In [None]:
_, model_encoder_Fully_Supervised = save_load_tools.load_model_from_checkpoint(os.path.join(path_load_Fully_Supervised, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_Fully_Supervised = save_load_tools.load_model_from_checkpoint(os.path.join(path_load_Fully_Supervised, "model_downstream.pt"), model_building_tools.create_mlp)
model_encoder_Fully_Supervised.eval()
model_downstream_Fully_Supervised.eval()

_, model_encoder_no_DA = save_load_tools.load_model_from_checkpoint(os.path.join(path_load_no_DA, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_no_DA = save_load_tools.load_model_from_checkpoint(os.path.join(path_load_no_DA, "model_downstream.pt"), model_building_tools.create_mlp)
model_encoder_no_DA.eval()
model_downstream_no_DA.eval()

_, model_encoder_DA = save_load_tools.load_model_from_checkpoint(os.path.join(path_load_DA, "model_encoder.pt"), model_building_tools.create_mlp)
_, model_downstream_DA = save_load_tools.load_model_from_checkpoint(os.path.join(path_load_DA, "model_downstream.pt"), model_building_tools.create_mlp)
model_encoder_DA.eval()
model_downstream_DA.eval()

_ = evaluation_tools.compare_model_parameters(model_downstream_no_DA, model_downstream_DA, rtol=1e-2, atol=1e-2)

In [None]:
xx, yy_true = dset_test(batch_size=dset_test.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_Fully_Supervised(xx)
        logits = model_downstream_Fully_Supervised(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_test_Fully_Supervised = yy_true.cpu().numpy()
features_test_Fully_Supervised= features_.cpu().numpy()
yy_pred_P_test_Fully_Supervised = yy_pred_P.cpu().numpy()
yy_pred_test_Fully_Supervised = np.argmax(yy_pred_P_test_Fully_Supervised, axis=1)



xx, yy_true = dset_val_no_DA(batch_size=dset_val_no_DA.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_no_DA(xx)
        logits = model_downstream_no_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_val_no_DA = yy_true.cpu().numpy()
features_val_no_DA = features_.cpu().numpy()
yy_pred_P_val_no_DA = yy_pred_P.cpu().numpy()
yy_pred_val_no_DA = np.argmax(yy_pred_P_val_no_DA, axis=1)


xx, yy_true = dset_train_DA(batch_size=dset_train_DA.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_DA(xx)
        logits = model_downstream_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_train_DA = yy_true.cpu().numpy()
features_train_DA = features_.cpu().numpy()
yy_pred_P_train_DA = yy_pred_P.cpu().numpy()
yy_pred_train_DA = np.argmax(yy_pred_P_train_DA, axis=1)


xx, yy_true = dset_val_DA(batch_size=dset_val_DA.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_DA(xx)
        logits = model_downstream_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_val_DA = yy_true.cpu().numpy()
features_val_DA = features_.cpu().numpy()
yy_pred_P_val_DA = yy_pred_P.cpu().numpy()
yy_pred_val_DA = np.argmax(yy_pred_P_val_DA, axis=1)


xx, yy_true = dset_test(batch_size=dset_test.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_no_DA(xx)
        logits = model_downstream_no_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_test = yy_true.cpu().numpy()
features_test_no_DA = features_.cpu().numpy()
yy_pred_P_test_no_DA = yy_pred_P.cpu().numpy()
yy_pred_test_no_DA = np.argmax(yy_pred_P_test_no_DA, axis=1)


xx, yy_true = dset_test(batch_size=dset_test.NN_xx, seed=0, sampling_strategy="true_random", to_torch=True, device="cpu")
with torch.no_grad():
        features_ = model_encoder_DA(xx)
        logits = model_downstream_DA(features_)
yy_pred_P = torch.nn.functional.softmax(logits, dim=1)
yy_true_test = yy_true.cpu().numpy()
features_test_DA = features_.cpu().numpy()
yy_pred_P_test_DA = yy_pred_P.cpu().numpy()
yy_pred_test_DA = np.argmax(yy_pred_P_test_DA, axis=1)

In [None]:
# =========================
# Global style parameters
# =========================
FS_TICK         = 16  # tick labels
FS_LABEL        = 20  # axis labels
FS_BOX          = 15  # info box text
FS_LEGEND       = 13  # legend text
FS_LEGEND_TITLE = 14  # legend title
FS_CBAR_LABEL   = 18  # colorbar label
FS_CBAR_TICK    = 16  # colorbar ticks

INFOBOX_BBOX = dict(boxstyle='round,pad=0.35', facecolor='white',
                    alpha=0.9, edgecolor='black', linewidth=0.8)

# =========================
# Bounds & grid
# =========================
x_min, x_max = -10, 10
y_min, y_max = -10, 10
grid_res = 512

xx_vals = np.linspace(x_min, x_max, grid_res)
yy_vals = np.linspace(y_min, y_max, grid_res)
xx_mesh, yy_mesh = np.meshgrid(xx_vals, yy_vals)
grid_points = np.stack([xx_mesh.ravel(), yy_mesh.ravel()], axis=1)

# Build the grid tensor on CPU, move to the model's device for each forward pass
xx_grid_cpu = torch.tensor(grid_points, dtype=torch.float32, device="cpu")

# =========================
# Colors (consistent across panels)
# =========================
try:
    n_classes = len(dset_val_no_DA.class_labels)
    class_labels = np.array(dset_val_no_DA.class_labels)
    class_colors = [color_dict[int(i)] for i in range(n_classes)]
except Exception:
    class_labels = np.unique(dset_val_no_DA.yy["SPECTYPE_int"])
    n_classes = len(class_labels)
    cmap_base = plt.cm.get_cmap("tab10")
    class_colors = [cmap_base(i % cmap_base.N) for i in range(n_classes)]

cmap = mpl.colors.ListedColormap(class_colors)
norm = mpl.colors.BoundaryNorm(np.arange(n_classes + 1) - 0.5, ncolors=n_classes)

# Helper to annotate panels
def add_info_box(ax, scatter_desc, bg_desc):
    txt = f"Scatter-Points: {scatter_desc}\nBackground Predictions: {bg_desc}"
    ax.text(0.02, 0.98, txt, transform=ax.transAxes,
            va='top', ha='left', fontsize=FS_BOX, bbox=INFOBOX_BBOX)

# =========================
# Helper: predict on grid
# =========================
def predict_grid(encoder, downstream, xx_grid_cpu, device):
    with torch.no_grad():
        feats = encoder(xx_grid_cpu.to(device))
        logits = downstream(feats)
        yy_pred = torch.softmax(logits, dim=1).argmax(dim=1)
    Z = yy_pred.view(grid_res, grid_res).cpu().numpy().astype(float)
    return Z

# Devices (assume models are already on appropriate devices)
dev_no_DA = next(model_encoder_no_DA.parameters()).device
dev_DA    = next(model_encoder_DA.parameters()).device

# =========================
# Compute predicted grids
# =========================
Z_Fully_Supervised = predict_grid(model_encoder_Fully_Supervised, model_downstream_Fully_Supervised, xx_grid_cpu, dev_no_DA)  # NEW
Z_noDA_val  = predict_grid(model_encoder_no_DA, model_downstream_no_DA, xx_grid_cpu, dev_no_DA)
Z_noDA_test = Z_noDA_val  # same models; reuse prediction for test overlay
Z_DA_test   = predict_grid(model_encoder_DA, model_downstream_DA, xx_grid_cpu, dev_DA)

# =========================
# Figure with FOUR VERTICAL PANELS (shared axes)
# =========================
fig, axes = plt.subplots(4, 1, figsize=(8, 24), sharex=True, sharey=True)

for ax in axes:
    ax.set_xlim(x_min, x_max)
    ax.set_ylim(y_min, y_max)
    ax.tick_params(labelsize=FS_TICK)

# Left y-labels on all panels; shared x-label on the bottom
for ax in axes:
    ax.set_ylabel(r"$\mathrm{Feature~2}$", fontsize=FS_LABEL)
axes[-1].set_xlabel(r"$\mathrm{Feature~1}$", fontsize=FS_LABEL)

# =========================
# Panel 0 (TOP): Fully_Supervised + train-DA points
# =========================
ax_train = axes[0]
pcm_train = ax_train.pcolormesh(xx_mesh, yy_mesh, Z_Fully_Supervised, cmap=cmap, norm=norm,
                                shading='nearest', alpha=0.5)
for i, label in enumerate(class_labels):

    mask = (dset_test.yy["SPECTYPE_int"] == label)
    tmp = dset_test.xx["OBS"][mask]
    ax_train.scatter(tmp[:, 0], tmp[:, 1], marker='s', s=12, alpha=0.5, color=class_colors[i],
                       edgecolor='black', linewidth=0.1)
    
    mask = (dset_train_DA.yy["SPECTYPE_int"] == label)
    tmp = dset_train_DA.xx["OBS"][mask]
    ax_train.scatter(tmp[:, 0], tmp[:, 1], marker='^', s=120, alpha=1.0, color=class_colors[i],
                     edgecolor='black', linewidth=0.1)

add_info_box(ax_train, "Target (Training) Set", "Fully_Supervised Model")

# =========================
# Panel 1: no-DA + validation points + shared centers (only)
# =========================
ax_val = axes[1]
pcm_val = ax_val.pcolormesh(xx_mesh, yy_mesh, Z_noDA_val, cmap=cmap, norm=norm,
                            shading='nearest', alpha=0.5)

for i, label in enumerate(class_labels):
    mask = (dset_val_no_DA.yy["SPECTYPE_int"] == label)
    tmp = dset_val_no_DA.xx["OBS"][mask]
    ax_val.scatter(tmp[:, 0], tmp[:, 1], s=12, alpha=0.5, color=class_colors[i],
                   edgecolor='black', linewidth=0.1)

# Use "Source (Validation) Set" since this panel is validation points
add_info_box(ax_val, "Source (Validation) Set", "no-DA Model")

# =========================
# Panel 2: no-DA + test points + centers + arrows
# =========================
ax_noDA_test = axes[2]
pcm_noDA_test = ax_noDA_test.pcolormesh(xx_mesh, yy_mesh, Z_noDA_test, cmap=cmap, norm=norm,
                                        shading='nearest', alpha=0.5)

for i, label in enumerate(class_labels):
    mask = (dset_test.yy["SPECTYPE_int"] == label)
    tmp = dset_test.xx["OBS"][mask]
    ax_noDA_test.scatter(tmp[:, 0], tmp[:, 1], marker='s', s=12, alpha=0.5, color=class_colors[i],
                         edgecolor='black', linewidth=0.1)

add_info_box(ax_noDA_test, "Target (Test) Set", "no-DA Model")

# =========================
# Panel 3: DA + test points + centers + arrows
# =========================
ax_DA_test = axes[3]
pcm_DA_test = ax_DA_test.pcolormesh(xx_mesh, yy_mesh, Z_DA_test, cmap=cmap, norm=norm,
                                    shading='nearest', alpha=0.5)

for i, label in enumerate(class_labels):

    mask = (dset_test.yy["SPECTYPE_int"] == label)
    tmp = dset_test.xx["OBS"][mask]
    ax_DA_test.scatter(tmp[:, 0], tmp[:, 1], marker='s', s=12, alpha=0.5, color=class_colors[i],
                       edgecolor='black', linewidth=0.1)

    mask = (dset_train_DA.yy["SPECTYPE_int"] == label)
    tmp = dset_train_DA.xx["OBS"][mask]
    ax_DA_test.scatter(tmp[:, 0], tmp[:, 1], marker='^', s=120, alpha=1.0, color=class_colors[i],
                     edgecolor='black', linewidth=0.1)

add_info_box(ax_DA_test, "Target (Test) Set", "DA Model")

# =========================
# Legends
# =========================
from matplotlib.lines import Line2D

# Class legend on top panel
class_handles = [
    Line2D([0], [0], marker='o', linestyle='None', markerfacecolor=class_colors[i],
           markeredgecolor='black', markersize=8, label=str(class_labels[i]))
    for i in range(n_classes)
]
leg_classes = ax_train.legend(handles=class_handles, loc='lower left',
                              title="Class (points)", title_fontsize=FS_LEGEND_TITLE,
                              fontsize=FS_LEGEND, fancybox=True, shadow=True, framealpha=0.9)
ax_train.add_artist(leg_classes)

# =========================
# Shared colorbar on the RIGHT
# =========================
plt.tight_layout(rect=[0.0, 0.0, 0.85, 1.0])  # leave room for colorbar
sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cax = fig.add_axes([0.88, 0.10, 0.03, 0.80])  # [left, bottom, width, height] in figure coords
cbar = fig.colorbar(sm, cax=cax)
cbar.set_label("Predicted Class", fontsize=FS_CBAR_LABEL)
cbar.ax.tick_params(labelsize=FS_CBAR_TICK)
cbar.set_ticks(np.arange(n_classes))
try:
    cbar.set_ticklabels([str(lbl) for lbl in class_labels])
except Exception:
    pass

plt.savefig(os.path.join(global_setup.path_saved_figures, "toy_distributions_quadpanel.png"),
            format="png", bbox_inches="tight")
plt.show()

In [None]:
# ---- Bigger font config ----
FS_TITLE = 24
FS_LABEL = 20
FS_TICKS = 18
FS_CELL  = 14          # off-diagonal cell text
FS_CELL_DIAG = 14      # diagonal cell text
FS_CBAR_LABEL = 20
FS_CBAR_TICKS = 16
TICK_ROT = 20

# Config
class_names = np.arange(n_classes)  # or your custom names
cmap = plt.cm.RdYlGn
threshold_color = 0.5  # cell text color threshold based on cm_percent (0..1)

cases = [
    ("Fully_Supervised: Target", yy_true_test,   yy_pred_P_test_Fully_Supervised),
    ("no DA: Source", yy_true_val_no_DA, yy_pred_P_val_no_DA),
    ("no DA: Target", yy_true_test,     yy_pred_P_test_no_DA),
    ("DA: Target",    yy_true_test,     yy_pred_P_test_DA),
]

# Color normalization for row-normalized proportion in [0, 1]
norm = mpl.colors.Normalize(vmin=0.0, vmax=1.0)

# --- Three VERTICAL subpanels with shared axes ---
fig, axes = plt.subplots(4, 1, figsize=(12, 30), sharex=True, sharey=True)
fig.dpi = 150

# For shared colorbar
mappable_for_cbar = mpl.cm.ScalarMappable(norm=norm, cmap=cmap)
mappable_for_cbar.set_array([])

# Shared limits/ticks for confusion matrices (note reversed y-limits to keep row 0 at top)
xlim = (-0.5, n_classes - 0.5)
ylim = (n_classes - 0.5, -0.5)  # reverse to match origin='upper'
ticks = np.arange(n_classes)

for ax, (title, yy_true, yy_pred_P) in zip(axes, cases):
    yy_pred = np.argmax(yy_pred_P, axis=1)

    # Confusion matrix with all classes present
    cm = np.zeros((n_classes, n_classes), dtype=int)
    valid = (yy_true >= 0) & (yy_true < n_classes)
    for t, p in zip(yy_true[valid], yy_pred[valid]):
        cm[int(t), int(p)] += 1

    # Row-normalized
    row_sums = cm.sum(axis=1, keepdims=True)
    cm_percent = np.divide(cm, row_sums, where=row_sums != 0)

    # Matrix image (keep row 0 at the TOP)
    im = ax.imshow(cm_percent, interpolation='nearest', cmap=cmap, norm=norm, origin='upper')

    # Shared axes styling
    ax.set_xlim(xlim)
    ax.set_ylim(ylim)
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.set_xticklabels(class_names, fontsize=FS_TICKS)
    ax.set_yticklabels(class_names, fontsize=FS_TICKS)
    ax.set_aspect('equal', adjustable='box')  # square cells

    # Place the "title" on the RIGHT like a y-label
    ax.text(1.02, 0.5, title, transform=ax.transAxes,
            rotation=-90, va='center', ha='left', fontsize=FS_TITLE)

    # Per-class metrics
    precision = np.zeros(n_classes, dtype=float)
    recall    = np.zeros(n_classes, dtype=float)
    f1        = np.zeros(n_classes, dtype=float)
    for i in range(n_classes):
        tp = cm[i, i]
        fp = cm[:, i].sum() - tp
        fn = cm[i, :].sum() - tp
        precision[i] = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall[i]    = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1[i]        = (2 * precision[i] * recall[i] / (precision[i] + recall[i])
                        if (precision[i] + recall[i]) > 0 else 0.0)

    # Cell annotations (counts + %; diagonal shows TPR/PPV/F1)
    for i in range(n_classes):
        for j in range(n_classes):
            count   = cm[i, j]
            percent = cm_percent[i, j] * 100 if row_sums[i, 0] != 0 else 0.0
            text_color = "white" if cm_percent[i, j] > threshold_color else "black"

            if i == j:
                text = (f"{count}\n"
                        f"TPR:{recall[i]*100:.1f}% "
                        f"\nPPV:{precision[i]*100:.1f}% "
                        f"\nF1:{f1[i]:.2f}")
                ax.text(j, i, text, ha="center", va="center",
                        color=text_color, fontsize=FS_CELL_DIAG, fontweight='bold', linespacing=1.2)
            else:
                text = f"{count}\n{percent:.1f}%"
                ax.text(j, i, text, ha="center", va="center",
                        color=text_color, fontsize=FS_CELL, linespacing=1.2)

# Labels: show y-label on middle panel, x-label on bottom panel
axes[0].set_ylabel('True Label', fontsize=FS_LABEL, labelpad=10)
axes[1].set_ylabel('True Label', fontsize=FS_LABEL, labelpad=10)
axes[2].set_ylabel('True Label', fontsize=FS_LABEL, labelpad=10)
plt.setp(axes[0].get_xticklabels(), visible=False)
plt.setp(axes[1].get_xticklabels(), visible=False)
axes[-1].set_xlabel('Predicted Label', fontsize=FS_LABEL, labelpad=10)
plt.setp(axes[-1].get_xticklabels(), rotation=TICK_ROT, ha="right", rotation_mode="anchor")

# Shared colorbar on the RIGHT
plt.tight_layout(rect=[0.0, 0.0, 0.90, 1.0])  # leave room for cbar
cax = fig.add_axes([0.92, 0.12, 0.02, 0.76])
cbar = fig.colorbar(mappable_for_cbar, cax=cax)
cbar.set_label("True-label (row) normalized ratio", fontsize=FS_CBAR_LABEL)
cbar.ax.tick_params(labelsize=FS_CBAR_TICKS)
cbar.set_ticks([0.0, 0.25, 0.5, 0.75, 1.0])

plt.savefig(os.path.join(global_setup.path_saved_figures, "toy_confusion_matrices_tripanel.pdf"),
            bbox_inches="tight")
plt.show()


In [None]:
evaluation_tools.plot_confusion_matrix(
    yy_true_val_no_DA, yy_pred_P_val_no_DA,
    class_names=np.arange(n_classes),
    cmap=plt.cm.RdYlGn, title="no DA: Source"
)

evaluation_tools.plot_confusion_matrix(
    yy_true_test, yy_pred_P_test_no_DA,
    class_names=np.arange(n_classes),
    cmap=plt.cm.RdYlGn, title="no DA: Target"
)

evaluation_tools.plot_confusion_matrix(
    yy_true_val_DA, yy_pred_P_val_DA,
    class_names=np.arange(n_classes),
    cmap=plt.cm.RdYlGn, title="DA: Validation Set Target"
)

evaluation_tools.plot_confusion_matrix(
    yy_true_test, yy_pred_P_test_DA,
    class_names=np.arange(n_classes),
    cmap=plt.cm.RdYlGn, title="DA: Target"
)

evaluation_tools.compare_TPR_confusion_matrices(
    yy_true_val_no_DA,
    yy_pred_P_val_no_DA,
    yy_true_test,
    yy_pred_P_test_no_DA,
    class_names=np.arange(n_classes),
    figsize=(10, 7),
    cmap='seismic',
    title='TPR Comparison: Source vs Target (no-DA)',
    name_1 = "no DA",
    name_2 = "DA"
)
evaluation_tools.compare_TPR_confusion_matrices(
    yy_true_test,
    yy_pred_P_test_no_DA,
    yy_true_test,
    yy_pred_P_test_DA,
    class_names=np.arange(n_classes),
    figsize=(10, 7),
    cmap='seismic',
    title='TPR Comparison: DA vs no-DA (Target)',
    name_1 = "no DA",
    name_2 = "DA"
)

In [None]:
metrics = evaluation_tools.compare_sets_performance(
    yy_true_val_no_DA, yy_pred_P_val_no_DA,
    yy_true_test, yy_pred_P_test_no_DA,
    class_names=np.arange(n_classes),
    name_1="Source (no-DA)",
    name_2="Target",
    y_max_Delta_F1=1.0,
    y_min_Delta_F1=-1.0,
    f1_save_path=os.path.join(global_setup.path_saved_figures, "toy_F1_comparison_source_vs_target_no_DA.pdf")
)

metrics = evaluation_tools.compare_sets_performance(
    yy_true_test, yy_pred_P_test_no_DA,
    yy_true_test, yy_pred_P_test_DA,
    class_names=np.arange(n_classes),
    name_1="no-DA (Target)",
    name_2="DA",
    y_max_Delta_F1=1.0,
    y_min_Delta_F1=-1.0,
    f1_save_path=os.path.join(global_setup.path_saved_figures, "toy_F1_comparison_DA_vs_no_DA.pdf")
)

metrics = evaluation_tools.compare_sets_performance(
    yy_true_test, yy_pred_P_test_Fully_Supervised,
    yy_true_test, yy_pred_P_test_DA,
    class_names=np.arange(n_classes),
    name_1="Fully_Supervised (Target)",
    name_2="DA",
    y_max_Delta_F1=1.0,
    y_min_Delta_F1=-1.0,
    f1_save_path=os.path.join(global_setup.path_saved_figures, "toy_F1_comparison_DA_vs_Fully_Supervised.pdf")
)

comparisons = [
    (yy_true_val_no_DA, yy_pred_P_val_no_DA, yy_true_test, yy_pred_P_test_no_DA, "Target vs. Source (no-DA)"),
    (yy_true_test, yy_pred_P_test_Fully_Supervised, yy_true_test, yy_pred_P_test_DA,   "DA vs Fully_Supervised (Target)"),
    (yy_true_test, yy_pred_P_test_no_DA, yy_true_test, yy_pred_P_test_DA,   "DA vs no-DA (Target)"),
]
fig, ax, deltas = evaluation_tools.plot_overall_deltaF1_grouped(
    comparisons,
    class_names=class_names,
    colors=["crimson", "darkorange", "limegreen"],  # extend as needed
    title=None,
    figsize=(8, 6),
    legend_kwargs={"loc":"upper left", "frameon":True, "fontsize": 12},
    save_dir=global_setup.path_saved_figures, save_format="pdf", filename="toy_delta_F1"
)

In [None]:
dict_radar = {
    "Fully-Supervised": {
        "y_true": yy_true_test_Fully_Supervised,
        "y_pred": yy_pred_P_test_Fully_Supervised,
        "plot_kwargs": {
            "linestyle": ":", "linewidth": 2.0, "color": "grey",
            "marker": "X", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Fully-Supervised"
        }
    },
    "Source no-DA": {
        "y_true": yy_true_val_no_DA,
        "y_pred": yy_pred_P_val_no_DA,
        "plot_kwargs": {
            "linestyle": "--", "linewidth": 2.0, "color": "royalblue",
            "marker": "s", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Source no-DA"
        }
    },
    "Target no-DA": {
        "y_true": yy_true_test,
        "y_pred": yy_pred_P_test_no_DA,
        "plot_kwargs": {
            "linestyle": "--", "linewidth": 2.0, "color": "firebrick",
            "marker": "v", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Target no-DA"
        }
    },
    "Target (Train) DA": {
        "y_true": yy_true_train_DA,
        "y_pred": yy_pred_P_train_DA,
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 2.0, "color": "darkorange",
            "marker": "^", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Target (Train) DA"
        }
    },
    "Target DA": {
        "y_true": yy_true_test,
        "y_pred": yy_pred_P_test_DA,
        "plot_kwargs": {
            "linestyle": "-", "linewidth": 2.0, "color": "green",
            "marker": "o", "markersize": 10.0, "fill_alpha": 0.05,
            "label": "Target DA"
        }
    },
}

fig, ax = evaluation_tools.radar_plot(
    dict_radar=dict_radar, class_names=class_names,
    title="F1 Radar Plot", figsize=(8, 8), theta_offset=np.pi / 2, # first axis at 12 o'clock
    r_ticks=(0.1, 0.3, 0.5, 0.7, 0.9), r_lim=(0.0, 1.0),
    tick_labelsize=16, radial_labelsize=12, show_legend=True,
    legend_kwargs={
        "loc": "upper left", "bbox_to_anchor": (0.73, 1.0), "fontsize": 14, "ncol": 1,
        "title": "Evaluation Cases", "frameon": True, "fancybox": True, "shadow": True, "borderaxespad": 0.0,
    }
)
fig.savefig(os.path.join(global_setup.path_saved_figures, "toy_F1_radar.pdf"), bbox_inches='tight')
plt.show()

# Explore latent space

In [None]:
feat_dict = {
    "latents_no_DA_Source": features_val_no_DA,
    "latents_no_DA_Target": features_test_no_DA,
    "latents_DA_Target": features_test_DA
}

latents_tSNE = evaluation_tools.tsne_per_key(
    feat_dict,
    standardize=False,
    subsample=None,
    random_state=137,
    tsne_kwargs={"perplexity": 100},
    return_all_key=None,
)

In [None]:
xlim = (-100, 100)
ylim = (-100, 100)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Source_tSNE'], y_val=yy_true_val_no_DA,
    X_test=latents_tSNE['latents_no_DA_Target_tSNE'], y_test=yy_true_test,
    class_names=None,
    title="Latents no-DA: Source vs Target",
    marker_val="o", marker_test="^",
    size_val=14, size_test=14, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=4000, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Source no-DA",
    legend_split_2="Target no-DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Source_tSNE'], yy_pred_val_no_DA, # yy_true_val_no_DA
    class_counts=dset_val_no_DA.class_counts,
    class_names=None,
    title="Latents no-DA: Source",
    n_bins=128, sigma=2.0,
    scatter_size=1.0, scatter_alpha=1.0,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Source_tSNE'],
    title="Latents no-DA: Source",
    density_method="hist", # or "kde"
    bins=256,
    sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    mask_zero_support=True,
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k",
    contour_linewidths=0.4,
    contour_label_fontsize=7,
    contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    random_subsample=None,
    xlim=xlim,
    ylim=ylim
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_no_DA_Target_tSNE'], yy_pred_test_no_DA, # yy_true_test
    class_counts=dset_test.class_counts,
    class_names=None,
    title="Latents no-DA: Target",
    n_bins=128, sigma=2.0,
    scatter_size=1.0, scatter_alpha=1.0,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_no_DA_Target_tSNE'],
    title="Latents no-DA: Target",
    density_method="hist", # or "kde"
    bins=256,
    sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    mask_zero_support=True,
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k",
    contour_linewidths=0.4,
    contour_label_fontsize=7,
    contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    random_subsample=None,
    xlim=xlim,
    ylim=ylim
)

evaluation_tools.plot_latents_scatter_val_test(
    X_val=latents_tSNE['latents_no_DA_Target_tSNE'], y_val=yy_true_test,
    X_test=latents_tSNE['latents_DA_Target_tSNE'], y_test=yy_true_test,
    class_names=None,
    title="Latents Target: no-DA vs DA",
    marker_val="o", marker_test="^",
    size_val=14, size_test=14, alpha_val=0.7, alpha_test=0.7,
    xlim=xlim, ylim=ylim,
    subsample=4000, seed=137,
    edgecolor=None, linewidths=0.0,
    legend_split_1="Target no-DA",
    legend_split_2="Target DA"
)
evaluation_tools.plot_latents_scatter(
    latents_tSNE['latents_DA_Target_tSNE'], yy_pred_test_DA, # yy_true_test
    class_counts=dset_test.class_counts,
    class_names=None,
    title="Latents DA: Target",
    n_bins=128, sigma=2.0,
    scatter_size=1.0, scatter_alpha=1.0,
    xlim=xlim, ylim=ylim
)
evaluation_tools.plot_latent_density_2d(
    latents_tSNE['latents_DA_Target_tSNE'],
    title="Latents DA: Target",
    density_method="hist", # or "kde"
    bins=256,
    sigma=2.0, # ignored if density_method="kde"
    norm_mode="max",
    color_scale="linear",
    mask_zero_support=True,
    contour_fracs=(0.01, 0.1, 0.3, 0.6),
    contour_colors="k",
    contour_linewidths=0.4,
    contour_label_fontsize=7,
    contour_label_color="k",
    show_points=False,
    points_alpha=0.1,
    points_size=2,
    random_subsample=None,
    xlim=xlim,
    ylim=ylim
)