In [1]:
from sklearn.decomposition import PCA
import cv2
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import os
from pathlib import Path
from tqdm import tqdm
from collections import defaultdict
from copy import deepcopy
import argparse
import plotly.express as px
import random

random.seed(42)

from PIL import Image
import torch
from torchvision.models import resnet18
from torchvision import transforms

from sklearn.cluster import DBSCAN, AgglomerativeClustering
from sklearn.metrics import accuracy_score


def get_classes(data_path):
    class_to_idx = {}
    for index, folder_name in enumerate(os.listdir(data_path)):
        class_to_idx[folder_name] = index
    return class_to_idx


def liniarize_images(dataset):
    for data in dataset:
        data['image'] = data['image'].flatten()
    return dataset


def extract_images(dataset):
    return [data['image'] for data in dataset]


def read_featured_data(data_path, model, transformer, device):
    all_data = []
    model.eval()
    for class_name in CLASS_NAMES.keys():
        folder_path = data_path / class_name
        for index, img_name in tqdm(enumerate(os.listdir(folder_path))):
            full_image_path = folder_path / img_name
            full_image = Image.open(full_image_path)
            processed_image = transformer(full_image).unsqueeze(0).to(device)
            features = model(processed_image).flatten()
            all_data.append({
                "label": CLASS_NAMES[class_name],
                "image": features.detach().cpu().numpy(),
                "path": str(full_image_path)
            })
    return all_data


def class_matcher(search_data, predictions, class_names):
    results = {}
    all_classes_counters = {}
    for class_name in class_names.keys():
        all_classes_counters[class_name] = dict(class_correct_counter(
            class_names[class_name], search_data, predictions))
    all_classes_counters_copy = deepcopy(all_classes_counters)
    while len(all_classes_counters) > 0:
        pair_found = get_pair_key_of_max(all_classes_counters)
        results[pair_found[0]] = pair_found[1]

    return results, all_classes_counters_copy


def get_pair_key_of_max(cls_counters):
    max_found = -1
    for class_name in cls_counters:
        for pred, freq in cls_counters[class_name].items():
            if max_found < freq:
                max_found = freq
                pair_keys = (class_name, pred)
    # delete the max found
    del cls_counters[pair_keys[0]]

    classes_to_delete = []
    for class_name in cls_counters:
        if pair_keys[1] in cls_counters[class_name]:
            del cls_counters[class_name][pair_keys[1]]
            if len(cls_counters[class_name]) == 0:
                classes_to_delete.append(class_name)
    for key in classes_to_delete:
        del cls_counters[key]
    return pair_keys


def class_correct_counter(class_verified, dataset, predictions):
    classes_counter = defaultdict(lambda: 0)
    for index in range(len(dataset)):
        if predictions[index] != -1 and class_verified == dataset[index]['label']:
            classes_counter[predictions[index]] += 1
    return classes_counter


def filter_predictions(predictions, actual_classes, class_to_id):
    filtered_preds = []
    pred_to_label = get_pred_to_label_dict(actual_classes, class_to_id)
    print(pred_to_label)
    for pred in predictions:
        if pred in class_to_id.values():
            filtered_preds.append(pred_to_label[pred])
        else:
            filtered_preds.append(-1)
    return filtered_preds


def get_pred_to_label_dict(actual_classes, class_to_id):
    pred_to_label = {}
    for class_name in actual_classes:
        if class_name in class_to_id.keys():
            pred_to_label[class_to_id[class_name]] = actual_classes[class_name]
    return pred_to_label


def calculate_accuracy(dataset, predictions, actual_classes, class_to_id):
    labels = [data['label'] for data in dataset]
    final_preds = filter_predictions(predictions, actual_classes, class_to_id)
    return accuracy_score(labels, final_preds)


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--plot', action='store_true', help='Whether to plot stuffs')
    parser.add_argument('--verbose', action='store_true', help='Whether to print extra info')
    parser.add_argument('--type', type=str, default=DBSCAN, choices=['DBSCAN'], help='Type of model to train')
    args = parser.parse_args()
    return args


TRAIN_DATA_PATH = Path('./afhq/train')
VAL_DATA_PATH = Path('./afhq/val')
CLASS_NAMES = get_classes(TRAIN_DATA_PATH)

In [2]:

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html
base_model = resnet18(pretrained=True)
RESNET_EXTRACTOR = torch.nn.Sequential(
    *list(base_model.children())[:-1]).to(DEVICE)

IMAGE_EXTRACTOR = transforms.Compose([
    transforms.Resize(
        256, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(224),
    transforms.ToTensor(),  # converts to [0, 1]
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )])


In [3]:
# read the data
train_data = read_featured_data(
    TRAIN_DATA_PATH, RESNET_EXTRACTOR, IMAGE_EXTRACTOR, DEVICE)
val_data = read_featured_data(
    VAL_DATA_PATH, RESNET_EXTRACTOR, IMAGE_EXTRACTOR, DEVICE)

# delete the resnet model
del RESNET_EXTRACTOR
torch.cuda.empty_cache()

# extract images
X = extract_images(train_data)
X_val = extract_images(val_data)


4738it [00:45, 105.21it/s]
4739it [00:43, 109.70it/s]
5153it [00:48, 107.00it/s]
500it [00:04, 115.92it/s]
500it [00:04, 110.44it/s]
500it [00:04, 101.44it/s]


In [10]:
pca = PCA(n_components=2)
points = pca.fit_transform(X)
points_val = pca.transform(X_val)
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()

In [4]:
train_data_copy = random.sample(train_data, len(val_data))
X = extract_images(train_data_copy)
X_val = extract_images(val_data)

In [5]:
pca = PCA(n_components=2)
points = pca.fit_transform(X)
points_val = pca.transform(X_val)
fig = px.scatter(points, x=0, y=1)
fig.show()
fig = px.scatter(points_val, x=0, y=1)
fig.show()

In [8]:
def predict_dbscan(model, train, val, class_names):
    X = extract_images(train)
    X_val = extract_images(val)
    clustering = model.fit(X)
    train_preds = clustering.labels_
    val_preds = clustering.fit_predict(X_val)

    # get class matchers given predictions and labels
    class_to_preds, all_classes = class_matcher(train, train_preds, class_names)

    # if args.verbose:
    print("All classes:")
    print(all_classes)
    print("Matched classes:")
    print(class_to_preds)

    print("Predictions stats on train:")
    print(f'Different than -1: {np.count_nonzero(clustering.labels_ == -1)}')
    print(f'Max predicted class: {max(clustering.labels_)}')
    print(f'Length of x: {len(X)}')
    train_acc = calculate_accuracy(train, train_preds, class_names, class_to_preds)
    val_acc = calculate_accuracy(val, val_preds, class_names, class_to_preds)

    print(f'Final accuracy on train dataset:    {train_acc}')
    print(f'Final accuracy on val dataset:      {val_acc}')
    return model, train_acc, val_acc


In [5]:
train_data = train_data_copy


In [18]:
model = AgglomerativeClustering(n_clusters=3, linkage='ward')
predict_dbscan(model, train_data, val_data, CLASS_NAMES)

All classes:
{'wild': {0: 495, 2: 1, 1: 2}, 'dog': {2: 425, 0: 52}, 'cat': {1: 525}}
Matched classes:
{'cat': 1, 'wild': 0, 'dog': 2}
Predictions stats on train:
Different than -1: 0
Max predicted class: 2
Length of x: 1500
{0: 0, 2: 1, 1: 2}
{0: 0, 2: 1, 1: 2}
Final accuracy on train dataset:    0.9633333333333334
Final accuracy on val dataset:      0.39866666666666667


(AgglomerativeClustering(n_clusters=3),
 0.9633333333333334,
 0.39866666666666667)

In [36]:
model = AgglomerativeClustering(n_clusters=None, distance_threshold=300, compute_full_tree=True)
predict_dbscan(model, train_data, val_data, CLASS_NAMES)

All classes:
{'wild': {0: 496, 1: 2}, 'dog': {0: 477}, 'cat': {1: 525}}
Matched classes:
{'cat': 1, 'wild': 0}
Predictions stats on train:
Different than -1: 0
Max predicted class: 1
Length of x: 1500
{0: 0, 1: 2}
{0: 0, 1: 2}
Final accuracy on train dataset:    0.6806666666666666
Final accuracy on val dataset:      0.666


(AgglomerativeClustering(compute_full_tree=True, distance_threshold=300,
                         n_clusters=None),
 0.6806666666666666,
 0.666)

In [38]:
len(np.unique(model.labels_)) > 10
print(list(range(1, 201, 5)))

[1, 6, 11, 16, 21, 26, 31, 36, 41, 46, 51, 56, 61, 66, 71, 76, 81, 86, 91, 96, 101, 106, 111, 116, 121, 126, 131, 136, 141, 146, 151, 156, 161, 166, 171, 176, 181, 186, 191, 196]
