In [None]:
import numpy as np
from sklearn.manifold import TSNE

import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties

In [None]:
colors = []
colors.append((23, 131, 232))
colors.append((2, 38, 110))
colors.append((185, 208, 241))
colors.append((118, 205, 3))
colors.append((162, 0, 0))
colors = [tuple(np.array(c) / 255.0) for c in colors]

In [None]:
def tsne_embeddings(embeddings_path, sample_size=None, sample_labels=None):
    embeddings = np.load(embeddings_path)
    
    print("Embeddings shape: ", embeddings.shape)

    image_features = embeddings[:, 0:512]
    text_features = embeddings[:, 512:1024]
    labels = embeddings[:, 1024]

    if sample_size is None and sample_labels is None:
        sample_size = 1000
    elif sample_labels is not None:
        sample_size = len(sample_labels)
    if sample_labels is None:
        sample_labels = np.random.choice(labels, sample_size, replace=False)
    print("Sample size: ", sample_size)
    print("Sample labels size: ", len(sample_labels))

    indices = np.where(np.in1d(labels, sample_labels))[0]
    if len(indices) > sample_size:
        indices = np.random.choice(indices, sample_size, replace=False)
    print("Indices size: ", len(indices))

    image_features = image_features[indices, :]
    text_features = text_features[indices, :]
    print("Image features: ", image_features.shape)
    print("Text features: ", text_features.shape)

    data = np.concatenate((image_features, text_features), axis=0)

    x_embedded = TSNE(
        n_components=2,
        learning_rate=300,
        init='pca',
        perplexity=int(sample_size / 5),
        metric='cosine'
    ).fit_transform(data)
    
    print("TSNE image samples: ", x_embedded[0:sample_size].shape)
    print("TSNE text samples: ", x_embedded[sample_size:].shape)
    return x_embedded[0:sample_size, :], x_embedded[sample_size:, :], sample_size, labels[indices]

In [None]:
num_samples = 100

In [None]:
zsclip_embeddings = tsne_embeddings("/root/industrial-clip/output/zsclip/embeddings.npy", sample_size=num_samples)

In [None]:
font = FontProperties()
font.set_name('Arial')

fig, axs = plt.subplots(1, 3, figsize=(30, 10))

axs[0].set_title("Zero-shot CLIP", fontproperties=font, fontweight='bold', fontsize=30, pad=9)
axs[0].scatter(zsclip_embeddings[0][:, 0], zsclip_embeddings[0][:, 1], color=colors[0], marker='o', s=80)
axs[0].scatter(zsclip_embeddings[1][:, 0], zsclip_embeddings[1][:, 1], color=colors[1], marker='x', s=80)

# axs[1]. ...
# axs[1]. ...
# axs[1]. ...

# axs[2]. ...
# axs[2]. ...
# axs[2]. ...

axs[0].set_xticks([])
axs[0].set_yticks([])
axs[1].set_xticks([])
axs[1].set_yticks([])
axs[2].set_xticks([])
axs[2].set_yticks([])

axs[0].legend(['Image embeddings', 'Text embeddings'], fontsize=25, loc='upper left')

fig.tight_layout(pad=0, h_pad=0.0, w_pad=1.0)

In [None]:
fig.subplots_adjust(bottom=0.0, top=0.94, left=00, right=1)
fig.savefig("tsne.pdf", format='pdf', transparent=True)