In [2]:

import csv
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import os
from pathlib import Path
import cv2

import torch
from torchvision.models import resnet18

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.cluster import DBSCAN
from sklearn.metrics import accuracy_score

SEED = 42
TRAIN_DATA_PATH = Path('./afhq/train')
VAL_DATA_PATH = Path('./afhq/val')
IMG_SHAPE = [64, 64]
EPOCHS = 50

def get_classes(data_path):
    class_to_idx = {}
    for index, folder_name in enumerate(os.listdir(data_path)):
        class_to_idx[folder_name] = index
    return class_to_idx


def liniarize_images(dataset):
    for data in dataset:
        data['image'] = data['image'].flatten()
    return dataset

def extract_images(dataset):
    return [data['image'] for data in dataset]

def read_all_data(data_path):
    all_data = []
    for class_name in CLASS_NAMES.keys():
        folder_path = data_path / class_name
        for index, img_name in enumerate(os.listdir(folder_path)):
            full_image_path = folder_path / img_name
            full_image = cv2.imread(str(full_image_path), cv2.IMREAD_GRAYSCALE) 
            reshaped_image = cv2.resize(full_image, IMG_SHAPE, interpolation = cv2.INTER_AREA) / 255.
            all_data.append({
                "label": CLASS_NAMES[class_name],
                "image": reshaped_image,
                "path": str(full_image_path)
            })
    return all_data

CLASS_NAMES = get_classes(TRAIN_DATA_PATH)



In [3]:
train_data = liniarize_images(read_all_data(TRAIN_DATA_PATH))
val_data = liniarize_images(read_all_data(VAL_DATA_PATH))
X = extract_images(train_data)

In [5]:
base_model = resnet18(pretrained=True)
base_model = torch.nn.Sequential(*list(base_model.children())[:-1])
print(base_model)

Sequential(
  (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU(inplace=True)
  (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Con

In [6]:
print(dir(base_model))

['T_destination', '__add__', '__annotations__', '__call__', '__class__', '__delattr__', '__delitem__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__iadd__', '__imul__', '__init__', '__init_subclass__', '__iter__', '__le__', '__len__', '__lt__', '__module__', '__mul__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__rmul__', '__setattr__', '__setitem__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_buffers', '_call_impl', '_forward_hooks', '_forward_pre_hooks', '_get_backward_hooks', '_get_item_by_idx', '_get_name', '_is_full_backward_hook', '_load_from_state_dict', '_load_state_dict_post_hooks', '_load_state_dict_pre_hooks', '_maybe_warn_non_full_backward_hook', '_modules', '_named_members', '_non_persistent_buffers_set', '_parameters', '_register_load_state_dict_pre_hook', '_register_state_dict_hook', '_repl

In [36]:
clustering = DBSCAN(eps=2, min_samples=5).fit(X)
clustering.labels_



array([-1, -1, -1, ..., -1, -1, -1])

In [37]:
print(np.count_nonzero(clustering.labels_ == -1))
print
print(len(X))

14630
14630


In [3]:
max(clustering.labels_)

def match_classes(dataset, predictions):
    



3