In [151]:
import numpy as np
import matplotlib.pyplot as plt

from sklearn.datasets import load_digits
from sklearn.utils import check_random_state
from sklearn.metrics import pairwise_distances
from sklearn.datasets import fetch_lfw_people

class MDS:
    def __init__(self, n_components=2, learning_rate=0.01, epochs=10):
        self.n_components = n_components
        self.learning_rate = learning_rate
        self.epochs = epochs

    def fit_transform(self, D):
        X = np.random.rand(D.shape[0], self.n_components)
        
        for epoch in range(self.epochs):
            print('Epoch: ', epoch, ' Loss: ', self.loss(X, D))
            
            if epoch > 50:
                self.learning_rate /= 2
            
            stress = self.compute_stress(X, D)
            
            for i in range(len(X)):
                X[i] -= self.learning_rate * stress[i]  
        
        return X

    def loss(self, X, D):
        stress = 0

        for i in range(len(X)):
            for j in range(len(X)):
                stress += (D[i][j] - np.linalg.norm(X[i] - X[j])) ** 2

        stress /= sum(sum(D ** 2))
        
        return stress

    def compute_stress(self, X, D):
        stress = np.zeros((X.shape[0], X.shape[1]))
        
        for i in range(len(X)):
            val = 0
            for j in range(len(X)):
                stress[i] += -(X[i] - X[j]) * 4 * abs(D[i][j] - np.linalg.norm(X[i] - X[j])) / (np.linalg.norm(X[i] - X[j]) + 1e-07)

            stress[i] /= np.sqrt(sum(sum(D ** 2)))
        
        return stress


if __name__ == '__main__':
    #data = np.random.rand(100, 100)
    data = load_digits()
    
    n_samples = 200

    labels = data.target[:n_samples]
    D = pairwise_distances(data.data[:n_samples])

    model = MDS(n_components=2, learning_rate=0.5, epochs=100)
    X = model.fit_transform(D)

    #print('Distance:', D)
    #print('New Distance:', pairwise_distances(X))

    plt.scatter(X[:, 0], X[:, 1], c=labels)
    plt.show()


Epoch:  0  Loss:  0.9792624305640373
Epoch:  1  Loss:  0.9172243208387868
Epoch:  2  Loss:  0.8598306082233247
Epoch:  3  Loss:  0.8069006265824079
Epoch:  4  Loss:  0.7581256574508846
Epoch:  5  Loss:  0.7131956020670182
Epoch:  6  Loss:  0.6718141894036627
Epoch:  7  Loss:  0.6337055325499622
Epoch:  8  Loss:  0.5986113474917838
Epoch:  9  Loss:  0.5662893057459648
Epoch:  10  Loss:  0.5365154810918733
Epoch:  11  Loss:  0.509082311133298
Epoch:  12  Loss:  0.4837987877805218
Epoch:  13  Loss:  0.46048888750831185
Epoch:  14  Loss:  0.4389907100740836
Epoch:  15  Loss:  0.4191578774476496
Epoch:  16  Loss:  0.4008566918546083
Epoch:  17  Loss:  0.3839646120854466
Epoch:  18  Loss:  0.36836976751565464
Epoch:  19  Loss:  0.3539688708253321
Epoch:  20  Loss:  0.3406656050169164
Epoch:  21  Loss:  0.3283723767063917
Epoch:  22  Loss:  0.31700778968964394
Epoch:  23  Loss:  0.3064963957543225
Epoch:  24  Loss:  0.29676804093401987
Epoch:  25  Loss:  0.28775781573976494
Epoch:  26  Loss: 

KeyboardInterrupt: 