# Domain Adaptation Between Digits

#### *Rémi Flamary, Nicolas Courty*

In this practical session, we will apply on digit classification the OT based domain adaptation method proposed in 

N. Courty, R. Flamary, D. Tuia, A. Rakotomamonjy, "[Optimal transport for domain adaptation](http://remi.flamary.com/biblio/courty2016optimal.pdf)", Pattern Analysis and Machine Intelligence, IEEE Transactions on, 2016.

![otda.png](http://remi.flamary.com/cours/otml/otda.png)

To this end, we will try and adapt between the MNIST and USPS datasets. Since those datasets do not have the same resolution (28x28 and 16x16 for MNSIT and USPS) we perform a zeros padding of the USPS digits 


####  Import modules

First, we import the relevant modules. Note that you will need ```sklearn``` to learn the Support Vector Machine classifier and to projet the data with TSNE.

In [None]:
import numpy as np # always need it
import pylab as pl # do the plots

# Uncomment the next line to install sklearn with pip
# !pip install -U scikit-learn

# OR

# Uncomment the next line to install sklearn with conda
# !conda install scikit-learn

# Import sklearn
from sklearn.svm import SVC
from sklearn.manifold import TSNE

# Import pot
import ot

### Loading data and normalization

We load the data in memory and perform a normalization of the images so that they all sum to 1.

Note that every line in the ```xs``` and ```xt``` is a 28x28 image.

In [None]:
data=np.load('data/mnist_usps.npz')

xs,ys=data['xs'],data['ys']
xt,yt=data['xt'],data['yt']


# normalization
xs=xs/xs.sum(1,keepdims=True) # every l
xt=xt/xt.sum(1,keepdims=True)

ns=xs.shape[0]
nt=xt.shape[0]

### Vizualizing Source (MNIST) and Target (USPS) datasets





In [None]:
# function for plotting images
def plot_image(x):
    pl.imshow(x.reshape((28,28)),cmap='gray')
    pl.xticks(())
    pl.yticks(())


nb=10

# Fisrt we plot MNIST
pl.figure(1,(nb,nb))
for i in range(nb*nb):
    pl.subplot(nb,nb,1+i)
    c=i%nb
    plot_image(xs[np.where(ys==c)[0][i//nb],:])
pl.gcf().suptitle("MNIST", fontsize=20);
pl.gcf().subplots_adjust(top=0.95)
    
# Then we plot USPS
pl.figure(2,(nb,nb))
for i in range(nb*nb):
    pl.subplot(nb,nb,1+i)
    c=i%nb
    plot_image(xt[np.where(yt==c)[0][i//nb],:])
pl.gcf().suptitle("USPS", fontsize=20);
pl.gcf().subplots_adjust(top=0.95)

Note that there is a large discrepancy especially between the 1,2 and 5 that have different shapes in both datasets.

Also since we have performed zero-padding on the USPS digits they are on average slightly smaller than NMSIT that can take the whole image.


### Classification without domain adaptation

We learn a classifier on the MNIST dataset (we will not be state of the art on 1000 samples). We evaluate this classifier on MNIST and the USPS dataset.

In [None]:
# Train SVM with reg parameter C=1 and RBF kernel parameter gamma=1e1
clf=SVC(C=1,gamma=1e2) # might take time
clf.fit(xs,ys)

# Compute accuracy
ACC_MNIST=clf.score(xs,ys) # beware of overfitting !
ACC_USPS=clf.score(xt,yt)

print('ACC_MNIST={:1.3f}'.format(ACC_MNIST))
print('ACC_USPS={:1.3f}'.format(ACC_USPS))

There is a very large loss in performances. This can be better explained by performing a TSNE embedding on the data.

### TSNE of the Source/Target domains

[t-distributed stochastic neighbour embedding (TSNE)](http://www.jmlr.org/papers/volume9/vandermaaten08a/vandermaaten08a.pdf) is a well-known approach that allows the projection of complex high dimensional data in a lower-dimensional space while keeping its structure.

In [None]:
xtot=np.concatenate((xs,xt),axis=0) # all data

xp=TSNE().fit_transform(xtot) # this might take a while

# separate again; now in 2D
xps=xp[:ns,:] 
xpt=xp[ns:,:]

In [None]:
# Display some plots
pl.figure(3,(12,10))

pl.scatter(xps[:,0],xps[:,1],c=ys,marker='o',cmap='tab10',label='Source data')
pl.scatter(xpt[:,0],xpt[:,1],c=yt,marker='+',cmap='tab10',label='Target data')
pl.legend()
pl.colorbar()
pl.title('TSNE Embedding of the Source/Target data')

We can see that while the classes are relatively well clustered, the clusters from source and target dataset rarely overlap. This is the main reason for the important loss in performance between Source and target.

### Optimal Transport Domain Adaptation (OTDA)

Now we perform domain adaptation with the following 3 steps illustrated at the top of the notebook:

1. Compute the OT matrix between source and target datasets
1. Perform OT mapping with barycentric mapping (```np.dot```).
1. Estimate classifier on the mapped source samples

#### 1. OT between domain

First, we compute the Cost matrix and visualize it. Note that the samples are sorted by class in both source and target domains to better see the class-based structure in the cost matrix and OT matrix.

In [None]:
# Your code below






# Click or Run the "..." to see a possible solution

In [None]:
C=ot.dist(xs,xt)

pl.figure(4,(10,10))
pl.imshow(C)
pl.title('Cost matrix')

We can see the (noisy) structure in the matrix. It is also interesting to note that the class 1 in USPS (second column) is particularly different from all the other classes in MNIST data (even class 1).


Next we compute the OT matrix using exact LP OT [ot.emd](http://pot.readthedocs.io/en/stable/all.html#ot.emd) or regularized OT with  [ot.sinkhorn](http://pot.readthedocs.io/en/stable/all.html#ot.sinkhorn).

In [None]:
# Your code below






# Click or Run the "..." to see a possible solution

In [None]:
G=ot.emd(ot.unif(ns),ot.unif(nt),C)

reg=.5e-4
# G=ot.sinkhorn(ot.unif(ns),ot.unif(nt),C,reg)

pl.figure(5,(10,10))
pl.imshow(G,interpolation='bilinear',vmax=G.max()/10)
pl.title('OT matrix')

We can see that most of the transportation is done in the block-diagonal which means that in average samples from one class are affected by the proper class in the target.

#### 2/3 Mapping + Classification

Now we perform the barycentric mapping of the samples and training the classifier on the mapped samples. We recommend using a smaller ```gamma=1e1``` here because some samples will be mislabeled and a smooth classifier will work better.

In [None]:
# Your code below






# Click or Run the "..." to see a possible solution

In [None]:
xst=ns*G.dot(xt)

clf=SVC(C=1,gamma=1e1)

clf.fit(xst,ys)

ACC_USPS2=clf.score(xt,yt)

print('ACC_MNIST={:1.3f}'.format(ACC_MNIST))
print('ACC_USPS={:1.3f}'.format(ACC_USPS))
print('ACC_USPS2={:1.3f}'.format(ACC_USPS2))

We can see that the adaptation with EMD leads to a performance gain of nearly 10%. You can get even better performances using entropic regularized OT or group lasso regularization.

#### TNSE vizualization for OTDA

To see the effect of the adaptation we can perform a new TSNE embedding to see if the classes are better aligned.

In [None]:
# Your code below






# Click or Run the "..." to see a possible solution

In [None]:
xtot=np.concatenate((xst,xt),axis=0)

xp=TSNE().fit_transform(xtot)

xps=xp[:ns,:]
xpt=xp[ns:,:]


pl.figure(6,(12,10))

pl.scatter(xps[:,0],xps[:,1],c=ys,marker='o',cmap='tab10',label='Source data')
pl.scatter(xpt[:,0],xpt[:,1],c=yt,marker='+',cmap='tab10',label='Target data')
pl.legend()
pl.colorbar()
pl.title('TSNE Embedding of the OT Adapted Source/Target data')

We can see that when using emd solver the OT matrix is a permutation where the samples are exactly superimposed. In average the classes are also well transported but there exist many badly transported samples that have a class permutation.


#### Transported samples visualization

We can now also plot the transported samples.

In [None]:
# Your code below






# Click or Run the "..." to see a possible solution

In [None]:
# Fisrt we plot MNIST
pl.figure(1,(nb,nb))
for i in range(nb*nb):
    pl.subplot(nb,nb,1+i)
    c=i%nb
    plot_image(xst[np.where(ys==c)[0][i//nb],:])
pl.gcf().suptitle("Transported MNIST", fontsize=20);
pl.gcf().subplots_adjust(top=0.95)

Those are the same MNIST samples that have been plotted above but after transportation. Several samples are transported on the wrong class but again on average, the class information is preserved which explain the accuracy gain.

### OTDA with regularization

We now recommend to try regularized OT and to redo classification/TSNE/Vizu to see the impact of the regularization in term of performances, TNSE and transported samples.

In [None]:
# Your code below



