In [1]:
pip install memory_profiler

Collecting memory_profiler
  Downloading https://files.pythonhosted.org/packages/8f/fd/d92b3295657f8837e0177e7b48b32d6651436f0293af42b76d134c3bb489/memory_profiler-0.58.0.tar.gz
Building wheels for collected packages: memory-profiler
  Building wheel for memory-profiler (setup.py) ... [?25l[?25hdone
  Created wheel for memory-profiler: filename=memory_profiler-0.58.0-cp36-none-any.whl size=30181 sha256=03be3cc83f6371637556d4606d22e1de7fc070c28cf806f7980a4a97b97554f3
  Stored in directory: /root/.cache/pip/wheels/02/e4/0b/aaab481fc5dd2a4ea59e78bc7231bb6aae7635ca7ee79f8ae5
Successfully built memory-profiler
Installing collected packages: memory-profiler
Successfully installed memory-profiler-0.58.0


In [2]:
%load_ext memory_profiler

In [3]:
import random
import time
import math
import numpy as np
from sklearn.neighbors import NearestNeighbors
from scipy.spatial import KDTree
from scipy.stats import wasserstein_distance

import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.autograd.variable import Variable
from torch.utils.data import DataLoader

In [4]:
device = torch.device('cuda') if torch.cuda.is_available else torch.device('cpu')
# device = torch.device('cpu')

In [5]:
np.random.seed(6345)
torch.manual_seed(6345)
set_dist = []


for i in range(50): 
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))
    x = m.sample([250])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([0.0, 1.0]), torch.tensor([[1,.5],[.5,1]]))
    x = m.sample([250])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.ones(2), covariance_matrix=torch.tensor([[.7,.1],[.1,1]]))
    x = m.sample([250])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([1.0, 0.0]), torch.tensor([[.2, -.1], [-.1, 1]]))
    x = m.sample([250])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([.5, .5]), torch.tensor([[.8,.4],[.4,1]]))
    x = m.sample([250])
    set_dist.append(x) 
    
for i in range(50): 
    m = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([-.25, -.5]), torch.eye(2)*.5)
    x = m.sample([250])
    set_dist.append(x) 
    


In [6]:
set_dist = torch.stack(set_dist)

In [7]:
set_dist.shape

torch.Size([300, 250, 2])

In [8]:
class Set2Set(nn.Module):
    def __init__(self, input_dim, hidden_dim, act_fn=nn.Tanh, num_layers=1):
        '''
        Args:
            input_dim: input dim of Set2Set. 
            hidden_dim: the dim of set representation, which is also the INPUT dimension of 
                the LSTM in Set2Set. 
                This is a concatenation of weighted sum of embedding (dim input_dim), and the LSTM
                hidden/output (dim: self.lstm_output_dim).
        '''
        super(Set2Set, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        if hidden_dim <= input_dim:
            print('ERROR: Set2Set output_dim should be larger than input_dim')
        # the hidden is a concatenation of weighted sum of embedding and LSTM output
        self.lstm_output_dim = hidden_dim - input_dim
        self.lstm = nn.LSTM(hidden_dim, input_dim, num_layers=num_layers, batch_first=True)

        # convert back to dim of input_dim
       # self.pred = nn.Linear(hidden_dim, input_dim)
        self.pred = nn.Linear(hidden_dim,4)
        self.act = act_fn()

    def forward(self, embedding):
        '''
        Args:
            embedding: [batch_size x n x d] embedding matrix
        Returns:
            aggregated: [batch_size x d] vector representation of all embeddings
        '''
        batch_size = embedding.size()[0]
        n = embedding.size()[1]

        hidden = (torch.zeros(self.num_layers, batch_size, self.lstm_output_dim).cuda(),
                  torch.zeros(self.num_layers, batch_size, self.lstm_output_dim).cuda())

        q_star = torch.zeros(batch_size, 1, self.hidden_dim).cuda()
        for i in range(n):
            # q: batch_size x 1 x input_dim
            q, hidden = self.lstm(q_star, hidden)
            # e: batch_size x n x 1
            e = embedding @ torch.transpose(q, 1, 2)
            a = nn.Softmax(dim=1)(e)
            r = torch.sum(a * embedding, dim=1, keepdim=True)
            q_star = torch.cat((q, r), dim=2)
        q_star = torch.squeeze(q_star, dim=1)
        out = self.act(self.pred(q_star))

        return out

In [9]:
class DeepSet(nn.Module):

    def __init__(self, in_features, set_features):
        super(DeepSet, self).__init__()
        self.in_features = in_features
        self.out_features = set_features
        self.feature_extractor = nn.Sequential(
            nn.Linear(in_features, 50),
            nn.ELU(inplace=True),
            nn.Linear(50, 100),
            nn.ELU(inplace=True),
            nn.Linear(100, set_features)
        )

        self.regressor = nn.Sequential(
            nn.Linear(set_features, 30),
            nn.ELU(inplace=True),
            nn.Linear(30, 30),
            nn.ELU(inplace=True),
            nn.Linear(30, 10),
            nn.ELU(inplace=True),
            nn.Linear(10, 2),
        )
        
        
    def forward(self, input):
        x = input
        x = self.feature_extractor(x)
        x = x.sum(dim=1)
        x = self.regressor(x)
        return x


In [10]:
class Encoder(nn.Module):
    """ Set Encoder 
    """
    def __init__(self, dim_Q, dim_K, dim_V, d_model, num_heads, ln=False, skip=True):
        super(Encoder, self).__init__()
        self.dim_V = dim_V
        self.num_heads = num_heads
        self.skip = skip
       # self.s_max = s_max
        #Maximum set size
        self.d_model = d_model
        self.fc_q = nn.Linear(dim_Q, d_model)
        self.fc_k = nn.Linear(dim_K, d_model)
        self.fc_v = nn.Linear(dim_K, d_model)
        if ln:
            self.ln0 = nn.LayerNorm(d_model)
            self.ln1 = nn.LayerNorm(d_model)
        #This is the classic pointwise feedforward in "Attention is All you need"
        self.ff = nn.Sequential(
        nn.Linear(d_model, 4 * d_model),
        nn.ReLU(),
        nn.Linear(4 * d_model, d_model))
        # I have experimented with just a smaller version of this 
       # self.fc_o = nn.Linear(d_model,d_model)
        
     #   self.fc_rep = nn.Linear(s_max, 1)
#number of heads must divide output size = d_model
        

    def forward(self, Q, K):
        Q = self.fc_q(Q)
      
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.d_model // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), 0)
        K_ = torch.cat(K.split(dim_split, 2), 0)
        V_ = torch.cat(V.split(dim_split, 2), 0)
  

        A = torch.softmax(Q_.bmm(K_.transpose(-2,-1))/math.sqrt(self.d_model), dim=-1)
        A_1 = A.bmm(V_)
        
 
        O = torch.cat((A_1).split(Q.size(0), 0), 2)
       
        O = torch.cat((Q_ + A_1).split(Q.size(0), 0), 2) if getattr(self, 'skip', True) else \
             torch.cat((A_1).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
       # O = O + F.relu(self.fc_o(O)) if getattr(self, 'skip', None) is None else F.relu(self.fc_o(O))
        # For the classic transformers paper it is 
        O = O + self.ff(O)
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
        O = torch.mean(O,dim=1)
#         O = pad_sequence(O, batch_first=True, padding_value=0)
#         O = O.transpose(-2,-1)
#         O = F.pad(O, (0, self.s_max- O.shape[-1]), 'constant', 0)
      #  O = self.fc_rep(O)
       # O = self.fc_rep(O.transpose(-2,-1))
      #  O = O.squeeze()

        return O

In [11]:
class SelfAttention(nn.Module):
    def __init__(self, dim_in=18, dim_out=8, num_heads=2, ln=True, skip=True):
        super(SelfAttention, self).__init__()
        self.Encoder = Encoder(dim_in, dim_in, dim_in, dim_out, num_heads, ln=ln, skip=skip)

    def forward(self, X):
        return self.Encoder(X, X)


In [12]:
eps = 1e-15
"""Approximating KL divergences between two probability densities using samples. 
    It is buggy. Use at your own peril
"""

def knn_distance(point, sample, k):
    """ Euclidean distance from `point` to it's `k`-Nearest
    Neighbour in `sample` """
    norms = np.linalg.norm(sample-point, axis=1)
    return np.sort(norms)[k]


def verify_sample_shapes(s1, s2, k):
    # Expects [N, D]
    assert(len(s1.shape) == len(s2.shape) == 2)
    # Check dimensionality of sample is identical
    assert(s1.shape[1] == s2.shape[1])


def naive_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using brute-force (numpy) k-NN
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    D = np.log(m / (n - 1))
    d = float(s1.shape[1])

    for p1 in s1:
        nu = knn_distance(p1, s2, k-1)  # -1 because 'p1' is not in 's2'
        rho = knn_distance(p1, s1, k)
        D += (d/n)*np.log((nu/rho)+eps)
    return D


def scipy_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using scipy's KDTree
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    d = float(s1.shape[1])
    D = np.log(m / (n - 1))

    nu_d,  nu_i   = KDTree(s2).query(s1, k)
    rho_d, rhio_i = KDTree(s1).query(s1, k+1)

    # KTree.query returns different shape in k==1 vs k > 1
    if k > 1:
        D += (d/n)*np.sum(np.log(nu_d[::, -1]/rho_d[::, -1]))
    else:
        D += (d/n)*np.sum(np.log(nu_d/rho_d[::, -1]))

    return D


def skl_estimator(s1, s2, k=1):
    """ KL-Divergence estimator using scikit-learn's NearestNeighbours
        s1: (N_1,D) Sample drawn from distribution P
        s2: (N_2,D) Sample drawn from distribution Q
        k: Number of neighbours considered (default 1)
        return: estimated D(P|Q)
    """
    verify_sample_shapes(s1, s2, k)

    n, m = len(s1), len(s2)
    d = float(s1.shape[1])
    D = np.log(m / (n - 1))

    s1_neighbourhood = NearestNeighbors(k+1, 10).fit(s1)
    s2_neighbourhood = NearestNeighbors(k, 10).fit(s2)

    for p1 in s1:
        s1_distances, indices = s1_neighbourhood.kneighbors([p1], k+1)
        s2_distances, indices = s2_neighbourhood.kneighbors([p1], k)
        rho = s1_distances[0][-1]
        nu = s2_distances[0][-1]
        D += (d/n)*np.log(nu/rho)
    return D

def calculate_loss(batch, n_data, a, y, y_a, y_translate):
  loss = [0,0,0]
  y_norm = torch.pdist(y)
  n_data_pairwise_1 = []
  n_data_pairwise_2 = []
  for i in range(len(batch)):
      for j in range(i+1,len(batch)):
        n_data_pairwise_1.append(n_data[i])
        n_data_pairwise_2.append(n_data[j])
  n_data_pairwise_1 = torch.stack(n_data_pairwise_1)
  n_data_pairwise_2 = torch.stack(n_data_pairwise_2)
  w_norm = sinkhorn(n_data_pairwise_2, n_data_pairwise_1)
  y_a_norm = torch.pdist(y_a)
  y_translate_norm = torch.pdist(y_translate)
  loss[0] = (y_norm - w_norm).abs().sum()
  loss[1] = ((y_a_norm - a * y_norm) ** 2).sum()
  loss[2] = ((y_translate_norm - y_norm) ** 2).sum()
  loss = sum(loss)
  return loss

# List of all estimators
Estimators = [naive_estimator, scipy_estimator, skl_estimator]

In [13]:
class SinkhornDistance(nn.Module):
    r"""
    Given two empirical measures each with :math:`P_1` locations
    :math:`x\in\mathbb{R}^{D_1}` and :math:`P_2` locations :math:`y\in\mathbb{R}^{D_2}`,
    outputs an approximation of the regularized OT cost for point clouds.
    Args:
        eps (float): regularization coefficient
        max_iter (int): maximum number of Sinkhorn iterations
        reduction (string, optional): Specifies the reduction to apply to the output:
            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
            'mean': the sum of the output will be divided by the number of
            elements in the output, 'sum': the output will be summed. Default: 'none'
    Shape:
        - Input: :math:`(N, P_1, D_1)`, :math:`(N, P_2, D_2)`
        - Output: :math:`(N)` or :math:`()`, depending on `reduction`
    """
    def __init__(self, eps, max_iter, reduction='none'):
        super(SinkhornDistance, self).__init__()
        self.eps = eps
        self.max_iter = max_iter
        self.reduction = reduction

    def forward(self, x, y):
        # The Sinkhorn algorithm takes as input three variables :
        C = self._cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]
        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]

        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).to(device).squeeze()
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).to(device).squeeze()

        u = torch.zeros_like(mu).to(device)
        v = torch.zeros_like(nu).to(device)
        # To check if algorithm terminates because of threshold
        # or max iterations reached
        actual_nits = 0
        # Stopping criterion
        thresh = 1e-1

        # Sinkhorn iterations
        for i in range(self.max_iter):
            u1 = u  # useful to check the update
            u = self.eps * (torch.log(mu+1e-8) - torch.logsumexp(self.M(C, u, v), dim=-1)) + u
            v = self.eps * (torch.log(nu+1e-8) - torch.logsumexp(self.M(C, u, v).transpose(-2, -1), dim=-1)) + v
            err = (u - u1).abs().sum(-1).mean()

            actual_nits += 1
            if err.item() < thresh:
                break

        U, V = u, v
        # Transport plan pi = diag(a)*K*diag(b)
        pi = torch.exp(self.M(C, U, V))
        # Sinkhorn distance
        cost = torch.sum(pi * C, dim=(-2, -1))

        if self.reduction == 'mean':
            cost = cost.mean()
        elif self.reduction == 'sum':
            cost = cost.sum()

      #  return cost, pi, C
        return cost

    def M(self, C, u, v):
        "Modified cost for logarithmic updates"
        "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
        return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / self.eps

    @staticmethod
    def _cost_matrix(x, y, p=2):
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(-2)
        y_lin = y.unsqueeze(-3)
        C = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
        return C

    @staticmethod
    def ave(u, u1, tau):
        "Barycenter subroutine, used by kinetic acceleration through extrapolation."
        return tau * u + (1 - tau) * u1

In [14]:
sinkhorn = SinkhornDistance(eps=0.1, max_iter=100, reduction=None).to(device)

In [15]:
class MyDataset(Dataset):
    def __init__(self, data, transform=None):
        self.data = data.float()
        
        self.transform = transform

    def __getitem__(self, index):
        x = self.data[index]
        
        if self.transform:
            x = self.transform(x)
           
        return x

    def __len__(self):
        return len(self.data)
    

In [16]:
dataset = MyDataset(set_dist)
loader = DataLoader(dataset, batch_size = 12, shuffle = True)


In [17]:
model = DeepSet(2, 36).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# checkpoint = torch.load('normal_2D_2condition1.pt')
# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# loss = checkpoint['loss']

model.train()

DeepSet(
  (feature_extractor): Sequential(
    (0): Linear(in_features=2, out_features=50, bias=True)
    (1): ELU(alpha=1.0, inplace=True)
    (2): Linear(in_features=50, out_features=100, bias=True)
    (3): ELU(alpha=1.0, inplace=True)
    (4): Linear(in_features=100, out_features=36, bias=True)
  )
  (regressor): Sequential(
    (0): Linear(in_features=36, out_features=30, bias=True)
    (1): ELU(alpha=1.0, inplace=True)
    (2): Linear(in_features=30, out_features=30, bias=True)
    (3): ELU(alpha=1.0, inplace=True)
    (4): Linear(in_features=30, out_features=10, bias=True)
    (5): ELU(alpha=1.0, inplace=True)
    (6): Linear(in_features=10, out_features=2, bias=True)
  )
)

Wasserstein distance has the following properties: 
1) W(aX,aY) = |a|W(X,Y)
2) W(X+x, Y+x) = W(X,Y)

Only implement these properties

In [18]:
np.random.seed(6345)
torch.manual_seed(6345)
num_epochs = 500
running_loss = []

for t in range(num_epochs):
    for n_batch, batch in enumerate(loader):
        n_data = Variable(batch.to(device), requires_grad=True)
        a = torch.rand(1).to(device)
        b = torch.rand(2).to(device)
       
    
        optimizer.zero_grad()
        y = model(n_data)
        y_a = model(a*n_data)
        y_translate = model(n_data + b)
        
        loss = calculate_loss(batch, n_data, a, y, y_a, y_translate)
        
        loss = loss/(len(batch)*(len(batch)-1)/2)
        
       
        loss.backward()
    
        optimizer.step()
    
        
    running_loss.append(loss)
    print(loss)
   
   

tensor(0.1534, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2441, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2164, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.3078, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2485, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2501, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2621, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2232, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2564, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.1616, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2856, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2004, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.1936, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2027, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2791, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.1717, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2849, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.2188, device='cuda:0', grad_fn=<DivBack

In [19]:
# 196+41 epochs in
torch.save({
           
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss
            
            }, 'normal_2D_2condition1.pt')

In [20]:
len(running_loss)

500

In [21]:
running_loss

[tensor(0.3933, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.1872, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2383, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2933, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2532, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.1736, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2241, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2035, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2803, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2068, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2478, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2558, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.1811, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2434, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2536, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.2752, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.3033, device='cuda:0', grad_fn=<DivBackward0>),
 tensor(0.1932

In [None]:
#Test ground truth
#Cov mat_1 = ID, Cov mat_2 = [[1,.5], [.5,1]], m_1 = (0,0) , m_2 = (0,1)
#Real Wass dist^2 = ||m_1 - m_2||^2 + (4-\sqrt(2)-\sqrt(6))

In [34]:
np.random.seed(0)
torch.manual_seed(0)
m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))

In [35]:
m1 = m.sample([250]).view(1,-1,2).to(device)
m2 = m.sample([250]).view(1,-1,2).to(device)

In [36]:
 n = torch.distributions.multivariate_normal.MultivariateNormal(torch.tensor([0.0, 1.0]), torch.tensor([[1,.5],[.5,1]]))

In [37]:
n1 = n.sample([250]).view(1,-1,2).to(device)
n2 = n.sample([250]).view(1,-1,2).to(device)

In [38]:
model(m1)

tensor([[-2.4903, -5.9006]], device='cuda:0', grad_fn=<AddmmBackward>)

In [39]:
model(m2)

tensor([[-2.2666, -6.0188]], device='cuda:0', grad_fn=<AddmmBackward>)

In [40]:
model(m1*.5)

tensor([[-2.4615, -6.1914]], device='cuda:0', grad_fn=<AddmmBackward>)

In [41]:
model(n1*.5)

tensor([[-1.7539, -5.7813]], device='cuda:0', grad_fn=<AddmmBackward>)

In [42]:
model(n1)

tensor([[-1.0995, -5.1148]], device='cuda:0', grad_fn=<AddmmBackward>)

In [None]:
#calculated distance = 1.336, scaling by .5 get distance to be .7 and moving them around got 1.323

In [43]:
model(m1+.8)

tensor([[-1.9079, -4.2803]], device='cuda:0', grad_fn=<AddmmBackward>)

In [44]:
model(n1+.8)

tensor([[-0.5489, -3.4706]], device='cuda:0', grad_fn=<AddmmBackward>)

In [45]:
sinkhorn(m1+.5, n1+.5)

tensor([1.4789], device='cuda:0')

In [46]:
sinkhorn(m1,n1)

tensor([1.4789], device='cuda:0')

In [47]:
sinkhorn(m1*.5, n1*.5)

tensor([0.4361], device='cuda:0')