In [None]:
import torch
import math
from itertools import combinations
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def L2discrepancy(x):
    N = x.size(1)
    dim = x.size(2)
    prod1 = 1. - x ** 2.
    prod1 = torch.prod(prod1, dim=2)
    sum1 = torch.sum(prod1, dim=1)
    pairwise_max = torch.maximum(x[:, :, None, :], x[:, None, :, :])
    product = torch.prod(1 - pairwise_max, dim=3)
    sum2 = torch.sum(product, dim=(1, 2))
    one_dive_N = 1. / N
    out = torch.sqrt(math.pow(3., -dim) - one_dive_N * math.pow(2., 1. - dim) * sum1 + 1. / math.pow(N, 2.) * sum2)
    return out

def hickernell_all_emphasized(x,dim_emphasize):
    nbatch, nsamples, dim = x.size(0), x.size(1), x.size(2)
    mean_disc_projections = torch.zeros(nbatch).to(device)
    for d in dim_emphasize:
        subsets_of_d = list(combinations(range(dim), d))
        for i in range(len(subsets_of_d)):
            set_inds = subsets_of_d[i]
            mean_disc_projections += L2discrepancy(x[:, :, set_inds])

    return mean_disc_projections

In [None]:
print(device)

cuda


In [None]:
!pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
!pip install torch-geometric

Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
Collecting torch-scatter
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m108.0/108.0 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: torch-scatter
  Building wheel for torch-scatter (setup.py) ... [?25l[?25hdone
  Created wheel for torch-scatter: filename=torch_scatter-2.1.2-cp311-cp311-linux_x86_64.whl size=3622730 sha256=b0d556722d956772278b18db8c484b1bd055f1875f352a57e38f302cf98ba459
  Stored in directory: /root/.cache/pip/wheels/b8/d4/0e/a80af2465354ea7355a2c153b11af2da739cfcf08b6c0b28e2
Successfully built torch-scatter
Installing collected packages: torch-scatter
Successfully installed torch-scatter-2.1.2
Looking in links: https://pytorch-geometric.com/whl/torch-1.9.0+cu102.html
Collecting torch-sparse
  Downloading torch_sparse-0.6.18.tar.gz (20

In [None]:
import math
from torch import nn
from torch_cluster import radius_graph
from torch_geometric.nn import MessagePassing, InstanceNorm

class MPNN_layer(MessagePassing):
    def __init__(self, ninp, nhid):
        super(MPNN_layer, self).__init__()
        self.ninp = ninp
        self.nhid = nhid

        self.message_net_1 = nn.Sequential(nn.Linear(2 * ninp, nhid),
                                           nn.ReLU()
                                           )
        self.message_net_2 = nn.Sequential(nn.Linear(nhid, nhid),
                                           nn.ReLU()
                                           )
        self.update_net_1 = nn.Sequential(nn.Linear(ninp + nhid, nhid),
                                          nn.ReLU()
                                          )
        self.update_net_2 = nn.Sequential(nn.Linear(nhid, nhid),
                                          nn.ReLU()
                                          )
        self.norm = InstanceNorm(nhid)

    def forward(self, x, edge_index, batch):
        x = self.propagate(edge_index, x=x)
        x = self.norm(x, batch)
        return x

    def message(self, x_i, x_j):
        message = self.message_net_1(torch.cat((x_i, x_j), dim=-1))
        message = self.message_net_2(message)
        return message

    def update(self, message, x):
        update = self.update_net_1(torch.cat((x, message), dim=-1))
        update = self.update_net_2(update)
        return update


class MPMC_net(nn.Module):
    def __init__(self, dim, nhid, nlayers, nsamples, nbatch, radius, loss_fn, dim_emphasize, n_projections):
        super(MPMC_net, self).__init__()
        self.enc = nn.Linear(dim,nhid)
        self.convs = nn.ModuleList()
        for i in range(nlayers):
            self.convs.append(MPNN_layer(nhid,nhid))
        self.dec = nn.Linear(nhid,dim)
        self.nlayers = nlayers
        self.mse = torch.nn.MSELoss()
        self.nbatch = nbatch
        self.nsamples = nsamples
        self.dim = dim
        self.n_projections = n_projections
        self.dim_emphasize = torch.tensor(dim_emphasize).long()

        ## random input points for transformation:
        self.x = torch.rand(nsamples * nbatch, dim).to(device)
        batch = torch.arange(nbatch).unsqueeze(-1).to(device)
        batch = batch.repeat(1, nsamples).flatten()
        self.batch = batch
        self.edge_index = radius_graph(self.x, r=radius, loop=True, batch=batch).to(device)

        if loss_fn == 'L2':
            self.loss_fn = self.L2discrepancy
        elif loss_fn == 'approx_hickernell':
            if dim_emphasize != None:
                assert torch.max(self.dim_emphasize) <= dim
                self.loss_fn = self.approx_hickernell
        else:
            print('Loss function not implemented')

    def approx_hickernell(self, X):
        X = X.view(self.nbatch, self.nsamples, self.dim)
        disc_projections = torch.zeros(self.nbatch).to(device)

        for i in range(self.n_projections):
            ## sample among non-emphasized dimensionality
            mask = torch.ones(self.dim, dtype=bool)
            mask[self.dim_emphasize - 1] = False
            remaining_dims = torch.arange(1, self.dim + 1)[mask]
            projection_dim = remaining_dims[torch.randint(low=0, high=remaining_dims.size(0), size=(1,))].item()
            projection_indices = torch.randperm(self.dim)[:projection_dim]
            disc_projections += self.L2discrepancy(X[:, :, projection_indices])
            ## sample among emphasized dimensionality
            remaining_dims = torch.arange(1, self.dim + 1)[self.dim_emphasize - 1]
            projection_dim = remaining_dims[torch.randint(low=0, high=remaining_dims.size(0), size=(1,))].item()
            projection_indices = torch.randperm(self.dim)[:projection_dim]
            disc_projections += self.L2discrepancy(X[:, :, projection_indices])

        return disc_projections

    def L2discrepancy(self, x):
        N = x.size(1)
        dim = x.size(2)
        prod1 = 1. - x ** 2.
        prod1 = torch.prod(prod1, dim=2)
        sum1 = torch.sum(prod1, dim=1)
        pairwise_max = torch.maximum(x[:, :, None, :], x[:, None, :, :])
        product = torch.prod(1 - pairwise_max, dim=3)
        sum2 = torch.sum(product, dim=(1, 2))
        one_dive_N = 1. / N
        out = torch.sqrt(math.pow(3., -dim) - one_dive_N * math.pow(2., 1. - dim) * sum1 + 1. / math.pow(N, 2.) * sum2)
        return out

    def forward(self):
        X = self.x
        edge_index = self.edge_index

        X = self.enc(X)
        for i in range(self.nlayers):
            X = self.convs[i](X,edge_index,self.batch)
        X = torch.sigmoid(self.dec(X))  ## clamping with sigmoid needed so that warnock's formula is well-defined
        X = X.view(self.nbatch, self.nsamples, self.dim)
        loss = torch.mean(self.loss_fn(X))
        return loss, X

In [None]:
#from models import *
#import torch
import torch.optim as optim
import numpy as np
from pathlib import Path
import argparse
#from utils import L2discrepancy, hickernell_all_emphasized

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train(args):
    model = MPMC_net(args.dim, args.nhid, args.nlayers, args.nsamples, args.nbatch,
                     args.radius, args.loss_fn, args.dim_emphasize, args.n_projections).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    best_loss = 10000.
    patience = 0

    ## could be tuned for better performance
    start_reduce = 100000
    reduce_point = 10

    Path('results/dim_' + str(args.dim)).mkdir(parents=True, exist_ok=True)
    Path('outputs/dim_' + str(args.dim)).mkdir(parents=True, exist_ok=True)

    for epoch in range(args.epochs):
        model.train()
        optimizer.zero_grad()
        loss, X = model()
        loss.backward()
        optimizer.step()

        if epoch % 100 ==0:
            y = X.clone()
            if args.loss_fn == 'L2':
                batched_discrepancies = L2discrepancy(y.detach())
            elif args.loss_fn == 'approx_hickernell':
                ## compute sum over all projections with emphasized dimensionality:
                batched_discrepancies = hickernell_all_emphasized(y.detach(),args.dim_emphasize)
            else:
                print('Loss function not implemented')
            min_discrepancy, mean_discrepancy = torch.min(batched_discrepancies).item(), torch.mean(batched_discrepancies).item

            if min_discrepancy < best_loss:
                best_loss = min_discrepancy
                f = open('results/dim_'+str(args.dim)+'/nsamples_'+str(args.nsamples)+'.txt', 'a')
                f.write(str(best_loss) + '\n')
                f.close()

                ## save MPMC points:
                PATH = 'outputs/dim_'+str(args.dim)+'/nsamples_'+str(args.nsamples)+'.npy'
                y = y.detach().cpu().numpy()
                np.save(PATH,y)

            if (min_discrepancy > best_loss and (epoch + 1) >= start_reduce):
                patience += 1

            if (epoch + 1) >= start_reduce and patience == reduce_point:
                patience = 0
                args.lr /= 10.
                for param_group in optimizer.param_groups:
                    param_group['lr'] = args.lr

            if (args.lr < 1e-6):
                f = open('results/dim_'+str(args.dim)+'/nsamples_'+str(args.nsamples)+'.txt', 'a')
                f.write('### epochs: '+str(epoch) + '\n')
                f.close()
                break


class Customargs:
  def __init__(self, dim =2, nhid = 128, nlayers = 3, nsamples = 64, nbatch = 16, radius = 0.35, loss_fn = 'L2', dim_emphasize = [1], n_projections = 15,
               lr = 0.001, start_reduce = 100000, epochs = 200000, weight_decay = 1e-6):

    self.dim = dim
    self.nhid = nhid
    self.nlayers = nlayers
    self.nsamples = nsamples
    self.nbatch = nbatch
    self.radius = radius
    self.loss_fn = loss_fn
    self.dim_emphasize = dim_emphasize
    self.n_projections = n_projections
    self.lr = lr
    self.epochs = epochs
    self.weight_decay = weight_decay



args = Customargs(nsamples = 32)

train(args)
#args = parser.parse_args()





In [None]:
args8 = Customargs(nsamples = 96)
train(args8)