We're going to build a scatter plot based on ResNeXt50 embeddings, so first let's add some code that will let us extract embeddings for images using ResNeXt.

In [1]:
import torch
import torchvision.models as models
import torchvision.transforms as transforms
import numpy as np


DEVICE = torch.device('cpu')
OUTPUT_SIZE = 2048

model = models.resnext50_32x4d(weights=models.ResNeXt50_32X4D_Weights.IMAGENET1K_V2)

extraction_layer = model._modules.get('avgpool')
model.to(DEVICE)
model.eval()

scaler = transforms.Resize((224, 224))
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
to_tensor = transforms.ToTensor()

def get_vec(arg, model, extraction_layer):
    image = normalize(to_tensor(scaler(arg))).unsqueeze(0).to(DEVICE)
    result = torch.zeros(1, OUTPUT_SIZE, 1, 1)
    def copy_data(m, i, o):
        result.copy_(o.data)
    hooked = extraction_layer.register_forward_hook(copy_data)
    with torch.no_grad():
        model(image)
    hooked.remove()
    return result

Downloading: "https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth" to /root/.cache/torch/hub/checkpoints/resnext50_32x4d-1a0047aa.pth
100%|██████████| 95.8M/95.8M [00:00<00:00, 175MB/s]


Now let's get a sample of our PNG files and make thumbnails and get embeddings for them.

In [2]:
import arrow
import base64
import pandas as pd
from glob import glob
from io import BytesIO
from os.path import basename
from PIL import Image

DATA = '/kaggle/input/enhanced-poster-design-dataset/archive/*/*/*/*/*/'
STOP = 1000
THUMBNAIL_SIZE = (96, 96)


def embed(model, filename: str):
    with Image.open(fp=filename, mode='r') as image:
        return get_vec(arg=image.convert('RGB'), model=model, extraction_layer=extraction_layer).numpy().reshape(OUTPUT_SIZE,)


# https://stackoverflow.com/a/952952
def flatten(arg):
    return [x for xs in arg for x in xs]

def png(filename: str) -> str:
    with Image.open(fp=filename, mode='r') as image:
        buffer = BytesIO()
        # our images are pretty big; let's shrink the hover images to thumbnail size
        image.resize(size=THUMBNAIL_SIZE).convert('RGB').save(buffer, format='png')
        return 'data:image/png;base64,' + base64.b64encode(buffer.getvalue()).decode()

def get_picture_from_glob(arg: str, stop: int) -> list:
    time_get = arrow.now()
    result = [pd.Series(data=[basename(input_file), embed(model=model, filename=input_file), png(filename=input_file)],
                        index=['name', 'value', 'png'])
        for index, input_file in enumerate(glob(pathname=arg)) if index < stop and input_file.endswith('.png')]
    print('encoded {} rows in {}'.format(len(result), arrow.now() - time_get))
    return result

time_start = arrow.now()
df = pd.DataFrame(data=get_picture_from_glob(arg=DATA + '/*', stop=STOP))
print('done in {}'.format(arrow.now() - time_start))

encoded 520 rows in 0:01:22.295066
done in 0:01:22.336050


Now let's use TSNE to get the coordinates we will use for our scatter plot.

In [3]:
import arrow
from sklearn.manifold import TSNE

time_start = arrow.now()
reducer = TSNE(random_state=2025, verbose=True, n_jobs=1, perplexity=10.0, )
df[['x', 'y']] = reducer.fit_transform(X=df['value'].apply(func=pd.Series))
print('done with TSNE in {}'.format(arrow.now() - time_start))

[t-SNE] Computing 31 nearest neighbors...
[t-SNE] Indexed 520 samples in 0.003s...
[t-SNE] Computed neighbors for 520 samples in 0.106s...
[t-SNE] Computed conditional probabilities for sample 520 / 520
[t-SNE] Mean sigma: 3.264663
[t-SNE] KL divergence after 250 iterations with early exaggeration: 77.599281
[t-SNE] KL divergence after 1000 iterations: 0.929716
done with TSNE in 0:00:02.269320


We have the data we need; let's build our interactive scatter plot.

In [4]:
from bokeh.models import ColumnDataSource
from bokeh.models import HoverTool

from bokeh.plotting import figure
from bokeh.plotting import output_notebook
from bokeh.plotting import show

output_notebook()

datasource = ColumnDataSource(df[['x', 'y', 'png']])

plot_figure = figure(title='TSNE projection: posters', width=1000, height=800, tools=('pan, wheel_zoom, reset'))

plot_figure.add_tools(HoverTool(tooltips="""
<div>
    <div>
        <img src='@png' style='float: left; margin: 5px 5px 5px 5px'/>
    </div>
</div>
"""))

plot_figure.scatter(x='x', y='y', source=datasource, line_alpha=0.6, fill_alpha=0.6, size=10, )
show(plot_figure)

What do we see? We have a lot of duplicate files; the files we have that aren't duplicates form a big cluster, and a lot of them don't look much like ads or posters. Once we get away from those, we don't see a lot of clustering, other than the duplicate pairs.