In [1]:
import matplotlib.pyplot as plt
import pandas as pd
import torch
from image_classification_simulation.data.office31_loader import Office31Loader
from image_classification_simulation.models.clustering_tools import show_grid_images
from image_classification_simulation.image_search import ImageSimilaritySearch
from image_classification_simulation.utils.visualization_utils import show_grid_images
from image_classification_simulation.models.clustering_tools import get_clustering_metrics

In [2]:
hparams = {
    "num_workers": 2,
    'batch_size': 32,
    "image_size":224,
    "train_test_split":-1,
}
office_loader = Office31Loader(
    data_dir="../examples/data/domain_adaptation_images/amazon/images/",
    eval_dir="../examples/data/domain_adaptation_images/dslr/images/",
     hyper_params=hparams)
# office_loader.setup('fit')
# train_loader = office_loader.train_dataloader()
# val_loader = office_loader.val_dataloader()
# test_loader = office_loader.test_dataloader()
# # /network/projects/aia/img_classif_sim/vit/output/best_model
office_loader.setup('eval')
eval_loader = office_loader.eval_dataloader()
office_loader.setup('infer')
train_loader = office_loader.train_dataloader()

In [8]:

hparams_resnet = {
    "clustering_alg": "nn",
    "num_neighbors":20,
    "radius":0.5,
    "n_jobs":2,
    "loss": "CrossEntropyLoss",
    "batch_size": 100,
    "pretrained": True,
    "num_classes": 31,
    "path_to_model": "../examples/resnet/output/best_model/model.ckpt",
    "architecture": "resnet",
    "num_clusters": 31,
    "random_state": 0,
    "clustering_batch_size": 100,
    "size": 256,
    "reassignment_ratio": 0.05,
    "path_cluster_ids": "../debug/dataset_cluster_ids.csv",
}
hparams_vit = {
    "clustering_alg": "MiniBatchKMeans",
    "loss": "CrossEntropyLoss",
    "pretrained": True,
    "batch_size": 100,
    "num_classes": 31,
    "path_to_model": "/network/projects/aia/img_classif_sim/vit/output/best_model/model.ckpt",
    "architecture": "vit",
    "num_clusters": 100,
    "random_state": 0,
    "clustering_batch_size": 1024,
    "reassignment_ratio": 0.01,
    "init":'random',
    "path_cluster_ids": "../debug/dataset_cluster_ids.csv",
}
hparams_ae = {
    "clustering_alg": "MiniBatchKMeans",
    "loss": "CrossEntropyLoss",
    "pretrained": True,
    "batch_size": 100,
    "num_channels": 3,
    "num_classes": 31,
    "path_to_model": "/network/projects/aia/img_classif_sim/conv_ae/output/best_model/model.ckpt",
    "architecture": "conv_ae",
    "num_clusters": 32,
    "random_state": 0,
    "clustering_batch_size": 100,
    "reassignment_ratio": 0.05,
    "path_cluster_ids": "../debug/dataset_cluster_ids.csv",
}
hparams_cnn = {
        "clustering_alg": "nn",
        "num_neighbors":20,
        "radius":0.5,
        "n_jobs":2,
        "loss": "CrossEntropyLoss",
        "batch_size": 124,
        "num_channels": 3,
        "pretrained": True,
        "num_classes": 31,
        "img_size": 224,
        "path_to_model": "/network/projects/aia/img_classif_sim/classic_cnn/output/best_model/model.ckpt",
        "architecture": "classic-cnn",
        "num_clusters": 31,
        "random_state": 0,
        "clustering_batch_size": 124,
        "reassignment_ratio": 0.01,
        "path_cluster_ids": "../debug/dataset_cluster_ids.csv",
    }

archs = {
    "resnet": hparams_resnet,
    # "vit": hparams_vit,
    # "ae": hparams_ae,
    # "cnn":hparams_cnn
    }


In [6]:
labels_true = [label for image, label in office_loader.eval_set]
len(labels_true)


498

In [None]:
labels_pred = image_search.predict(office_loader.eval_dataloader())
len(labels_pred)

The following metrics do not work when we are using nearest neighbors.

In [7]:
m = get_clustering_metrics(labels_true, labels_pred)
m

{'rand_score': 0.024261940161551716,
 'adjusted_rand_score': 0.024261940161551716,
 'mutual_info_score': 0.5971366373315987}

In [21]:
m

{'rand_score': 0.03455624029458484,
 'adjusted_rand_score': 0.03455624029458484,
 'mutual_info_score': 0.6542970012509423}

you can use either images from the evaluation or the training set

In [8]:
data_dir = "../examples/data/domain_adaptation_images/amazon/images/"

In [15]:
eval_dir = '../examples/data/domain_adaptation_images/dslr/images/'

In [None]:
import os
import matplotlib.pyplot as plt

for arch in archs:
    image_search = ImageSimilaritySearch(archs[arch], office_loader)
    image_search.setup()
    for class_name in office_loader.dataset.class_to_idx:
        print(class_name)
        path = eval_dir+"{}/frame_0001.jpg".format(class_name)
        query_res = image_search.find_similar_images(path,None)
        fig,_ = show_grid_images(
            query_res['image_path'].tolist(),
            num_rows=5,
            num_cols=5,
            )
        fig.savefig('./results/'+arch+'/'+class_name+'.png',format='png')