In [7]:
from sklearn.decomposition import PCA
import cv2
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import os
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
import argparse
import plotly.express as px
import random
import cv2

random.seed(42)

from PIL import Image
import torch
from torchvision.models import resnet18
from torchvision import transforms

from sklearn.cluster import DBSCAN, AgglomerativeClustering
from sklearn.metrics import accuracy_score


def get_classes(data_path):
    class_to_idx = {}
    for index, folder_name in enumerate(os.listdir(data_path)):
        class_to_idx[folder_name] = index
    return class_to_idx


def liniarize_images(dataset):
    for data in dataset:
        data['image'] = data['image'].flatten()
    return dataset


def extract_images(dataset):
    return [data['image'] for data in dataset]


def read_featured_data(data_path, orb):
    all_data = []
    backup_orb = cv2.ORB_create(nfeatures=orb.getMaxFeatures(), fastThreshold=0, edgeThreshold=0)
    for class_name in CLASS_NAMES.keys():
        folder_path = data_path / class_name
        for index, img_name in tqdm(enumerate(os.listdir(folder_path))):
            full_image_path = folder_path / img_name
            full_image = cv2.imread(str(full_image_path), cv2.IMREAD_COLOR)
            _, features = orb.detectAndCompute(full_image, None)
            if features is None:
                _, features = backup_orb.detectAndCompute(full_image, None)
            features = features.flatten()
            features = np.pad(features, (0, orb.getMaxFeatures()*32 - features.shape[0]), 'constant', constant_values=(0, features[-1]))
            features = features / 255.
            all_data.append({
                "label": CLASS_NAMES[class_name],
                "image": features,
                "path": str(full_image_path)
            })
    return all_data


def class_matcher(search_data, predictions, class_names):
    results = {}
    all_classes_counters = {}
    for class_name in class_names.keys():
        all_classes_counters[class_name] = dict(class_correct_counter(
            class_names[class_name], search_data, predictions))
    all_classes_counters_copy = deepcopy(all_classes_counters)
    while len(all_classes_counters) > 0:
        delete_check = []
        for key in all_classes_counters.keys():
            if len(all_classes_counters[key]) == 0:
                delete_check.append(key)
        for key in delete_check:
            del all_classes_counters[key]
        if len(all_classes_counters) == 0:
            continue
        pair_found = get_pair_key_of_max(all_classes_counters)
        results[pair_found[0]] = pair_found[1]

    return results, all_classes_counters_copy


def get_pair_key_of_max(cls_counters):
    max_found = -1
    for class_name in cls_counters:
        for pred, freq in cls_counters[class_name].items():
            if max_found < freq:
                max_found = freq
                pair_keys = (class_name, pred)
    # delete the max found
    del cls_counters[pair_keys[0]]

    classes_to_delete = []
    for class_name in cls_counters:
        if pair_keys[1] in cls_counters[class_name]:
            del cls_counters[class_name][pair_keys[1]]
            if len(cls_counters[class_name]) == 0:
                classes_to_delete.append(class_name)
    for key in classes_to_delete:
        del cls_counters[key]
    return pair_keys


def class_correct_counter(class_verified, dataset, predictions):
    classes_counter = defaultdict(lambda: 0)
    for index in range(len(dataset)):
        if predictions[index] != -1 and class_verified == dataset[index]['label']:
            classes_counter[predictions[index]] += 1
    return classes_counter


def filter_predictions(predictions, actual_classes, class_to_id):
    filtered_preds = []
    pred_to_label = get_pred_to_label_dict(actual_classes, class_to_id)
    print(pred_to_label)
    for pred in predictions:
        if pred in class_to_id.values():
            filtered_preds.append(pred_to_label[pred])
        else:
            filtered_preds.append(-1)
    return filtered_preds


def get_pred_to_label_dict(actual_classes, class_to_id):
    pred_to_label = {}
    for class_name in actual_classes:
        if class_name in class_to_id.keys():
            pred_to_label[class_to_id[class_name]] = actual_classes[class_name]
    return pred_to_label


def calculate_accuracy(dataset, predictions, actual_classes, class_to_id):
    labels = [data['label'] for data in dataset]
    final_preds = filter_predictions(predictions, actual_classes, class_to_id)
    return accuracy_score(labels, final_preds)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--plot', action='store_true', help='Whether to plot stuffs')
    parser.add_argument('--verbose', action='store_true', help='Whether to print extra info')
    parser.add_argument('--type', type=str, default=DBSCAN, choices=['DBSCAN'], help='Type of model to train')
    args = parser.parse_args()
    return args


TRAIN_DATA_PATH = Path('./afhq/train')
VAL_DATA_PATH = Path('./afhq/val')
CLASS_NAMES = get_classes(TRAIN_DATA_PATH)

In [2]:
# read the data
orb = cv2.ORB_create(nfeatures=100)
train_data = read_featured_data(TRAIN_DATA_PATH, orb)
val_data = read_featured_data(VAL_DATA_PATH, orb)

# delete the resnet model

# extract images
X = extract_images(train_data)
X_val = extract_images(val_data)


4738it [00:37, 127.58it/s]
4739it [00:31, 150.96it/s]
5153it [00:36, 139.73it/s]
500it [00:03, 132.44it/s]
500it [00:03, 153.39it/s]
500it [00:03, 145.41it/s]


In [3]:
from sklearn.preprocessing import StandardScaler
train_data = random.sample(train_data, len(val_data))
X = extract_images(train_data)
X_val = extract_images(val_data)
scaler = StandardScaler()
scaler.fit(X)
X = scaler.transform(X)
X_val = scaler.transform(X_val)

In [4]:
for index, sample in enumerate(train_data):
    sample['image'] = X[index]
for index, sample in enumerate(val_data):
    sample['image'] = X_val[index]

In [114]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=3)
points = tsne.fit_transform(np.array(X))
points_val = tsne.fit_transform(np.array(X_val))
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()


In [None]:
from sklearn.manifold import TSNE
tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=3)
points = tsne.fit_transform(np.array(X))
points_val = tsne.fit_transform(np.array(X_val))
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()


In [71]:
pca = PCA(n_components=2)
points = pca.fit_transform(X)
points_val = pca.transform(X_val)
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()

In [75]:
train_data_copy = random.sample(train_data, len(val_data))
X = extract_images(train_data_copy)
X_val = extract_images(val_data)

In [55]:
pca = PCA(n_components=2)
points = pca.fit_transform(X)
points_val = pca.transform(X_val)
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()

In [76]:
tsne = TSNE(n_components=2, learning_rate='auto',init='random', perplexity=3)
points = tsne.fit_transform(np.array(X))
points_val = tsne.fit_transform(np.array(X_val))
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()

In [5]:
def predict_dbscan(model, train, val, class_names):
    X = extract_images(train)
    X_val = extract_images(val)
    clustering = model.fit(X)
    train_preds = clustering.labels_
    val_preds = clustering.fit_predict(X_val)

    # get class matchers given predictions and labels
    class_to_preds, all_classes = class_matcher(train, train_preds, class_names)

    # if args.verbose:
    print("All classes:")
    print(all_classes)
    print("Matched classes:")
    print(class_to_preds)

    print("Predictions stats on train:")
    print(f'Different than -1: {np.count_nonzero(clustering.labels_ == -1)}')
    print(f'Max predicted class: {max(clustering.labels_)}')
    print(f'Length of x: {len(X)}')
    train_acc = calculate_accuracy(train, train_preds, class_names, class_to_preds)
    val_acc = calculate_accuracy(val, val_preds, class_names, class_to_preds)

    print(f'Final accuracy on train dataset:    {train_acc}')
    print(f'Final accuracy on val dataset:      {val_acc}')
    return model, train_acc, val_acc


In [1]:
dbscan = DBSCAN(eps=200.3, min_samples=2)
print(predict_dbscan(dbscan, train_data, val_data, CLASS_NAMES))

NameError: name 'DBSCAN' is not defined

In [89]:
model = AgglomerativeClustering(n_clusters=3, linkage='ward')
predict_dbscan(model, train_data, val_data, CLASS_NAMES)

All classes:
{'wild': {1: 295, 2: 66, 0: 137}, 'dog': {2: 167, 0: 228, 1: 82}, 'cat': {2: 171, 0: 196, 1: 158}}
Matched classes:
{'wild': 1, 'dog': 0, 'cat': 2}
Predictions stats on train:
Different than -1: 0
Max predicted class: 2
Length of x: 1500
{1: 0, 0: 1, 2: 2}
{1: 0, 0: 1, 2: 2}
Final accuracy on train dataset:    0.46266666666666667
Final accuracy on val dataset:      0.4826666666666667


(AgglomerativeClustering(n_clusters=3),
 0.46266666666666667,
 0.4826666666666667)

In [105]:
model = AgglomerativeClustering(n_clusters=None, distance_threshold=60, compute_full_tree=True)
predict_dbscan(model, train_data, val_data, CLASS_NAMES)

All classes:
{'wild': {1: 295, 2: 66, 0: 137}, 'dog': {2: 167, 0: 228, 1: 82}, 'cat': {2: 171, 0: 196, 1: 158}}
Matched classes:
{'wild': 1, 'dog': 0, 'cat': 2}
Predictions stats on train:
Different than -1: 0
Max predicted class: 2
Length of x: 1500
{1: 0, 0: 1, 2: 2}
{1: 0, 0: 1, 2: 2}
Final accuracy on train dataset:    0.46266666666666667
Final accuracy on val dataset:      0.4826666666666667


(AgglomerativeClustering(compute_full_tree=True, distance_threshold=60,
                         n_clusters=None),
 0.46266666666666667,
 0.4826666666666667)

In [38]:
len(np.unique(model.labels_)) > 10
print(list(range(1, 201, 5)))

[1, 6, 11, 16, 21, 26, 31, 36, 41, 46, 51, 56, 61, 66, 71, 76, 81, 86, 91, 96, 101, 106, 111, 116, 121, 126, 131, 136, 141, 146, 151, 156, 161, 166, 171, 176, 181, 186, 191, 196]
