# Learning embeddings into entropic Wasserstein spaces

## Imports

In [1]:
import matplotlib
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
import skimage.io as io
import skimage.transform as transform
import numpy as np 

import torch
import torch.nn as nn 
import torch.optim as optim
import h5py
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch.nn.functional as F 
from tqdm import tqdm
import math

if torch.cuda.is_available(): print('Using GPU acceleration')
else: print('Unable to access CUDA compatible GPU!! Please fix this before running the notebook.')
device = torch.device('cuda')
dtype = torch.cuda.FloatTensor

Using GPU acceleration


## Comparing t-SNE and Wasserstein embeddings on precomputed image representations

### Loading image representations

In [2]:
h5f = h5py.File('img_representations.h5','r')
image_representations = h5f['img_emb'][:]
h5f.close()

### Computing the t-SNE embedding and the associated visualization

In [None]:
def imscatter(x, y, paths, ax=None, zoom=1, linewidth=0):
    if ax is None:
        ax = plt.gca()
    x, y = np.atleast_1d(x, y)
    artists = []
    for x0, y0, p in zip(x, y, paths):
        try:
            im = io.imread(p)
        except:
            print(p)
            continue
        im = transform.resize(im,(224,224))
        im = OffsetImage(im, zoom=zoom)
        ab = AnnotationBbox(im, (x0, y0), xycoords='data',
                            frameon=True, pad=0.1, 
                            bboxprops=dict(edgecolor='red',
                                           linewidth=linewidth))
        artists.append(ax.add_artist(ab))
    ax.update_datalim(np.column_stack([x, y]))
    ax.autoscale()
    return artists

img_emb_tsne = TSNE(perplexity=30).fit_transform(image_representations)
plt.figure(figsize=(10, 10))
plt.scatter(img_emb_tsne[:, 0], img_emb_tsne[:, 1]);
plt.xticks(()); plt.yticks(());
plt.show()

import os
paths = ["images_resize/" + path
         for path in sorted(os.listdir("images_resize/"))]

fig, ax = plt.subplots(figsize=(50, 50))
imscatter(img_emb_tsne[:, 0], img_emb_tsne[:, 1], paths, zoom=0.5, ax=ax)
plt.show()

### Defining the entropy regularized Wasserstein distance

In [3]:
def entropy_regularized_wasserstein_distance(x,y,entropy_level,nb_sinkhorn_iterations,support_size):

    D_2 = torch.cdist(x,y,p=2,compute_mode='donot_use_mm_for_euclid_dist')

    K = torch.exp(-D_2/entropy_level)

    c = torch.ones(support_size).cuda()/support_size
    u = torch.ones(support_size).cuda()/support_size
    v = torch.ones(support_size).cuda()/support_size

    for iter in range(nb_sinkhorn_iterations):
        r =  torch.matmul(K, c)
        c = v / torch.matmul(K.t(), r)
    transport = torch.mm(torch.mm(torch.diag(r), K), torch.diag(c))
    return torch.trace(torch.mm(D_2.t(), transport))

### Defining the Wasserstein-based mapping

In [5]:
support_size = 5

class Mapping(nn.Module):
    def __init__(self,representation_size,hidden_size,support_size):
        super(Mapping, self).__init__()
        self.hidden = nn.Linear(representation_size, hidden_size)
        self.embedding = nn.Linear(hidden_size, 2*support_size)

    def forward(self, x):
        intermediate_representation = self.hidden(x)
        embedded_representation = self.embedding(intermediate_representation)
        return embedded_representation

### Training the Wasserstein embedding

In [None]:
representations = torch.Tensor(image_representations)
trainloader = torch.utils.data.DataLoader(representations, batch_size=128,
                                          shuffle=True, num_workers=6)
    
mapping = Mapping(representation_size=2048,hidden_size=64,support_size=support_size).cuda()
optimizer = optim.Adam(mapping.parameters())

num_epochs = 1
for epoch in range(num_epochs):
    for i,data in tqdm(enumerate(trainloader,start=0)):
        optimizer.zero_grad()
        num_representations = data.shape[0]
        embeddings = mapping(data.cuda()).view(num_representations,support_size,2).cuda()
        loss = torch.zeros(1).cuda()
        for i in range(num_representations):
            for j in range (num_representations):
                if j!=i:
                    original_distance = torch.norm(data[i]-data[j]).cuda()
                    embeddings_distance = entropy_regularized_wasserstein_distance(embeddings[i],embeddings[j],0.05,20,support_size).cuda()
                    loss += (original_distance - embeddings_distance) ** 2    
            print(loss)
        loss/= math.comb(2,num_representations)
        loss.backward()
        optimizer.step()

In [None]:
    for i,representation in enumerate(representations):
        for j,representation in enumerate(representations):
            if j != i:
                original_distance = np.linalg.norm(representation[i]-representation[j],2) ** 2
                wasserstein_distance = entropy_regularized_wasserstein_distance(y[i]-y[j])
                
a1 = np.array([0.,0.])
b1 = np.array([0.1,0.1])
c1 = np.array([0.2,0.2])
d1 = np.array([-0.1,-0.1])

A = [a1,b1,c1,d1]


a2 = np.array([0.,0.])
b2 = np.array([0.05,0.05])
c2 = np.array([0.2,0.2])

B = [a2,b2,c2]

x = torch.tensor([a1,b1,c1,d1])
y = torch.tensor([a2,b2,c2])

D = np.zeros(shape=(4,3))

for i in range(4):
    for j in range(3):
        D[i,j] = np.linalg.norm(A[i]-B[j],2)
        
#entropy_regularized_wasserstein_distance(x,y,0.05,20,3)
print(torch.cdist(x,y,p=2,compute_mode='donot_use_mm_for_euclid_dist'))
print(D)
