In [None]:
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from sklearn.model_selection import train_test_split as tts
import os

from datasets import DatasetGenerator, PairedMNISTDataset
from helpers import EarlyStopper, classification_run, contrastive_run, unet_run, run_dswd
from models import TinyCNN, TinyCNN_Headless, TinyCNN_Head, WrapperModelTrainHead

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
mnist_train = torchvision.datasets.MNIST(root='./data', train=True, download=True)
mnist_test = torchvision.datasets.MNIST(root='./data', train=False, download=True)

x_train = mnist_train.data.numpy()
y_train = mnist_train.targets.numpy()

x_test = mnist_test.data.numpy()
y_test = mnist_test.targets.numpy()

x_test, x_val, y_test, y_val = tts(
    x_test, y_test, test_size=.5, random_state=71
)

## Base - None / Aux - Gap

In [None]:
base = "none"
aux = "skip"

In [None]:
train_ds_gen: DatasetGenerator = DatasetGenerator(
    images = x_train,
    labels = y_train,
    subset_ratio = .2,
    base_ds=base,
    aux_ds=aux
)

base_images_train, base_labels_train = train_ds_gen.build_base_dataset()
aux_images_train, aux_labels_train = train_ds_gen.build_aux_dataset()

print(f"Train Dataset Base Size: {len(base_images_train)}")
print(f"Train Dataset Aux Size: {len(aux_images_train)}")

test_ds_gen: DatasetGenerator = DatasetGenerator(
    images = x_test,
    labels = y_test,
    subset_ratio = .5,
    base_ds=base,
    aux_ds=aux
)

base_images_test, base_labels_test = test_ds_gen.build_base_dataset()
aux_images_test, aux_labels_test = test_ds_gen.build_aux_dataset()

print(f"Test Dataset Base Size: {len(base_images_test)}")
print(f"Test Dataset Aux Size: {len(aux_images_test)}")

val_ds_gen: DatasetGenerator = DatasetGenerator(
    images = x_val,
    labels = y_val,
    subset_ratio = .5,
    base_ds=base,
    aux_ds=aux
)

base_images_val, base_labels_val = val_ds_gen.build_base_dataset()
aux_images_val, aux_labels_val = val_ds_gen.build_aux_dataset()

print(f"Validation Dataset Base Size: {len(base_images_val)}")
print(f"Validation Dataset Aux Size: {len(aux_images_val)}")


In [None]:
train_dataset: PairedMNISTDataset = PairedMNISTDataset(
    base_images=base_images_train,
    base_labels=base_labels_train,
    aux_images=aux_images_train,
    aux_labels=aux_labels_train
)

test_dataset: PairedMNISTDataset = PairedMNISTDataset(
    base_images=base_images_test,
    base_labels=base_labels_test,
    aux_images=aux_images_test,
    aux_labels=aux_labels_test
)

val_dataset: PairedMNISTDataset = PairedMNISTDataset(
    base_images=base_images_val,
    base_labels=base_labels_val,
    aux_images=aux_images_val,
    aux_labels=aux_labels_val
)

train_loader: DataLoader = DataLoader(
    dataset=train_dataset,
    batch_size=16,
    shuffle=True
)

test_loader: DataLoader = DataLoader(
    dataset=test_dataset,
    batch_size=16
)

val_loader: DataLoader = DataLoader(
    dataset=val_dataset,
    batch_size=16
)

In [None]:
if not os.path.exists("models_unet"):
    os.mkdir("models_unet")

### Base Model Training

In [None]:
if not os.path.exists(f"models_unet/base_classifier_{base}.pt"):
    num_base_epochs = 20

    model = TinyCNN()
    model.to(DEVICE)
    optimizer = optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    base_early_stopper = EarlyStopper(
        patience=5,
        min_delta=0
    )

    base_best = {
        "val_loss": 1000,
        "val_acc": 0
    }

    for epoch in range(num_base_epochs):
        train_loss, train_acc = classification_run(
            model=model,
            optimizer=optimizer,
            dataloader=train_loader,
            mode="base_only",
            device=DEVICE,
        )

        val_loss, val_acc = classification_run(
            model=model,
            optimizer=optimizer,
            dataloader=val_loader,
            mode="base_only",
            device=DEVICE,
            train=False
        )

        print(f"Epoch {epoch+1}:", round(train_loss, 4), round(train_acc*100, 2), round(val_loss, 4), round(val_acc*100, 2))

        if val_loss < base_best["val_loss"]:
            base_best["val_loss"] = val_loss
            base_best["val_acc"] = val_acc
            torch.save(model.state_dict(), f"models_unet/base_classifier_{base}.pt")

        if base_early_stopper(val_loss):
            print("Stopped")
            break

### Mixed Model Training

In [None]:
if not os.path.exists(f"models_unet/mixed_classifier_{base}+{aux}.pt"):
    model = TinyCNN()
    model.to(DEVICE)
    optimizer = optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
    base_early_stopper = EarlyStopper(
        patience=5,
        min_delta=0
    )

    base_aux_best = {
        "val_loss": 1000,
        "val_acc": 0
    }

    for epoch in range(num_base_epochs):
        train_loss, train_acc = classification_run(
            model=model,
            optimizer=optimizer,
            dataloader=train_loader,
            mode="base_and_aux",
            device=DEVICE,
        )

        val_loss, val_acc = classification_run(
            model=model,
            optimizer=optimizer,
            dataloader=val_loader,
            device=DEVICE,
            mode="base_only",
            train=False
        )

        print(f"Epoch {epoch+1}:", round(train_loss, 4), round(train_acc*100, 2), round(val_loss, 4), round(val_acc*100, 2))

        if val_loss < base_aux_best["val_loss"]:
            base_aux_best["val_loss"] = val_loss
            base_aux_best["val_acc"] = val_acc
            torch.save(model.state_dict(), f"models_unet/mixed_classifier_{base}+{aux}.pt")

        if base_early_stopper(val_loss):
            print("Stopped")
            break


### Supervised Contrastive Learning Model

#### Body Training + Temperature Hyperparameter Selection

In [None]:
num_contrast_epochs = 200
temp_range = np.linspace(0.05, .15, 3)
best_val_loss = 1000*np.ones(len(temp_range))

for i, temp in enumerate(temp_range):
    model = TinyCNN_Headless()
    proj_head = torch.nn.Linear(32, 128)

    model.to(DEVICE)
    proj_head.to(DEVICE)

    contrast_optimizer = optim.Adam(
        list(model.parameters()) + list(proj_head.parameters()),
        lr=0.001, 
        weight_decay=1e-5
    )

    contrast_early_stopper = EarlyStopper(
        patience=10,
        min_delta=0
    )

    for epoch in range(num_contrast_epochs):
        train_loss = contrastive_run(
            model=model,
            proj_head=proj_head,
            optimizer=contrast_optimizer,
            dataloader=train_loader,
            device=DEVICE,
            temperature=temp
        )

        val_loss = contrastive_run(
            model=model,
            proj_head=proj_head,
            optimizer=contrast_optimizer,
            dataloader=val_loader,
            device=DEVICE,
            train=False,
            temperature=temp
        )

        if val_loss < best_val_loss[i]:
            best_val_loss[i] = val_loss
            torch.save(model.state_dict(), f"models_unet/contrast_body_{base}+{aux}_{round(temp, 2)}.pt")
            torch.save(proj_head.state_dict(), f"models_unet/contrast_proj_{base}+{aux}_{round(temp, 2)}.pt")

        if contrast_early_stopper(val_loss):
            print("\n")
            print(f"Best Val Loss ({round(temp, 2)}):", best_val_loss[i])
            break

        if (epoch+1 == num_contrast_epochs):
            print("\n")
            print(f"Best Val Loss ({round(temp, 2)}):", best_val_loss[i])

#### Supervised Contrastive Learning Head Training

In [None]:
best_temp = round(temp_range[np.argmin(best_val_loss)], 2)
print(f"Best temp: {best_temp}")

In [None]:
num_class_epochs = 20

contrast_body = TinyCNN_Headless()
contrast_body.load_state_dict(torch.load(f"models_unet/contrast_body_{base}+{aux}_{best_temp}.pt", weights_only=True))

class_head = TinyCNN_Head()

wrapped_model = WrapperModelTrainHead(
    body = contrast_body,
    head = class_head
)
wrapped_model.to(DEVICE)
optimizer = optim.Adam(
    wrapped_model.head.parameters(),
    lr = 0.001,
    weight_decay = 1e-5
)

contrast_early_stopper = EarlyStopper(
    patience=5,
    min_delta=0
)

contrast_best = {
    "val_loss": 1000,
    "val_acc": 0
}

for epoch in range(num_class_epochs):
    train_loss, train_acc = classification_run(
        model=wrapped_model,
        optimizer=optimizer,
        dataloader=train_loader,
        mode="base_and_aux",
        device=DEVICE,
    )

    val_loss, val_acc = classification_run(
        model=wrapped_model,
        optimizer=optimizer,
        dataloader=val_loader,
        device=DEVICE,
        mode="base_only",
        train=False
    )

    print(f"Epoch {epoch+1}:", round(train_loss, 4), round(train_acc*100, 2), round(val_loss, 4), round(val_acc*100, 2))

    if val_loss < contrast_best["val_loss"]:
        contrast_best["val_loss"] = val_loss
        contrast_best["val_acc"] = val_acc
        torch.save(wrapped_model.state_dict(), f"models_unet/contrast_classifier_{base}+{aux}.pt")

    if contrast_early_stopper(val_loss):
        print("\n")
        print("Best Val Loss:", contrast_best["val_loss"])
        print("Best Val Acc:", round(contrast_best["val_acc"]*100, 2))
        break

    if (epoch+1 == num_class_epochs):
        print("\n")
        print("Best Val Loss:", contrast_best["val_loss"])
        print("Best Val Acc:", round(contrast_best["val_acc"]*100, 2))

### Compare Accuracies Between Models

In [None]:
base_model = TinyCNN()
base_model.load_state_dict(torch.load(f"models_unet/base_classifier_{base}.pt", weights_only=True))
base_model.to(DEVICE)
base_loss, base_acc = classification_run(
    model=base_model,
    optimizer=optimizer,
    dataloader=val_loader,
    device=DEVICE,
    mode="base_only",
    train=False
)

mixed_model = TinyCNN()
mixed_model.load_state_dict(torch.load(f"models_unet/mixed_classifier_{base}+{aux}.pt", weights_only=True))
mixed_model.to(DEVICE)
mixed_loss, mixed_acc = classification_run(
    model=mixed_model,
    optimizer=optimizer,
    dataloader=val_loader,
    device=DEVICE,
    mode="base_only",
    train=False
)


contrast_body = TinyCNN_Headless()
contrast_head = TinyCNN_Head()

contrast_model = WrapperModelTrainHead(
    body=contrast_body,
    head=contrast_head
)
contrast_model.load_state_dict(torch.load(f"models_unet/contrast_classifier_{base}+{aux}.pt", weights_only=True))
contrast_model.to(DEVICE)

contrast_loss, contrast_acc = classification_run(
    model=contrast_model,
    optimizer=optimizer,
    dataloader=val_loader,
    device=DEVICE,
    mode="base_only",
    train=False
)

print(f"Base: {round(base_loss, 4)}, {round(base_acc*100, 2)}")
print(f"Base + Aux: {round(mixed_loss, 4)}, {round(mixed_acc*100, 2)}")
print(f"Base + Aux: {round(contrast_loss, 4)}, {round(contrast_acc*100, 2)}")

### Plot Latent Space Representations

In [None]:
base_model = TinyCNN()
base_model.load_state_dict(torch.load(f"models_unet/base_classifier_{base}.pt", weights_only=True))
base_model.eval()

mixed_model = TinyCNN()
mixed_model.load_state_dict(torch.load(f"models_unet/mixed_classifier_{base}+{aux}.pt", weights_only=True))
mixed_model.eval()

contrast_body = TinyCNN_Headless()
contrast_head = TinyCNN_Head()

contrast_model = WrapperModelTrainHead(
    body=contrast_body,
    head=contrast_head
)

contrast_model.load_state_dict(torch.load(f"models_unet/contrast_classifier_{base}+{aux}.pt", weights_only=True))
contrast_model.eval()

contrast_model.to(DEVICE)
base_model.to(DEVICE)
mixed_model.to(DEVICE)

base_embeds = np.zeros((len(test_dataset), 32))
mixed_embeds_base = np.zeros((len(test_dataset), 32))
mixed_embeds_aux = np.zeros((len(test_dataset), 32))
contrast_embeds_base = np.zeros((len(test_dataset), 32))
contrast_embeds_aux = np.zeros((len(test_dataset), 32))
labels = np.zeros(len(test_dataset))

test_loader.dataset.unique_sources = True

for i, (x,y,z) in enumerate(test_loader):
    x = x.to(DEVICE)
    y = y.to(DEVICE)
    with torch.no_grad():
        base_outputs = base_model(x)[9].cpu().numpy()
        mixed_outputs_base = mixed_model(x)[9].cpu().numpy()
        mixed_outputs_aux = mixed_model(y)[9].cpu().numpy()
        contrast_outputs_base = contrast_model(x)[9].cpu().numpy()
        contrast_outputs_aux = contrast_model(y)[9].cpu().numpy()

        base_embeds[i*8:i*8+8] = base_outputs
        mixed_embeds_base[i*8:i*8+8] = mixed_outputs_base
        mixed_embeds_aux[i*8:i*8+8] = mixed_outputs_aux
        contrast_embeds_base[i*8:i*8+8] = contrast_outputs_base
        contrast_embeds_aux[i*8:i*8+8] = contrast_outputs_aux
        labels[i*8:i*8+8] = z.cpu()


In [None]:
mixed_labels = np.concatenate((labels*2, labels*2+1))
mixed_embeds = np.concatenate((mixed_embeds_base, mixed_embeds_aux))

contrast_labels = np.concatenate((labels*2, labels*2+1))
contrast_embeds = np.concatenate((contrast_embeds_base, contrast_embeds_aux))

import numpy as np
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import matplotlib.colors as mcolors

plt.rcParams['font.size'] = 16

# Perform t-SNE for all three embedding sets
tsne = TSNE(n_components=2, random_state=42)
base_tsne = tsne.fit_transform(base_embeds)
mixed_tsne = tsne.fit_transform(mixed_embeds)
contrast_tsne = tsne.fit_transform(contrast_embeds)

# Convert labels to numpy if it's a torch tensor
if isinstance(labels, torch.Tensor):
    labels = labels.numpy()

# Create three subplots side by side
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(30, 8))

# Plot for base embeddings
scatter1 = ax1.scatter(base_tsne[:, 0], base_tsne[:, 1], c=labels, cmap='tab10', alpha=.7)
ax1.set_title(f'Base Model Embeddings (Accuracy: {round(base_acc*100, 2)}%)')
ax1.set_xlabel('t-SNE feature 1')
ax1.set_ylabel('t-SNE feature 2')
cbar = fig.colorbar(scatter1, ax=ax1)
ticks = np.arange(0, 10)
c_labels = [ 
    "0 - Base",
    "1 - Base",
    "2 - Base",
    "3 - Base",
    "4 - Base",
    "5 - Base",
    "6 - Base",
    "7 - Base",
    "8 - Base",
    "9 - Base",
]
cbar.set_ticks(ticks)
cbar.set_ticklabels(c_labels)
cbar.set_label("Classes")

tab20 = plt.cm.get_cmap('tab20')
color_dict = {i: tab20(i/20) for i in range(20)}

scatter2a = ax2.scatter(mixed_tsne[:len(mixed_tsne)//2, 0], mixed_tsne[:len(mixed_tsne)//2, 1], c=[color_dict[label] for label in mixed_labels[:len(mixed_tsne)//2]])
scatter2b = ax2.scatter(mixed_tsne[len(mixed_tsne)//2:, 0], mixed_tsne[len(mixed_tsne)//2:, 1], c=[color_dict[label] for label in mixed_labels[len(mixed_tsne)//2:]], marker="x", s=30)
ax2.set_title(f'Mixed Model Embeddings (Accuracy: {round(mixed_acc*100, 2)}%)')
ax2.set_xlabel('t-SNE feature 1')
ax2.set_ylabel('t-SNE feature 2')
colors = [color_dict[i] for i in range(20)]
cmap = mcolors.ListedColormap(colors)
bounds = np.arange(21)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create the colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax2, ticks=np.arange(0.5, 20))

c_labels = [
    "0 - Base", "0 - Aux",
    "1 - Base", "1 - Aux",
    "2 - Base", "2 - Aux",
    "3 - Base", "3 - Aux",
    "4 - Base", "4 - Aux",
    "5 - Base", "5 - Aux",
    "6 - Base", "6 - Aux",
    "7 - Base", "7 - Aux",
    "8 - Base", "8 - Aux",
    "9 - Base", "9 - Aux",
]
cbar.set_ticklabels(c_labels)
cbar.set_label("Classes")


# Plot for contrast embeddings
tab20 = plt.cm.get_cmap('tab20')
color_dict = {i: tab20(i/20) for i in range(20)}

scatter3a = ax3.scatter(contrast_tsne[:len(contrast_tsne)//2, 0], contrast_tsne[:len(contrast_tsne)//2, 1], c=[color_dict[label] for label in contrast_labels[:len(contrast_tsne)//2]])
scatter3b = ax3.scatter(contrast_tsne[len(contrast_tsne)//2:, 0], contrast_tsne[len(contrast_tsne)//2:, 1], c=[color_dict[label] for label in contrast_labels[len(contrast_tsne)//2:]], marker="x", s=30)
ax3.set_title(f'Contrastive Model Embeddings (Accuracy: {round(contrast_acc*100, 2)}%)')
ax3.set_xlabel('t-SNE feature 1')
ax3.set_ylabel('t-SNE feature 2')
colors = [color_dict[i] for i in range(20)]
cmap = mcolors.ListedColormap(colors)
bounds = np.arange(21)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create the colorbar
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax3, ticks=np.arange(0.5, 20))

c_labels = [
    "0 - Base", "0 - Aux",
    "1 - Base", "1 - Aux",
    "2 - Base", "2 - Aux",
    "3 - Base", "3 - Aux",
    "4 - Base", "4 - Aux",
    "5 - Base", "5 - Aux",
    "6 - Base", "6 - Aux",
    "7 - Base", "7 - Aux",
    "8 - Base", "8 - Aux",
    "9 - Base", "9 - Aux",
]
cbar.set_ticklabels(c_labels)
cbar.set_label("Classes")

fig.suptitle(f"Base - {base.capitalize()} / Auxiliary - {aux.capitalize()}")

plt.tight_layout()
plt.savefig(f'tsne_2d_{base}+{aux}.png', dpi=300)
plt.show()

### Calculate Divergence Between Layers of the Networks

In [None]:
base_model = TinyCNN()
base_model.load_state_dict(torch.load(f"models_unet/base_classifier_{base}.pt", weights_only=True))
base_model.eval()

mixed_model = TinyCNN()
mixed_model.load_state_dict(torch.load(f"models_unet/mixed_classifier_{base}+{aux}.pt", weights_only=True))
mixed_model.eval()

contrast_body = TinyCNN_Headless()
contrast_head = TinyCNN_Head()

contrast_model = WrapperModelTrainHead(
    body=contrast_body,
    head=contrast_head
)

contrast_model.load_state_dict(torch.load(f"models_unet/contrast_classifier_{base}+{aux}.pt", weights_only=True))
contrast_model.eval()

base_model.to(DEVICE)
mixed_model.to(DEVICE)
contrast_model.to(DEVICE)

In [None]:
dswd_loss_test = {
    "base": np.zeros((2, len(test_loader), 9)),
    "mixed": np.zeros((2, len(test_loader), 9)),
    "contrast": np.zeros((2, len(test_loader), 9)),
}
# 2 - base or mixed

networks = {
    "base": base_model,
    "mixed": mixed_model,
    "contrast": contrast_model,
}

for i, network in enumerate(networks.keys()):
    print(f"Starting network {network} : {i}/{len(networks)}")
    dswd_loss_base = run_dswd(
        model=networks[network],
        dataloader=test_loader,
        layers=8,
        device=DEVICE,
        base_only=True
    )
    dswd_loss_test[network][0] = dswd_loss_base

    dswd_loss_mixed = run_dswd(
        model=networks[network],
        dataloader=test_loader,
        layers=8,
        device=DEVICE,
        base_only=False
    )
    dswd_loss_test[network][1] = dswd_loss_mixed

In [None]:

dswd_loss_val = np.zeros((len(test_loader), 9))
val_loader.dataset.unique_sources = True
for i, (base, aux, label) in enumerate(val_loader):
    base = base.to(DEVICE)

    aux_outputs = aux_model(base)[:9]
    contrast_outputs = contrast_model(base)[:9]
    print(i)
    for j, (aux_layer, contrast_layer)  in enumerate(zip(aux_outputs, contrast_outputs)):

        aux_layer_flat = aux_layer.view(aux_layer.size(0), -1)
        contrast_layer_flat = contrast_layer.view(contrast_layer.size(0), -1)

        projnet = ProjNet(size=aux_layer_flat.size(1)).to(DEVICE)
        op_projnet = optim.Adam(
            projnet.parameters(),
            lr=0.001, 
            weight_decay=1e-5
        )

        dsw_loss = DSW(
            encoder=None,
            embedding_norm=1.0,
            num_projections=1000,
            projnet=projnet,
            op_projnet=op_projnet
        )

        dswd_loss_val[i, j] += dsw_loss(
            aux_layer_flat,
            contrast_layer_flat
        ) / aux_layer_flat.size(0)
        # print(j, dswd_loss_val[j,i])