In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from scipy.special import lambertw

device='cuda'


def poly(cs, xs):
  ys = torch.zeros_like(xs)
  for (i, c) in enumerate(cs):
    ys += c * xs**i

  return ys


def discMC(xs, logpts):
  return poly([0, 0.6, 1.2, 1.1], xs) * poly([1, 0.1, -0.01], logpts)

def discData(xs, logpts):
  return poly([0, 0.7, 1.1, 1.3], xs) * poly([1, -0.1, 0.02], logpts)


def logptMC(xs):
  return torch.log(poly([25, 50, 7], -torch.log(xs)))

def logptData(xs):
  return torch.log(poly([25, 45, 5], -torch.log(xs)))



def bootstrap(n):
  xs = np.random.rand(n)
  ws = lambertw((xs-1)/np.e, k=-1).astype(np.float)
  return - torch.from_numpy(ws).to(device) - 1


# need to add a small number to avoid values of zero.
def genMC(n):
  xs = torch.rand(n, device=device) + 1e-5
  logpts = logptMC(xs)
  ys = torch.rand(n, device=device) + 1e-5
  ds = discMC(ys, logpts)
  return torch.stack([ds, logpts]).transpose(0, 1)

def genData(n):
  xs = torch.rand(n, device=device) + 1e-5
  logpts = logptData(xs)
  ys = torch.rand(n, device=device) + 1e-5
  ds = discData(ys, logpts)
  return torch.stack([ds, logpts]).transpose(0, 1)


def test(n):
  mc = genMC(n).cpu().numpy()
  data = genData(n).cpu().numpy()

  plt.figure(figsize=(30, 10))

  plt.subplot(1, 3, 1)

  _ = plt.hist([np.exp(mc[:,1]), np.exp(data[:,1])], bins=25, label=["mc", "data"])
  plt.title("pT")
  plt.yscale("log")
  plt.legend()

  plt.subplot(1, 3, 2)


  _ = plt.hist([mc[:,1], data[:,1]], bins=25, label=["mc", "data"])
  plt.title("log pT")
  plt.legend()

  plt.subplot(1, 3, 3)

  _ = plt.hist([mc[:,0], data[:,0]], bins=25, label=["mc", "data"])
  plt.title("discriminant")
  plt.legend()
  plt.show()

test(int(1e6))

In [None]:
%reload_ext tensorboard
%tensorboard --logdir runs

In [None]:
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.autograd as autograd

nps = 5


# this worked well.
# transport = \
#   nn.Sequential(
#     nn.Linear(2, 512)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(512, 512)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(512, 1)
#   )


# adversary = \
#   nn.Sequential(
#     nn.Linear(2, 512)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(512, 512)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(512, 1)
#   )


# tlr = 1e-7
# alr = 5e-6


transport = \
  nn.Sequential(
    nn.Linear(2+nps, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 1)
  )


adversary = \
  nn.Sequential(
    nn.Linear(2, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 1)
  )


tlr = 1e-7
alr = 5e-6


transport.to(device)
adversary.to(device)


def tloss(xs):
  return torch.mean(xs**2)


# binary_cross_entropy_with_logits

toptim = torch.optim.Adam(transport.parameters(), lr=tlr)
aoptim = torch.optim.Adam(adversary.parameters(), lr=alr)


transport

In [None]:
from math import exp, log
from torch.nn.functional import binary_cross_entropy_with_logits

def tonp(xs):
  return xs.cpu().detach().numpy()


def plotPtTheta(logpt, toys):
  zeros = torch.zeros((toys.size()[0], nps), device=device)
  logpts = torch.ones(toys.size()[0], device=device)*logpt

  data = torch.stack([torch.sort(discData(toys, logpts))[0], logpts]).transpose(0, 1)
  mc = torch.stack([torch.sort(discMC(toys, logpts))[0], logpts]).transpose(0, 1)

  thetas = zeros.clone()
  transporting = transport(torch.cat([mc, thetas], axis=1))
  nomtrans = tonp(transporting)
  nom = tonp(transporting + mc[:,0:1])

  postrans = []
  negtrans = []
  for i in range(nps):
    thetas = zeros.clone()
    thetas[:,i] = 1
    transporting = transport(torch.cat([mc, thetas], axis=1))
    postrans.append(tonp(transporting))

    thetas = zeros.clone()
    thetas[:,i] = -1
    transporting = transport(torch.cat([mc, thetas], axis=1))
    negtrans.append(tonp(transporting))


  data = tonp(data)
  mc = tonp(mc)

  plt.figure(figsize=(18, 6))

  plt.subplot(1, 3, 1)

  rangex = (0, 5)
  rangey = (-0.75, 0.25)

  h, b, _ = plt.hist( \
        [mc[:,0], nom[:,0], data[:,0]]
      , bins=25
      , range=rangex
      , density=True
      , label=["mc", "nominal transported", "data"]
      )
  
  plt.title("discriminant distribution, (pT = %0.2f)" % exp(logpt))
  plt.xlabel("discriminant")
  plt.legend()

  plt.subplot(1, 3, 2)


  _ = plt.plot( \
        (b[:-1] + b[1:]) / 2.0
      , h[0] - h[2]
      , label="mc"
      , linewidth=3
      )

  _ = plt.plot( \
        (b[:-1] + b[1:]) / 2.0
      , h[1] - h[2]
      , label="transported"
      , linewidth=3
      )

  plt.ylim(-0.5, 0.5)
  plt.title("discriminant difference to data, (pT = %0.2f)" % exp(logpt))
  plt.xlabel("discriminant")
  plt.ylabel("prediction - data")
  plt.legend()


  plt.subplot(1, 3, 3)


  
  cols = ["blue", "green", "red", "orange", "magenta"]
  for i in range(nps):
    _ = \
      plt.plot(
          mc[:,0]
        , postrans[i]
        , c=cols[i]
      )

    _ = \
      plt.plot(
          mc[:,0]
        , negtrans[i]
        , c=cols[i]
      )


  _ = \
    plt.plot(
        mc[:,0]
      , nomtrans
      , c="black"
    )


  
  plt.xlim(rangex)
  plt.ylim(rangey)
  plt.title("discriminant transport, (pT = %0.2f)" % exp(logpt))
  plt.xlabel("mc discriminant")
  plt.ylabel("transport vector")

  plt.subplots_adjust(wspace=0.4)

  plt.show()

  print("\n\n")



lam = 500


# toys for validation samples
nval = int(2**15)
valtoys = torch.rand(nval, device=device)


nepochs = 2**20
batchsize = 2**12

datasize = 2**20

alldata = genData(datasize)
allmc = genMC(datasize)

writer = SummaryWriter()
nbatches = datasize // batchsize


# with autograd.detect_anomaly():
for epoch in range(nepochs):
  radvloss = 0
  fadvloss = 0
  tadvloss = 0
  ttransloss = 0
  realavg = 0
  fakeavg = 0

  for batch in range(nbatches):
    straps = bootstrap(batchsize).unsqueeze(1)
    thetas = torch.randn((batchsize, nps), device=device)


    tmp = alldata[torch.randint(alldata.size()[0], size=(batchsize,), device=device)]
    data = tmp

    tmp = allmc[torch.randint(allmc.size()[0], size=(batchsize,), device=device)]
    mc = tmp

    toptim.zero_grad()
    aoptim.zero_grad()

    real = adversary(data)

    transporting = transport(torch.cat([mc, thetas], axis=1))
    transported = transporting + mc[:,0:1]
    fake = adversary(torch.cat([transported, mc[:,1:]], axis=1))

    realavg += torch.mean(real).item()
    fakeavg += torch.mean(fake).item()

    
    tmp1 = \
      binary_cross_entropy_with_logits( \
          real
        , torch.ones_like(real)
        , weight=straps
        , reduction='mean'
        )

    radvloss += tmp1.item()


    tmp2 = \
      binary_cross_entropy_with_logits( \
          fake
        , torch.zeros_like(real)
        , reduction='mean'
        )
      
    fadvloss += tmp2.item()

    loss = tmp1 + tmp2

    loss.backward()
    aoptim.step()


    toptim.zero_grad()
    aoptim.zero_grad()

    transporting = transport(torch.cat([mc, thetas], axis=1))
    transported = transporting + mc[:,0:1]
    fake = adversary(torch.cat([transported, mc[:,1:]], axis=1))

    tmp1 = tloss(transporting)
    ttransloss += tmp1.item()

    tmp2 = \
      lam * \
      binary_cross_entropy_with_logits( \
          fake
        , torch.ones_like(real)
        , reduction='mean'
        )
      
    tadvloss += tmp2.item()

    loss = tmp2 # tmp1 + tmp2

    loss.backward()
    toptim.step()


  # write tensorboard info once per epoch
  writer.add_scalar('radvloss', radvloss / nbatches, epoch)
  writer.add_scalar('fadvloss', fadvloss / nbatches, epoch)
  writer.add_scalar('tadvloss', tadvloss / nbatches, epoch)
  writer.add_scalar('ttransloss', ttransloss / nbatches, epoch)
  writer.add_scalar('realavg', realavg / nbatches, epoch)
  writer.add_scalar('fakeavg', fakeavg / nbatches, epoch)


  # make validation plots once per epoch
  plotPtTheta(log(25), valtoys)
  plotPtTheta(log(250), valtoys)
  plotPtTheta(log(500), valtoys)
  plotPtTheta(log(1000), valtoys)


In [None]:
# !rm -r runs