In [4]:
import click
import os
import numpy as np
import torch
from torch.utils.data import DataLoader

from tqdm import tqdm
from util import configure_device, get_dataset
from models.split_vae import VAE
import matplotlib
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns

@click.group()
def cli():
    pass

def extract(
    orig_img_id = "00084",
    root=None,
    vae_chkpt_path="/home/bias-team/Mo_Projects/DiffuseVAE/logs/vae_carla_day/checkpoints/vae-carla_day-epoch=1499-train_loss=0.0000.ckpt",
    device="gpu:0",
    dataset_name="carla",
    image_size=128,
    save_path=os.getcwd(),
):
    # dev, _ = configure_device(device)
    dev = 'cuda'
    root = os.path.join(root, orig_img_id)
    # Dataset
    dataset = get_dataset(dataset_name, root, image_size, norm=False, flip=False)

    # Loader
    loader = DataLoader(
        dataset,
        1,
        num_workers=1,
        pin_memory=True,
        shuffle=False,
        drop_last=False,
    )

    # Load VAE
    vae = VAE.load_from_checkpoint(vae_chkpt_path, input_res=image_size).to(dev)
    vae.eval()

    z_list = []
    for _, batch in tqdm(enumerate(loader)):
        batch = batch.to(dev)
        with torch.no_grad():

            batch_aux, _ = vae.scramble(batch)
            mu, logvar = vae.encode(batch)
            mu_aux, logvar_aux = vae.encode_aux(batch_aux)

            z_main = vae.reparameterize(mu, logvar)
            z_aux = vae.reparameterize(mu_aux, logvar_aux)

            # z_main = torch.cat((mu,mu_aux), dim=1)
            z_main = mu_aux
            # z_main = mu
        # Not transferring to CPU leads to memory overflow in GPU!
        z_list.append(z_main.cpu().detach().numpy())
    z_arr = np.array(z_list)
    image_labels = len(z_list)*[orig_img_id]
    np.save(os.path.join(root, f"{orig_img_id}_split.npy"), z_arr)
    return z_arr, image_labels

def get_latent(data_path, vae_chkpt_path):
    img_id_list = os.listdir(data_path)

    for i, img_id in enumerate(img_id_list):
        z, image_labels = extract(orig_img_id=img_id, root=data_path, vae_chkpt_path=vae_chkpt_path)
        z = np.squeeze(z)
        if len(z.shape) == 1:
            z = np.expand_dims(z, axis=0)        
        if i == 0:
            z_total = z
            image_labels_total = image_labels
        else:
            z_total = np.concatenate((z_total, z), axis=0)
            image_labels_total += image_labels
    labels, colors = get_class_map(image_labels_total, img_id_list)
    return z_total, labels


def get_class_map(image_labels_total, img_id_list):
    # Create Mapping function from img_id to class 0,1,2,3,...
    img_class_map = {}
    colors = []
    labels = []
    for i, img_id in enumerate(img_id_list):
        img_class_map[img_id] = i
        colors.append(i)

    for i in image_labels_total:
        temp = img_class_map[i]
        labels.append(temp)
    labels = np.array(labels)
    return labels, colors



        

In [5]:
from sklearn.svm import LinearSVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import datasets
from sklearn import svm
from sklearn.metrics import accuracy_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.linear_model import LogisticRegression
## CARLA
# test_dir = "/home/bias-team/Mo_Projects/DiffuseVAE/ood_experiment/carla2_ood/tests"
test_dir = "/home/bias-team/Mo_Projects/DiffuseVAE/ood_experiment/carla2_ood/tests2"
test_name_list = os.listdir(test_dir)
train_score_dict = {}
test_score_dict = {}
for test_name in test_name_list:
    train_score_list = []
    test_score_list = []
    test_path = os.path.join(test_dir, test_name)
    for seed in range(1,11):
        vae_ckpt_path =  f"/home/bias-team/Mo_Projects/DiffuseVAE/logs/vae_carla_day{seed}/checkpoints/vae-carla_day-epoch=1499-train_loss=0.0000.ckpt"
        x_train, y_train = get_latent("/home/bias-team/Mo_Projects/DiffuseVAE/ood_experiment/carla2_ood/train",
                                    vae_ckpt_path)
        x_test, y_test = get_latent(test_path,
                                    vae_ckpt_path)

        clf =  make_pipeline(StandardScaler(), KNeighborsClassifier(n_neighbors=1))
        clf.fit(x_train, y_train)
        y_pred_train = clf.predict(x_train)
        y_pred_test = clf.predict(x_test)
        train_score = accuracy_score(y_train, y_pred_train)
        test_score = accuracy_score(y_test, y_pred_test)
        train_score_list.append(train_score)
        test_score_list.append(test_score)
        print(f"{test_name} seed {seed}")
        print("train:", train_score)
        print("test:", test_score)
    train_score_dict[test_name] = train_score_list
    test_score_dict[test_name] = test_score_list


1it [00:00, 16.39it/s]
1it [00:00, 17.95it/s]
1it [00:00, 18.10it/s]
1it [00:00, 12.54it/s]
1it [00:00, 16.31it/s]
1it [00:00, 16.85it/s]
1it [00:00, 13.14it/s]
1it [00:00, 13.07it/s]
1it [00:00, 14.13it/s]
1it [00:00, 17.41it/s]
10it [00:00, 59.93it/s]
10it [00:00, 70.83it/s]
10it [00:00, 64.56it/s]
10it [00:00, 64.83it/s]
10it [00:00, 58.52it/s]
10it [00:00, 61.79it/s]
10it [00:00, 64.27it/s]
10it [00:00, 67.45it/s]
10it [00:00, 65.26it/s]
10it [00:00, 65.72it/s]


test_flare2 seed 1
train: 1.0
test: 0.08


1it [00:00, 17.53it/s]
1it [00:00, 16.29it/s]
1it [00:00, 16.88it/s]
1it [00:00, 17.35it/s]
1it [00:00, 12.67it/s]
1it [00:00, 18.54it/s]
1it [00:00, 13.06it/s]
1it [00:00, 16.58it/s]
1it [00:00, 17.42it/s]
1it [00:00, 16.43it/s]
10it [00:00, 70.04it/s]
10it [00:00, 59.50it/s]
10it [00:00, 64.62it/s]
10it [00:00, 62.24it/s]
10it [00:00, 64.50it/s]
10it [00:00, 58.87it/s]
10it [00:00, 61.17it/s]
10it [00:00, 59.02it/s]
10it [00:00, 60.08it/s]
10it [00:00, 62.31it/s]


test_flare2 seed 2
train: 1.0
test: 0.18


1it [00:00, 17.53it/s]
1it [00:00, 14.72it/s]
1it [00:00, 16.37it/s]
1it [00:00, 15.28it/s]
1it [00:00, 13.34it/s]
1it [00:00, 15.34it/s]
1it [00:00, 13.68it/s]
1it [00:00, 13.31it/s]
1it [00:00, 17.46it/s]
1it [00:00, 17.45it/s]
10it [00:00, 60.36it/s]
10it [00:00, 61.11it/s]
10it [00:00, 63.58it/s]
10it [00:00, 65.09it/s]
10it [00:00, 67.30it/s]
10it [00:00, 66.33it/s]
10it [00:00, 66.68it/s]
10it [00:00, 70.16it/s]
10it [00:00, 69.45it/s]
10it [00:00, 61.08it/s]


test_flare2 seed 3
train: 1.0
test: 0.14


1it [00:00, 12.86it/s]
1it [00:00, 16.21it/s]
1it [00:00, 13.24it/s]
1it [00:00, 14.66it/s]
1it [00:00, 16.71it/s]
1it [00:00, 15.73it/s]
1it [00:00, 15.57it/s]
1it [00:00, 15.32it/s]
1it [00:00, 13.64it/s]
1it [00:00, 12.30it/s]
10it [00:00, 56.89it/s]
10it [00:00, 55.14it/s]
10it [00:00, 50.49it/s]
10it [00:00, 59.06it/s]
10it [00:00, 56.31it/s]
10it [00:00, 56.94it/s]
10it [00:00, 59.36it/s]
10it [00:00, 61.47it/s]
10it [00:00, 64.01it/s]
10it [00:00, 54.82it/s]


test_flare2 seed 4
train: 1.0
test: 0.11


1it [00:00, 11.53it/s]
1it [00:00, 15.46it/s]
1it [00:00, 16.58it/s]
1it [00:00, 16.08it/s]
1it [00:00, 16.39it/s]
1it [00:00, 14.89it/s]
1it [00:00, 14.87it/s]
1it [00:00, 15.06it/s]
1it [00:00, 15.65it/s]
1it [00:00, 12.98it/s]
10it [00:00, 66.13it/s]
10it [00:00, 63.34it/s]
10it [00:00, 65.75it/s]
10it [00:00, 62.85it/s]
10it [00:00, 60.62it/s]
10it [00:00, 56.53it/s]
10it [00:00, 65.65it/s]
10it [00:00, 60.25it/s]
10it [00:00, 50.70it/s]
10it [00:00, 50.90it/s]


test_flare2 seed 5
train: 1.0
test: 0.1


1it [00:00, 15.24it/s]
1it [00:00, 14.46it/s]
1it [00:00, 14.35it/s]
1it [00:00, 15.86it/s]
1it [00:00, 14.58it/s]
1it [00:00, 14.71it/s]
1it [00:00, 16.62it/s]
1it [00:00, 14.82it/s]
1it [00:00, 16.46it/s]
1it [00:00, 14.30it/s]
10it [00:00, 61.55it/s]
10it [00:00, 54.42it/s]
10it [00:00, 57.85it/s]
10it [00:00, 57.88it/s]
10it [00:00, 58.29it/s]
10it [00:00, 58.06it/s]
10it [00:00, 52.86it/s]
10it [00:00, 56.00it/s]
10it [00:00, 63.21it/s]
10it [00:00, 54.07it/s]


test_flare2 seed 6
train: 1.0
test: 0.26


1it [00:00, 12.72it/s]
1it [00:00, 15.34it/s]
1it [00:00, 15.92it/s]
1it [00:00, 13.85it/s]
1it [00:00, 13.91it/s]
1it [00:00, 13.55it/s]
1it [00:00, 11.68it/s]
1it [00:00, 14.18it/s]
1it [00:00, 16.91it/s]
1it [00:00, 14.05it/s]
10it [00:00, 68.33it/s]
10it [00:00, 63.33it/s]
10it [00:00, 52.70it/s]
10it [00:00, 64.02it/s]
10it [00:00, 50.73it/s]
10it [00:00, 60.21it/s]
10it [00:00, 49.94it/s]
10it [00:00, 55.33it/s]
10it [00:00, 58.14it/s]
10it [00:00, 57.85it/s]


test_flare2 seed 7
train: 1.0
test: 0.1


1it [00:00, 13.89it/s]
1it [00:00, 15.34it/s]
1it [00:00, 13.94it/s]
1it [00:00, 15.31it/s]
1it [00:00, 16.73it/s]
1it [00:00, 10.74it/s]
1it [00:00, 15.45it/s]
1it [00:00, 12.47it/s]
1it [00:00, 15.83it/s]
1it [00:00, 13.77it/s]
10it [00:00, 62.24it/s]
10it [00:00, 57.84it/s]
10it [00:00, 54.97it/s]
10it [00:00, 65.38it/s]
10it [00:00, 59.43it/s]
10it [00:00, 54.72it/s]
10it [00:00, 55.50it/s]
10it [00:00, 54.94it/s]
10it [00:00, 57.67it/s]
10it [00:00, 53.32it/s]


test_flare2 seed 8
train: 1.0
test: 0.18


1it [00:00, 13.84it/s]
1it [00:00, 15.86it/s]
1it [00:00, 18.20it/s]
1it [00:00, 16.59it/s]
1it [00:00, 15.10it/s]
1it [00:00, 18.95it/s]
1it [00:00, 18.55it/s]
1it [00:00, 18.44it/s]
1it [00:00, 17.79it/s]
1it [00:00, 15.74it/s]
10it [00:00, 65.96it/s]
10it [00:00, 65.22it/s]
10it [00:00, 66.50it/s]
10it [00:00, 70.78it/s]
10it [00:00, 69.35it/s]
10it [00:00, 57.54it/s]
10it [00:00, 62.57it/s]
10it [00:00, 71.48it/s]
10it [00:00, 68.24it/s]
10it [00:00, 64.01it/s]


test_flare2 seed 9
train: 1.0
test: 0.15


1it [00:00, 16.82it/s]
1it [00:00, 16.79it/s]
1it [00:00, 15.99it/s]
1it [00:00, 16.38it/s]
1it [00:00, 17.06it/s]
1it [00:00, 18.03it/s]
1it [00:00, 18.23it/s]
1it [00:00, 16.98it/s]
1it [00:00, 18.80it/s]
1it [00:00, 16.26it/s]
10it [00:00, 69.66it/s]
10it [00:00, 68.22it/s]
10it [00:00, 69.40it/s]
10it [00:00, 71.72it/s]
10it [00:00, 69.87it/s]
10it [00:00, 69.95it/s]
10it [00:00, 71.63it/s]
10it [00:00, 68.93it/s]
10it [00:00, 72.23it/s]
10it [00:00, 73.95it/s]

test_flare2 seed 10
train: 1.0
test: 0.1





In [7]:
# # print(test_score_dict)
# rain_acc, rain_sd = np.mean(test_score_dict['test_rain']), np.std(test_score_dict['test_rain'])
# fog_acc, fog_sd = np.mean(test_score_dict['test_fog']), np.std(test_score_dict['test_fog'])
# flare_acc, flare_sd = np.mean(test_score_dict['test_flare']), np.std(test_score_dict['test_flare'])
# snow_acc, snow_sd = np.mean(test_score_dict['test_snow']), np.std(test_score_dict['test_snow'])
# shadow_acc, shadow_sd = np.mean(test_score_dict['test_shadow']), np.std(test_score_dict['test_shadow'])

# print(f"Rain: {rain_acc} +- {rain_sd}")
# print(f"Fog: {fog_acc} +- {fog_sd}")
# print(f"Flare: {flare_acc} +- {flare_sd}")
# print(f"Snow: {snow_acc} +- {snow_sd}")
# print(f"Shadow: {shadow_acc} +- {shadow_sd}")

flare2_acc, flare2_sd = np.mean(test_score_dict['test_flare2']), np.std(test_score_dict['test_flare2'])
print(f"Flare2: {flare2_acc} +- {flare2_sd}")

Flare2: 0.13999999999999999 +- 0.05196152422706632
