In [1]:
import numpy as np

In [164]:
import torch

class BregmanSubnet(torch.nn.Module):
    def __init__(self, input_dim, hidden_dims):
        super(BregmanSubnet, self).__init__()
        dims = [input_dim] + hidden_dims + [1]
        self.layers = [
            torch.nn.Linear(input_dim, hidden_dims[0]),
#             torch.nn.ReLU(),
            torch.nn.Linear(hidden_dims[0], hidden_dims[1]),
#             torch.nn.ReLU(),
            torch.nn.Linear(hidden_dims[1], 1),
        ]
        self.layers = torch.nn.Sequential(*self.layers)
    
    def forward(self, x):
#         print(self.layers(x).shape)
        return self.layers(x)
    
class BregmanLearner(torch.nn.Module):
    def __init__(self, input_dim, hidden_dims, subnets):
        super(BregmanLearner, self).__init__()
        dims = [input_dim] + hidden_dims
        self.layers = [
            torch.nn.Linear(input_dim, hidden_dims[0]),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dims[0], hidden_dims[1]),
            torch.nn.ReLU(),
            torch.nn.Linear(hidden_dims[1],hidden_dims[2]),
        ]                    
        self.layers = torch.nn.Sequential(*self.layers)
        self.heads = subnets
    
    def forward(self, x):
        h = self.layers(x)
        return torch.cat([head(h) for head in self.heads], dim=-1)

from pytorch_metric_learning import losses
from pytorch_metric_learning import distances

# class SameMaxDivergence(distances.BaseDistance):
#     def __init__(self, *args, **kwargs):
#         super(SameMaxDivergence, self).__init__(*args, **kwargs)
        
#     def compute_mat(self, query_emb, ref_emb):
#         q_max_idxs = torch.argmax(query_emb, dim=-1)
#         r_max_idxs = torch.argmax(ref_emb, dim=-1)
#         qm = q_max_idxs[:, None]
#         rm = r_max_idxs[None, :]
#         mask = torch.where(qm - rm == 0, 0, 1).float()

#         return  mask * ((query_emb[:, None, :] - ref_emb[None, :, :])**2).sum(dim=-1)
# #         return  ((query_emb[:, None, :] - ref_emb[None, :, :])**self.p).sum(dim=-1)
    
#     def pairwise_distance(self, query_emb, ref_emb):
# #         q = torch.rand(10,5).float()
# #         r = torch.rand(10,5).float()
#         q_max_idxs = torch.argmax(query_emb, dim=-1)
#         r_max_idxs = torch.argmax(ref_emb, dim=-1)
#         mask = torch.where(q_max_idxs - r_max_idxs == 0, 0, 1)
# #         mask = 1
#         return mask * torch.nn.functional.pairwise_distance(query_emb, ref_emb, p=self.p)

# TODO this is wrong
# instead of doing 0 for if same class and normal otherwise
# need to do
# p* = argmax D(x)
# q* = argmax D(y)
# find output of D(x) @ q*

def train_bregman_distance(n_classes, dataloader):
    
    bregman_input_dim = dataloader.dataset.X.shape[1]
    bregman_hidden_dims = [100, 50, 10]
    n_subnets = n_classes
    subnets_input_dim = bregman_hidden_dims[-1]
    subnets_dim = [10, 10]
    subnets = [BregmanSubnet(subnets_input_dim, subnets_dim) for _ in range(n_subnets)]
    embedder = BregmanLearner(bregman_input_dim, bregman_hidden_dims, subnets)
    
    optimizer = torch.optim.Adam(lr=1e-2, params=embedder.parameters())
    def dist_fn(x, y):
        v_x = embedder(x)
        v_y = embedder(y)
        phi_x, p_star = torch.max(v_x, dim=-1)
        phi_y, q_star = torch.max(v_y, dim=-1)
        return (phi_x - v_x[:, q_star])**2
    
    loss_func = torch.nn.TripletMarginWithDistanceLoss(distance_function=dist_fn, margin=1)#losses.TripletMarginLoss()#distance=SameMaxDivergence())
    
    for epoch in range(100):
        for i, (anchor, positive, negative) in enumerate(dataloader):
            optimizer.zero_grad()
            
#             anchor = embedder(anchor.float())
#             positive = embedder(positive.float())
#             negative = embedder(negative.float())
            loss = loss_func(anchor.float(), positive.float(), negative.float())
#             print(labels)
            print(loss.item())
            loss.backward()
            optimizer.step()
    return embedder, dist_fn

In [165]:
from sklearn import datasets
import random
class IrisTripletDataset(torch.utils.data.Dataset):
    def __init__(self):
        iris = datasets.load_iris()
        self.X = iris.data[:, :]
        
        self.labels = iris.target
        self.all_classes = set(self.labels)
        
        self.by_label = {}
        for idx, (x, label) in enumerate(zip(self.X, self.labels)):
            if label not in self.by_label: self.by_label[label]=[]
            self.by_label[label].append(idx)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        # anchor, positive, negative
        label = self.labels[idx]
        negative_classes = list(self.all_classes - {label})
        negative = random.choice(negative_classes)
        return self.X[idx], self.X[random.choice(self.by_label[label])], self.X[random.choice(self.by_label[negative])]
        
ds = IrisTripletDataset()
# dl = torch.utils.data.DataLoader(ds, batch_size=5)
dl = torch.utils.data.DataLoader(ds, batch_size=len(ds.X))

In [166]:
n_classes = len(set(ds.labels))
embedder, dist_fn = train_bregman_distance(n_classes, dl)

0.9996938109397888
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0
1.0


In [167]:
dist_fn(torch.tensor(ds.X).float(), torch.tensor(ds.X).float())

tensor([[0.0000e+00, 3.0009e-02, 3.3293e-02,  ..., 9.2317e-02, 9.9879e-02,
         1.5783e-02],
        [3.0009e-02, 0.0000e+00, 8.5200e-05,  ..., 2.2760e-01, 2.3938e-01,
         8.9320e-02],
        [3.3293e-02, 8.5200e-05, 0.0000e+00,  ..., 2.3649e-01, 2.4850e-01,
         9.4922e-02],
        ...,
        [9.2317e-02, 2.2760e-01, 2.3649e-01,  ..., 0.0000e+00, 1.4880e-04,
         3.1757e-02],
        [9.9879e-02, 2.3938e-01, 2.4850e-01,  ..., 1.4880e-04, 0.0000e+00,
         3.6253e-02],
        [1.5783e-02, 8.9320e-02, 9.4922e-02,  ..., 3.1757e-02, 3.6253e-02,
         0.0000e+00]], grad_fn=<PowBackward0>)

In [168]:
## distances.LpDistance()(torch.tensor([[1,2,3]]).float(),torch.tensor([[4,5,6]]).float())

In [169]:
# SameMaxDivergence().compute_mat(torch.tensor(ds.X).float(), torch.tensor(ds.X).float())

In [170]:
embedded_X = embedder(torch.tensor(ds.X).float()).detach().numpy()

In [171]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
%matplotlib widget
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(embedded_X[:, 0], embedded_X[:, 1], embedded_X[:, 2], c = ds.labels)
plt.show()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [172]:
import umap.umap_ as umap
embeddings = embedder(torch.tensor(ds.X).float()).detach().numpy()
mapper = umap.UMAP().fit(embeddings)

In [173]:
import umap.plot
umap.plot.points(mapper, labels=ds.labels)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<AxesSubplot:>

In [176]:
umap_embeds = umap.umap_.UMAP().fit(ds.X)
umap.plot.points(umap_embeds, labels=ds.labels, cmap=['#000', '#aaa', '#f00'], facecolor='b')

TypeError: points() got an unexpected keyword argument 'facecolor'