In [4]:
import click
import os
import numpy as np
import torch
from torch.utils.data import DataLoader

from tqdm import tqdm
from util import configure_device, get_dataset
from models.split_vae import VAE
import matplotlib
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns

@click.group()
def cli():
    pass

def extract(
    orig_img_id = "00084",
    root=None,
    vae_chkpt_path="/home/bias-team/Mo_Projects/DiffuseVAE/logs/vae_carla_day/checkpoints/vae-carla_day-epoch=1499-train_loss=0.0000.ckpt",
    device="gpu:0",
    dataset_name="carla",
    image_size=128,
    save_path=os.getcwd(),
):
    # dev, _ = configure_device(device)
    dev = 'cuda'
    root = os.path.join(root, orig_img_id)
    # Dataset
    dataset = get_dataset(dataset_name, root, image_size, norm=False, flip=False)

    # Loader
    loader = DataLoader(
        dataset,
        1,
        num_workers=1,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
    )

    # Load VAE
    vae = VAE.load_from_checkpoint(vae_chkpt_path, input_res=image_size).to(dev)
    vae.eval()

    z_list = []
    for _, batch in tqdm(enumerate(loader)):
        batch = batch.to(dev)
        with torch.no_grad():

            batch_aux, _ = vae.scramble(batch)
            mu, logvar = vae.encode(batch)
            mu_aux, logvar_aux = vae.encode_aux(batch_aux)

            z_main = vae.reparameterize(mu, logvar)
            z_aux = vae.reparameterize(mu_aux, logvar_aux)

            # z_main = torch.cat((mu,mu_aux), dim=1)
            # z_main = mu_aux
            z_main = mu
        # Not transferring to CPU leads to memory overflow in GPU!
        z_list.append(z_main.cpu().detach().numpy())
    z_arr = np.array(z_list)
    image_labels = len(z_list)*[orig_img_id]
    np.save(os.path.join(root, f"{orig_img_id}_split.npy"), z_arr)
    return z_arr, image_labels


def get_latent(data_path, vae_chkpt_path):
    img_id_list = os.listdir(data_path)

    for i, img_id in enumerate(img_id_list):
        z, image_labels = extract(orig_img_id=img_id, root=data_path, vae_chkpt_path=vae_chkpt_path)
        z = np.squeeze(z)
        if len(z.shape) == 1:
            z = np.expand_dims(z, axis=0)        
        if i == 0:
            z_total = z
            image_labels_total = image_labels
        else:
            z_total = np.concatenate((z_total, z), axis=0)
            image_labels_total += image_labels
    labels, colors = get_class_map(image_labels_total, img_id_list)
    return z_total, labels

def get_class_map(image_labels_total, img_id_list):
    # Create Mapping function from img_id to class 0,1,2,3,...
    img_class_map = {}
    colors = []
    labels = []
    for i, img_id in enumerate(img_id_list):
        img_class_map[img_id] = i
        colors.append(i)

    for i in image_labels_total:
        temp = img_class_map[i]
        labels.append(temp)
    labels = np.array(labels)
    return labels, colors



        

In [5]:
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
## CARLA
# test_dir = "/home/bias-team/Mo_Projects/DiffuseVAE/ood_experiment/carla2_ood/tests"
test_dir = "/home/bias-team/Mo_Projects/DiffuseVAE/ood_experiment/carla2_ood/tests2"
test_name_list = os.listdir(test_dir)
train_score_dict = {}
test_score_dict = {}
for test_name in test_name_list:
    train_score_list = []
    test_score_list = []
    test_path = os.path.join(test_dir, test_name)
    for seed in range(1,11):
        vae_ckpt_path =  f"/home/bias-team/Mo_Projects/DiffuseVAE/logs/vae_carla_day{seed}/checkpoints/vae-carla_day-epoch=1499-train_loss=0.0000.ckpt"
        x_train, y_train = get_latent("/home/bias-team/Mo_Projects/DiffuseVAE/ood_experiment/carla2_ood/train",
                                    vae_ckpt_path)
        x_test, y_test = get_latent(test_path,
                                    vae_ckpt_path)

        clf =  make_pipeline(StandardScaler(), KNeighborsClassifier(n_neighbors=1))
        clf.fit(x_train, y_train)
        y_pred_train = clf.predict(x_train)
        y_pred_test = clf.predict(x_test)
        train_score = accuracy_score(y_train, y_pred_train)
        test_score = accuracy_score(y_test, y_pred_test)
        train_score_list.append(train_score)
        test_score_list.append(test_score)
        print(f"{test_name} seed {seed}")
        print("train:", train_score)
        print("test:", test_score)
    train_score_dict[test_name] = train_score_list
    test_score_dict[test_name] = test_score_list


1it [00:00, 17.72it/s]
1it [00:00, 18.40it/s]
1it [00:00, 18.86it/s]
1it [00:00, 19.21it/s]
1it [00:00, 20.18it/s]
1it [00:00, 20.35it/s]
1it [00:00, 20.14it/s]
1it [00:00, 21.01it/s]
1it [00:00, 20.48it/s]
1it [00:00, 18.70it/s]
10it [00:00, 75.28it/s]
10it [00:00, 75.31it/s]
10it [00:00, 75.53it/s]
10it [00:00, 61.90it/s]
10it [00:00, 73.32it/s]
10it [00:00, 71.63it/s]
10it [00:00, 68.70it/s]
10it [00:00, 72.31it/s]
10it [00:00, 59.88it/s]
10it [00:00, 67.55it/s]


test_flare2 seed 1
train: 1.0
test: 0.69


1it [00:00, 12.90it/s]
1it [00:00, 15.75it/s]
1it [00:00, 17.70it/s]
1it [00:00, 19.01it/s]
1it [00:00, 20.30it/s]
1it [00:00, 20.00it/s]
1it [00:00, 16.68it/s]
1it [00:00, 19.50it/s]
1it [00:00, 19.94it/s]
1it [00:00, 19.22it/s]
10it [00:00, 72.41it/s]
10it [00:00, 70.36it/s]
10it [00:00, 70.05it/s]
10it [00:00, 70.89it/s]
10it [00:00, 73.68it/s]
10it [00:00, 67.38it/s]
10it [00:00, 67.14it/s]
10it [00:00, 69.95it/s]
10it [00:00, 71.80it/s]
10it [00:00, 69.28it/s]


test_flare2 seed 2
train: 1.0
test: 0.76


1it [00:00, 16.13it/s]
1it [00:00, 19.74it/s]
1it [00:00, 18.86it/s]
1it [00:00, 18.50it/s]
1it [00:00, 12.22it/s]
1it [00:00, 16.86it/s]
1it [00:00, 16.27it/s]
1it [00:00, 16.99it/s]
1it [00:00, 17.37it/s]
1it [00:00, 17.36it/s]
10it [00:00, 60.99it/s]
10it [00:00, 72.43it/s]
10it [00:00, 73.90it/s]
10it [00:00, 66.55it/s]
10it [00:00, 66.29it/s]
10it [00:00, 63.31it/s]
10it [00:00, 64.94it/s]
10it [00:00, 68.66it/s]
10it [00:00, 67.87it/s]
10it [00:00, 67.99it/s]


test_flare2 seed 3
train: 1.0
test: 0.52


1it [00:00, 17.56it/s]
1it [00:00, 18.43it/s]
1it [00:00, 19.94it/s]
1it [00:00, 19.45it/s]
1it [00:00, 17.91it/s]
1it [00:00, 15.68it/s]
1it [00:00, 16.12it/s]
1it [00:00, 13.47it/s]
1it [00:00, 16.78it/s]
1it [00:00, 18.09it/s]
10it [00:00, 63.63it/s]
10it [00:00, 70.34it/s]
10it [00:00, 66.83it/s]
10it [00:00, 66.43it/s]
10it [00:00, 60.69it/s]
10it [00:00, 62.75it/s]
10it [00:00, 59.04it/s]
10it [00:00, 67.87it/s]
10it [00:00, 66.96it/s]
10it [00:00, 59.51it/s]


test_flare2 seed 4
train: 1.0
test: 0.47


1it [00:00, 16.60it/s]
1it [00:00, 15.33it/s]
1it [00:00, 11.28it/s]
1it [00:00, 16.92it/s]
1it [00:00, 15.22it/s]
1it [00:00, 16.81it/s]
1it [00:00, 16.68it/s]
1it [00:00, 19.09it/s]
1it [00:00, 17.79it/s]
1it [00:00, 18.45it/s]
10it [00:00, 57.16it/s]
10it [00:00, 56.06it/s]
10it [00:00, 69.74it/s]
10it [00:00, 60.96it/s]
10it [00:00, 53.02it/s]
10it [00:00, 60.31it/s]
10it [00:00, 63.38it/s]
10it [00:00, 63.09it/s]
10it [00:00, 58.60it/s]
10it [00:00, 53.56it/s]


test_flare2 seed 5
train: 1.0
test: 0.71


1it [00:00, 13.74it/s]
1it [00:00, 14.32it/s]
1it [00:00, 14.42it/s]
1it [00:00, 10.18it/s]
1it [00:00, 18.87it/s]
1it [00:00, 11.42it/s]
1it [00:00, 13.09it/s]
1it [00:00, 16.34it/s]
1it [00:00, 12.80it/s]
1it [00:00, 16.75it/s]
10it [00:00, 53.13it/s]
10it [00:00, 54.29it/s]
10it [00:00, 58.38it/s]
10it [00:00, 58.07it/s]
10it [00:00, 59.77it/s]
10it [00:00, 66.81it/s]
10it [00:00, 57.23it/s]
10it [00:00, 52.34it/s]
10it [00:00, 58.14it/s]
10it [00:00, 62.10it/s]


test_flare2 seed 6
train: 1.0
test: 0.59


1it [00:00, 18.37it/s]
1it [00:00, 18.06it/s]
1it [00:00, 17.11it/s]
1it [00:00, 16.44it/s]
1it [00:00, 15.05it/s]
1it [00:00, 13.83it/s]
1it [00:00, 15.37it/s]
1it [00:00, 11.69it/s]
1it [00:00, 12.40it/s]
1it [00:00, 14.20it/s]
10it [00:00, 60.90it/s]
10it [00:00, 68.81it/s]
10it [00:00, 56.29it/s]
10it [00:00, 57.59it/s]
10it [00:00, 58.97it/s]
10it [00:00, 52.51it/s]
10it [00:00, 58.54it/s]
10it [00:00, 66.39it/s]
10it [00:00, 51.80it/s]
10it [00:00, 59.58it/s]


test_flare2 seed 7
train: 1.0
test: 0.62


1it [00:00, 14.56it/s]
1it [00:00, 15.86it/s]
1it [00:00, 15.67it/s]
1it [00:00, 16.81it/s]
1it [00:00, 17.63it/s]
1it [00:00, 16.23it/s]
1it [00:00, 16.09it/s]
1it [00:00, 15.07it/s]
1it [00:00, 16.74it/s]
1it [00:00, 16.35it/s]
10it [00:00, 59.67it/s]
10it [00:00, 51.44it/s]
10it [00:00, 56.78it/s]
10it [00:00, 68.90it/s]
10it [00:00, 57.82it/s]
10it [00:00, 59.17it/s]
10it [00:00, 65.23it/s]
10it [00:00, 61.59it/s]
10it [00:00, 69.01it/s]
10it [00:00, 61.44it/s]


test_flare2 seed 8
train: 1.0
test: 0.71


1it [00:00, 14.65it/s]
1it [00:00, 17.94it/s]
1it [00:00, 16.96it/s]
1it [00:00, 17.38it/s]
1it [00:00, 11.89it/s]
1it [00:00, 15.41it/s]
1it [00:00, 14.63it/s]
1it [00:00, 17.44it/s]
1it [00:00, 17.80it/s]
1it [00:00, 16.28it/s]
10it [00:00, 58.99it/s]
10it [00:00, 65.83it/s]
10it [00:00, 58.62it/s]
10it [00:00, 64.04it/s]
10it [00:00, 58.30it/s]
10it [00:00, 55.91it/s]
10it [00:00, 53.11it/s]
10it [00:00, 61.94it/s]
10it [00:00, 69.05it/s]
10it [00:00, 71.42it/s]


test_flare2 seed 9
train: 1.0
test: 0.51


1it [00:00, 17.59it/s]
1it [00:00, 17.21it/s]
1it [00:00, 16.98it/s]
1it [00:00, 18.35it/s]
1it [00:00, 14.50it/s]
1it [00:00, 16.45it/s]
1it [00:00, 16.51it/s]
1it [00:00, 17.56it/s]
1it [00:00, 19.49it/s]
1it [00:00, 18.25it/s]
10it [00:00, 71.27it/s]
10it [00:00, 64.66it/s]
10it [00:00, 60.82it/s]
10it [00:00, 69.51it/s]
10it [00:00, 66.21it/s]
10it [00:00, 63.39it/s]
10it [00:00, 67.46it/s]
10it [00:00, 56.58it/s]
10it [00:00, 66.21it/s]
10it [00:00, 68.63it/s]

test_flare2 seed 10
train: 1.0
test: 0.61





In [7]:
# # print(test_score_dict)
# rain_acc, rain_sd = np.mean(test_score_dict['test_rain']), np.std(test_score_dict['test_rain'])
# fog_acc, fog_sd = np.mean(test_score_dict['test_fog']), np.std(test_score_dict['test_fog'])
# flare_acc, flare_sd = np.mean(test_score_dict['test_flare']), np.std(test_score_dict['test_flare'])
# snow_acc, snow_sd = np.mean(test_score_dict['test_snow']), np.std(test_score_dict['test_snow'])
# shadow_acc, shadow_sd = np.mean(test_score_dict['test_shadow']), np.std(test_score_dict['test_shadow'])

# print(f"Rain: {rain_acc} +- {rain_sd}")
# print(f"Fog: {fog_acc} +- {fog_sd}")
# print(f"Flare: {flare_acc} +- {flare_sd}")
# print(f"Snow: {snow_acc} +- {snow_sd}")
# print(f"Shadow: {shadow_acc} +- {shadow_sd}")

flare2_acc, flare2_sd = np.mean(test_score_dict['test_flare2']), np.std(test_score_dict['test_flare2'])
print(f"Flare2: {flare2_acc} +- {flare2_sd}")

Flare2: 0.619 +- 0.09289241088485108
