In [1]:
# what kind of files do we have?
from glob import glob
from pathlib import Path
from collections import Counter

GLOB = '/kaggle/input/chessman-image-dataset/Chessman-image-dataset/Chess/*/*'
suffixes = Counter([Path(input_file).suffix for input_file in glob(pathname=GLOB)])

print(suffixes)

Counter({'.jpg': 465, '.png': 69, '.JPG': 10, '.jpeg': 7, '.fcgi': 2, '.php': 1, '.webp': 1, '.gif': 1})


We have some files that are not images and we need to be sure to ignore them.

Let's build a DataFrame of tags and file names.

In [2]:
import pandas as pd
from pathlib import Path

SUFFIXES = {'.jpg', '.png', '.JPG', '.jpeg', '.gif'}
data = []
for input_file in glob(pathname=GLOB):
    path = Path(input_file)
    if path.suffix in SUFFIXES:
        data.append(pd.Series(data={'tag': str(path.parents[0]).split('/')[-1], 'name': input_file}))
df = pd.DataFrame(data=data)

We know already that we have unequal numbers of pictures of the different pieces; how unequal are they?

In [3]:
import warnings
from plotly import express

warnings.filterwarnings(action='ignore', category=FutureWarning)
express.pie(data_frame=df['tag'].value_counts().to_frame().reset_index(), names='tag', values='count').update_traces(hoverinfo='label+percent', textinfo='value')

In [4]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image

device = torch.device('cpu')
# model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).to(device=device)
model = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1).to(device=device)
model.eval()

layer = model._modules.get('avgpool')

layer_output_size = 512

scaler = transforms.Resize(size=(224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

df['data'] = df['name'].apply(func=lambda x: Image.open(fp=x, mode='r').convert('RGB'))
model_input = [normalize(to_tensor(scaler(item))) for item in df['data'].tolist()]
images = torch.stack(model_input).to(device)
df['value'] = [item.detach().numpy() for item in model(images)]
df = df.drop(columns=['data'])

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth
100%|██████████| 83.3M/83.3M [00:00<00:00, 116MB/s]


In [5]:
import arrow
from umap import UMAP

time_start = arrow.now()
umap = UMAP(random_state=2024, verbose=True, n_jobs=1, low_memory=False, n_epochs=500)
df[['x', 'y']] = umap.fit_transform(X=df['value'].apply(pd.Series))
print('done with UMAP in {}'.format(arrow.now() - time_start))

2024-07-29 15:32:59.498332: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-29 15:32:59.498451: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-29 15:32:59.643625: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


UMAP(low_memory=False, n_epochs=500, n_jobs=1, random_state=2024, verbose=True)
Mon Jul 29 15:33:10 2024 Construct fuzzy simplicial set
Mon Jul 29 15:33:11 2024 Finding Nearest Neighbors
Mon Jul 29 15:33:15 2024 Finished Nearest Neighbor Search
Mon Jul 29 15:33:19 2024 Construct embedding


Epochs completed:   0%|            0/500 [00:00]

	completed  0  /  500 epochs
	completed  50  /  500 epochs
	completed  100  /  500 epochs
	completed  150  /  500 epochs
	completed  200  /  500 epochs
	completed  250  /  500 epochs
	completed  300  /  500 epochs
	completed  350  /  500 epochs
	completed  400  /  500 epochs
	completed  450  /  500 epochs
Mon Jul 29 15:33:21 2024 Finished embedding
done with UMAP in 0:00:10.956037


In [6]:
import warnings
from plotly import express

warnings.filterwarnings(action='ignore', category=FutureWarning)
express.scatter(data_frame=df, x='x', y='y', color='tag')