In [None]:
!pip install pykeops
!pip install geomloss[full]
!pip install pot

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline
from random import choices
import torch

import geomloss
import pykeops
from pykeops.torch import generic_sum
import pickle as pkl

from google.colab import drive
drive.mount('/content/drive')
import sys, os
project_path = '/content/drive/MyDrive'
sys.path.append(project_path)

import warnings

import time
from synthetic_utils import *
from utils_torch import cost_mat

import ot

from absl import flags, app


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
from tqdm import tqdm

In [None]:
for name in ['imgs', 'generated_samples']:
  if not os.path.exists(f'{project_path}/{name}'):
    os.makedirs(f'{project_path}/{name}')

In [None]:
def get_tv(x1, x2, n_pts=40):
    arr = np.zeros((n_pts+1, n_pts+1))
    arr2 = np.zeros((n_pts+1, n_pts+1))

    if np.max(np.abs(x1)) >=2:
        raise ValueError("argument 1 of get_iou has to be in [-1,1] square")
    factor = n_pts / 4
    idx = ((x1 + 2)*factor).astype(int)
    idx2 = ((np.clip(x2, -2, 2) + 2)*factor).astype(int)
    print(idx2.max())

    idx, counts = np.unique(idx, return_counts=True, axis = 0)
    idx2, counts2 = np.unique(idx2, return_counts=True, axis = 0)


    arr[idx[:, 0], idx[:, 1]] = counts / np.sum(counts)
    arr2[idx2[:, 0], idx2[:, 1]] = counts2 / np.sum(counts2)

    diff = (np.abs(arr - arr2)).sum()/2

    return diff

In [None]:
from pykeops.torch import generic_sum
transfer = generic_sum(
    "Exp( (F_i + G_j - IntInv(2)*SqDist(X_i,Y_j)) / E ) * L_j",  # See the formula above
    "Lab = Vi(2)",  # Output:  one vector of size 3 per line
    "E   = Pm(1)",  # 1st arg: a scalar parameter, the temperature
    "X_i = Vi(2)",  # 2nd arg: one 2d-point per line
    "Y_j = Vj(2)",  # 3rd arg: one 2d-point per column
    "F_i = Vi(1)",  # 4th arg: one scalar value per line
    "G_j = Vj(1)",  # 5th arg: one scalar value per column
    "L_j = Vj(2)",  # 6th arg: one vector of size 3 per column
)

loss_l1 = generic_sum(
    "Exp( (F_i + G_j - Sum(Abs(X_i - Y_j))) / E ) * Sum(Abs(A_i-Y_j))",  # See the formula above
    "Lab = Vi(2)",  # Output:  one vector of size 3 per line
    "E   = Pm(1)",  # 1st arg: a scalar parameter, the temperature
    "X_i = Vi(2)",  # 2nd arg: one 2d-point per line
    "Y_j = Vj(2)",  # 3rd arg: one 2d-point per column
    "F_i = Vi(1)",  # 4th arg: one scalar value per line
    "G_j = Vj(1)",  # 5th arg: one scalar value per column
    "A_i = Vi(2)",  # 6th arg: one vector of size 3 per column
)


loss_l1_grad = generic_sum(
    "Exp( (F_i + G_j - Sum(Abs(X_i-Y_j))) / E ) * Sign(X_i - Y_j)",  # See the formula above
    "Lab = Vi(2)",  # Output:  one vector of size 3 per line
    "E   = Pm(1)",  # 1st arg: a scalar parameter, the temperature
    "X_i = Vi(2)",  # 2nd arg: one 2d-point per line
    "Y_j = Vj(2)",  # 3rd arg: one 2d-point per column
    "F_i = Vi(1)",  # 4th arg: one scalar value per line
    "G_j = Vj(1)",  # 5th arg: one scalar value per column
    # "A_i = Vi(2)",  # 6th arg: one vector of size 3 per column
)

l1_dist = lambda x, y: torch.clamp_min(torch.sum(torch.abs(
    x[:,None,:] - y[None,:,:]), axis=-1), 1e-8)
l1_dist_formula = "Sum(Abs(X-Y))"


In [None]:
d = Opt()._asdict()
d['exp_id'] = int(str(np.random.normal()).rsplit('.')[1][:16])
d['shape_name'] = 'ellipsis'
d['N'] = 400000
d['eps'] = 5
d['cost'] = 'l2'
d['n_iter'] = 500#500
d['scaling'] = .99#z.9999
d['lr'] = 1e-3
d['use_bn']=False
d['hidden_neurons']=256
d['hidden_layers']=2
d['activation']='tanh'
d['input_scale'] = [2, 2]
d['input_mean'] = [-1, -1]
d['input_dist'] = 'uniform'
display(d)
use_emd = True

{'exp_id': 8973876331863215,
 'shape_name': 'ellipsis',
 'N': 400000,
 'eps': 5,
 'cost': 'l2',
 'n_iter': 500,
 'scaling': 0.99,
 'lr': 0.001,
 'use_bn': False,
 'hidden_neurons': 256,
 'hidden_layers': 2,
 'activation': 'tanh',
 'input_scale': [2, 2],
 'input_mean': [-1, -1],
 'input_dist': 'uniform',
 'load_path': None}

In [None]:
params = create_training_params_from_dict(d)
print('training with parameters', params)

training with parameters Opt(exp_id=8973876331863215, shape_name='ellipsis', N=400000, eps=5, cost='l2', n_iter=500, scaling=0.99, lr=0.001, use_bn=False, hidden_neurons=256, hidden_layers=2, activation='tanh', input_scale=[2.0, 2.0], input_mean=[-1.0, -1.0], input_dist='uniform', load_path=None)


In [None]:
device = torch.device("cpu")
shape_fn = make_shape_fn(params.shape_name)

loss_args = {
    'loss': "sinkhorn",
    'cost': (l1_dist_formula, l1_dist) if params.cost == 'l1' else None,
    'p': 2 if params.cost == 'l2' else 1,
    'scaling': params.scaling,
    'debias': False,
    'backend': 'multiscale'
}

model_name = f'{params.shape_name}_{params.N}_{params.exp_id}_gen'
a = (torch.ones(params.N)/params.N).to(device).double()
b = (torch.ones(params.N)/params.N).to(device).double()
print(f'Shape: {params.shape_name}')

vals, sensitivity = shape_fn()
sigma = get_sigma(eps=params.eps, delta=1e-4, sensitivity=2 * np.sqrt(2))  if params.cost == 'l2' else 4/params.eps
loss = geomloss.SamplesLoss(**loss_args, blur=sigma if not use_emd else 1e-9)
loss_potentials = geomloss.SamplesLoss(**loss_args, blur=sigma, potentials=True)

if use_emd:
  def loss(X, Y):
    split_size = 1000
    means = []
    i = 0
    for x, y in tqdm(
        zip(torch.split(X, split_size), torch.split(Y, split_size)),
        total=np.ceil(X.shape[0] / split_size).astype(int)):
      i+= 1
      C = cost_mat(x,y,params.cost)
      pi =  ot.lp.emd(torch.ones(C.shape[0]).to(C.device)/C.shape[0],
                                    torch.ones(C.shape[1]).to(C.device)/C.shape[1],
                                    C, numThreads=10, numItermax=400)
      cond_mean = torch.matmul(pi.float(), y.float())
      means.append(cond_mean)
    full_mean = torch.concatenate(means, dim = 0).detach()
    return torch.square(full_mean - X).sum(-1).mean()
  def loss_potentials(X, Y):
    C = cost_mat(X,Y,params.cost)
    l, d = ot.lp.emd2(torch.ones(C.shape[0]).to(C.device)/C.shape[0],
                                  torch.ones(C.shape[1]).to(C.device)/C.shape[1],
                                  C, log=True)
    return d['u'], d['v']

#input to generator sampler
input_scale = np.array(params.input_scale)[None,:]
input_mean = np.array(params.input_mean)[None,:]
if params.input_dist == 'uniform':
    input_generator_fn = lambda n: (torch.rand(n,2)*input_scale+input_mean).double().to(device)
elif params.input_dist == 'normal':
    input_generator_fn = lambda n: (torch.randn(n,2)*input_scale+input_mean).double().to(device)

# generator & its optimizer
torch.manual_seed(0);
np.random.seed(0);
G_input_val = input_generator_fn(params.N)
G = Generator(
    use_bn=params.use_bn, hidden_neurons=params.hidden_neurons, hidden_layers=params.hidden_layers,
    activation=params.activation).to(device).double()

optimizer_g = torch.optim.RMSprop(G.parameters(), lr=params.lr)
G_input = input_generator_fn(params.N)


# data & its privatized version
sampled_idxs = np.random.choice(vals.shape[0], params.N if params.N!='all_random' else 10000)
Xt = vals[sampled_idxs]
Xt += np.random.normal(size=Xt.shape)*sigma if params.cost == 'l2' \
    else np.random.laplace(size=Xt.shape)*sigma
xt = torch.tensor(Xt).to(device).double()

# calculate the loss between privatized data and true data as baseline
if params.load_path is not None:
    with open(params.load_path, 'rb') as f:
        d = pkl.load(f)
    Xt = d['privatized']
    vals = d['target']
    G.load_state_dict(d['generator_state'])
    optimizer_g.load_state_dict(d['optimizer_state'])
    for p in optimizer_g.param_groups:
        p['lr'] = params.lr

    losses_gen = d['loss']
    print('Loading model with params ', d['params'])
else:
    losses_gen = []
print("The loss between samples without noise and the privatized ones:",
      float(loss(torch.Tensor(vals[sampled_idxs]).to(device).double(),
                 xt.view(xt.shape[0], -1))))
losses_tv = []

Shape: ellipsis


100%|██████████| 400/400 [00:30<00:00, 13.18it/s]

The loss between samples without noise and the privatized ones: 0.2721991197714133





In [None]:

# Train the generator
for i in range(params.n_iter):
    gen_samples = G(G_input)
    if use_emd:
      gen_samples = G(G_input)
      start_time = time.time()
      loss_gen = loss(gen_samples, xt)
      end_time = time.time()
      start_time1 = time.time()
      loss_gen.backward()
      optimizer_g.step()
      optimizer_g.zero_grad()
      end_time1 = time.time()
      loss_gen0 = 0
      loss_diff = float(loss_gen.detach().cpu().numpy())
    else:
      start_time = time.time()
      pot1, pot2 = loss_potentials(gen_samples, xt)
      end_time = time.time()

      start_time1 = time.time()

      loss_gen0 = (pot1.mean() + pot2.mean()).detach()

      if params.cost == 'l2':
          cond_mean = transfer(
                torch.Tensor([sigma**2]).type(torch.float64).to(device),
                gen_samples.detach(),
                xt,
                pot1.detach().view(-1, 1),
                pot2.detach().view(-1, 1),
                xt
                )/ params.N
          loss_gen0 -= (torch.square(gen_samples).sum(-1).mean() - \
                        2*(gen_samples*cond_mean).sum(-1).mean() + \
                        torch.square(cond_mean).sum(-1).mean()).detach()
      else:
        gen_samples_old = gen_samples
        diff_pi_fn = lambda gen_new: (loss_l1(
              torch.Tensor([sigma]).type(torch.float64).to(device),
              gen_samples_old.detach(),
              xt,
              pot1.detach().view(-1, 1),
              pot2.detach().view(-1, 1),
              gen_new
              )/ params.N).mean()
        loss_gen0 -= diff_pi_fn(gen_samples.detach()).detach()

      loss_gen0 = float(loss_gen0.cpu().numpy())
      loss_gen = np.inf
      n_iter_G = 0
      # while loss_gen > 1e-4 and n_iter_G < max_iter_G:
      if True:
        n_iter_G += 1
        gen_samples = G(G_input)
        if params.cost == 'l2':
          loss_gen = torch.square(gen_samples).sum(-1).mean() - 2*(gen_samples*cond_mean).sum(-1).mean() + torch.square(cond_mean).sum(-1).mean()
        else:
          loss_gen = diff_pi_fn(gen_samples)
        loss_gen.backward()
        optimizer_g.step()
        optimizer_g.zero_grad()
        # print(loss_gen)
      end_time1 = time.time()
      loss_diff = float(loss_gen.detach().cpu().numpy())
    tv = get_tv(vals, gen_samples.detach().cpu().numpy())
    # print('generator iterations', n_iter_G, 'loss difference: ', loss_diff)
    losses_gen.append(loss_diff + loss_gen0)
    losses_tv.append(tv)
    print(params.shape_name, 'iter', i+1, 'loss: ', losses_gen[-1],
          'loss calc time: ', end_time-start_time,
          'backward pass time: ', end_time1-start_time1,
          'TV: ', tv)

    if (i+1) % 10 == 0:
        with open(f'{project_path}/generated_samples/{model_name}.pkl', 'wb') as f:
            pkl.dump({
                'privatized': Xt,
                'target': vals,
                'generator_state': G.state_dict(),
                'optimizer_state': optimizer_g.state_dict(),
                'loss': losses_gen,
                'loss_tv': losses_tv,
                'params': params,
            }, f)
        new_gen_samples = G(G_input_val).cpu().detach().numpy()
        plt.figure(figsize = (5,5))
        plt.scatter(*Xt.T, label = 'privatized', s=1, alpha = .3)
        plt.scatter(*vals.T, label = 'target', s = 10, c='k')
        plt.scatter(*(new_gen_samples.T), label='generated', s=1, alpha = .1, c='y')
        plt.legend(loc='lower right')
        plt.xlim(-3, 3)
        plt.ylim(-3, 3)
        plt.xticks([])
        plt.yticks([])
        plt.savefig(f'{project_path}/imgs/{model_name}.png')
        plt.show()
        plt.close()

        plt.figure(figsize = (5,5))
        plt.scatter(*Xt.T, label = 'privatized', s=1, alpha = .3)
        plt.scatter(*vals.T, label = 'target', s = 10, c='k')
        plt.scatter(*(new_gen_samples.T), label='generated', s=1, alpha = .013, c='y')
        plt.legend(loc='lower right')
        plt.xlim(-10, 10)
        plt.ylim(-10, 10)
        plt.xticks([])
        plt.yticks([])
        plt.savefig(f'{project_path}/imgs/{model_name}_enlarged.png')
        plt.close()

In [None]:
new_gen_samples = G(G_input).cpu().detach().numpy()
plt.figure(figsize = (5,5))
plt.scatter(*Xt.T, label = 'privatized', s=1, alpha = .3)
plt.scatter(*vals.T, label = 'target', s = 10, c='k')
plt.scatter(*(new_gen_samples.T), label='generated', s=1, alpha = .1, c='y')
plt.legend(loc='lower right')
plt.xlim(-3, 3)
plt.ylim(-3, 3)
plt.xticks([])
plt.yticks([])
plt.savefig(f'{project_path}/imgs/{model_name}.png')
plt.close()
plt.show()