In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from deepshape.surfaces import * 
import numpy as np
import matplotlib.pyplot as plt 

In [3]:
import idx2numpy
from sklearn.cluster import AgglomerativeClustering
from sklearn.manifold import MDS
from sklearn_extra.cluster import KMedoids
from torchvision.transforms.functional import gaussian_blur


from sklearn.metrics import confusion_matrix, multilabel_confusion_matrix

In [4]:
def load_digits(examples_per_digit):
    # Load Data
    imgfile = "../data/t10k-images.idx3-ubyte"
    imgdata = torch.tensor( idx2numpy.convert_from_file(imgfile), dtype=torch.float )
    labelfile = "../data/t10k-labels.idx1-ubyte"
    labels = torch.tensor(idx2numpy.convert_from_file(labelfile))
    imgdata = gaussian_blur(imgdata, [3, 3])
    imgdata /= imgdata.max()
    
    digits = torch.empty(10 *  examples_per_digit, 28, 28)
    sublabels = torch.empty(10 * examples_per_digit, dtype=int)

    for i in range(10):
        start = i * examples_per_digit
        digits[start:start+examples_per_digit] = imgdata[labels == i][:examples_per_digit]
        sublabels[start:start+examples_per_digit] = i
    return digits, sublabels

In [5]:
class ImageDistance(SurfaceDistance, ComponentDistance):
    def loss_func(self, U, Y):
        return ((self.Q - self.r(Y))**2).sum() / self.k

In [None]:
digits, sublabels = load_digits(2)
distance = ImageWarpDistance(k=32)

optim_builder = optimizer_builder(torch.optim.LBFGS, lr=1.0, max_iter=100)

# Total number of images
num_images = digits.shape[0]

# Create empty distance matrix
distance_matrix = np.zeros((num_images, num_images))


for i in range(digits.shape[0]):
    print(f"{'=' * 30 } Digit {sublabels[i]} {'=' * 30}")
    
    # Create qmap for current image
    f = SingleChannelImageSurface(digits[i])
    q = Qmap(f)
    
    # Loop thorugh all other images and compare to image i
    for j in range(digits.shape[0]):
        # 0 distance to itself
        if i == j:
            continue
            
        # Define qmap for other image.
        g = SingleChannelImageSurface(digits[j])
        r = Qmap(g)
        
        # Define a new netwrok used for matching.
        rn = SurfaceReparametrizer(
            [SineLayer(6, init_scale=0.) for _ in range(8)]
        )
        optimizer = optim_builder(rn)
        
        loss = ImageComponentDistance(q, r, k=32, h=None)
        
        # Match data
        print(f"{'-'*20} Comparing digits {sublabels[i]} and {sublabels[j]} {'-'*20}")
        reparametrize(rn, loss, optimizer, 100, Silent())
        
        # Insert into distance matrix. Alternatives
#         distance_matrix[i, j] = ShapeDistance(q, r, k=32, h=3.4e-4)(rn)
#         distance_matrix[i, j] = distance(rn)
        distance_matrix[i, j] = ImageDistance(f, g, k=28)(rn)
        print(f"Loss: {distance_matrix[i, j]:.5f}")
        

save_distance_matrix('distance_matrix.pickle', distance_matrix, sublabels)

fig, axes = plt.subplots(10, examples_per_digit, figsize=(14, 20))
for i in range(10):
    for j in range(examples_per_digit):
        axes[i][j].imshow(digits[examples_per_digit * i + j])
        
plt.show()

plot_distance_matrix(distance_matrix)

-------------------- Comparing digits 0 and 0 --------------------


In [None]:
D, y = load_distance_matrix("distance_matrix.pickle")
D /= D.max()
plot_distance_matrix(D)
S, A = symmetric_part(D), antisymmetric_part(D)

# Create label list, as they were not stored with matrix.
y = []
for i in range(10):
    y.extend([i, i])
print(y)

In [None]:
X1 = mds(S)
y1 = agglomerative_clustering(S)
X2 = KMedoids(metric='precomputed', method='pam').fit_transform(S)
y2 = KMedoids(metric='precomputed', method='pam').fit_predict(S)

In [None]:
plot_clustering(X1, y, y1)

In [None]:
plot_clustering(X2, y, y2)

In [None]:
plot_clustering(X2, y, y1)

In [None]:
plot_clustering(X1, y, y2)

In [None]:
plt.show()