In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

device='cuda'


def poly(cs, xs):
  ys = torch.zeros_like(xs)
  for (i, c) in enumerate(cs):
    ys += c * xs**i

  return ys


def discMC(xs, logpts):
  return poly([0, 0.6, 1.2, 1.1], xs) * poly([1, 0.1, -0.01], logpts)

def discData(xs, logpts):
  return poly([0, 0.7, 1.1, 1.3], xs) * poly([1, -0.1, 0.02], logpts)


def logptMC(xs):
  return torch.log(poly([25, 50, 7], -torch.log(xs)))

def logptData(xs):
  return torch.log(poly([25, 45, 5], -torch.log(xs)))


# need to add a small number to avoid values of zero.
def genMC(n):
  xs = torch.rand(n, device=device) + 1e-5
  logpts = logptMC(xs)
  ys = torch.rand(n, device=device) + 1e-5
  ds = discMC(ys, logpts)
  return torch.stack([logpts, ds]).transpose(0, 1)

def genData(n):
  xs = torch.rand(n, device=device) + 1e-5
  logpts = logptData(xs)
  ys = torch.rand(n, device=device) + 1e-5
  ds = discData(ys, logpts)
  return torch.stack([logpts, ds]).transpose(0, 1)


# TEST


def test(n):
  mc = genMC(n).cpu().numpy()
  data = genData(n).cpu().numpy()

  plt.figure(figsize=(30, 10))

  plt.subplot(1, 3, 1)

  _ = plt.hist([np.exp(mc[:,0]), np.exp(data[:,0])], bins=25, label=["mc", "data"])
  plt.title("pT")
  plt.yscale("log")
  plt.legend()

  plt.subplot(1, 3, 2)


  _ = plt.hist([mc[:,0], data[:,0]], bins=25, label=["mc", "data"])
  plt.title("log pT")
  plt.legend()

  plt.subplot(1, 3, 3)

  _ = plt.hist([mc[:,1], data[:,1]], bins=25, label=["mc", "data"])
  plt.title("discriminant")
  plt.legend()
  plt.show()

test(int(1e6))

In [None]:
%reload_ext tensorboard
%tensorboard --logdir runs

In [None]:
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import torch.autograd as autograd


# this worked well.
# transport = \
#   nn.Sequential(
#     nn.Linear(2, 256)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(256, 256)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(256, 1)
#   )


# adversary = \
#   nn.Sequential(
#     nn.Linear(2, 256)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(256, 256)
#   , nn.LeakyReLU(inplace=True)
#   , nn.Linear(256, 1)
#   )

# tlr = 1e-7
# alr = 5e-6


transport = \
  nn.Sequential(
    nn.Linear(2, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 1)
  )


adversary = \
  nn.Sequential(
    nn.Linear(2, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 512)
  , nn.LeakyReLU(inplace=True)
  , nn.Linear(512, 1)
  )


tlr = 1e-7
alr = 5e-6


transport.to(device)
adversary.to(device)


def tloss(xs):
  return torch.mean(xs**2)

aloss = nn.BCEWithLogitsLoss(reduction='mean')


toptim = torch.optim.Adam(transport.parameters(), lr=tlr)
aoptim = torch.optim.Adam(adversary.parameters(), lr=alr)


transport

In [None]:
from math import log
import torch.utils.data as d

def tonp(xs):
  return xs.cpu().detach().numpy()


lam = 500


# build validation samples at 25 and 500 GeV
nval = int(2**20)

pt25 = torch.ones(nval, device=device)*log(25)
pt500 = torch.ones(nval, device=device)*log(500)

valxs = torch.rand(nval, device=device) + 1e-3
datads_pt25 = tonp(discData(valxs, pt25))
datads_pt500 = tonp(discData(valxs, pt500))

valxs = torch.rand(nval, device=device) + 1e-3
mcds_pt25 = discMC(valxs, pt25)
mcds_pt500 = discMC(valxs, pt500)

histmc_pt25 = torch.stack([pt25, mcds_pt25]).transpose(0, 1)
histmc_pt500 = torch.stack([pt500, mcds_pt500]).transpose(0, 1)

mcds_pt25 = tonp(mcds_pt25)
mcds_pt500 = tonp(mcds_pt500)


nepochs = int(1e6)
batchsize = 2**10

datasize = int(2**20)

alldata = genData(datasize)
allmc = genMC(datasize)

writer = SummaryWriter()
nbatches = datasize // batchsize

# with autograd.detect_anomaly():
for epoch in range(nepochs):
  radvloss = 0
  fadvloss = 0
  tadvloss = 0
  ttransloss = 0
  realavg = 0
  fakeavg = 0

  for batch in range(nbatches):
    data = alldata[torch.randint(alldata.size()[0], size=(batchsize,))]
    mc = allmc[torch.randint(allmc.size()[0], size=(batchsize,))]


    toptim.zero_grad()
    aoptim.zero_grad()

    real = adversary(data)

    transporting = transport(mc)
    transported = transporting + mc[:,1:]
    fake = adversary(torch.cat([mc[:,:1], transported], axis=1))

    realavg += torch.mean(real).item()
    fakeavg += torch.mean(fake).item()
    
    tmp1 = aloss(real, torch.ones_like(real))
    radvloss += tmp1.item()

    tmp2 = aloss(fake, torch.zeros_like(fake))
    fadvloss += tmp1.item()

    loss = tmp1 + tmp2

    loss.backward()
    aoptim.step()


    toptim.zero_grad()
    aoptim.zero_grad()

    transporting = transport(mc)
    transported = transporting + mc[:,1:]
    fake = adversary(torch.cat([mc[:,:1], transported], axis=1))

    tmp1 = tloss(transporting)
    ttransloss += tmp1.item()

    tmp2 = lam * aloss(fake, torch.ones_like(fake))
    tadvloss += tmp2.item()

    loss = tmp1 + tmp2

    loss.backward()
    toptim.step()


  # write tensorboard info once per epoch
  writer.add_scalar('radvloss', radvloss / nbatches, epoch)
  writer.add_scalar('fadvloss', fadvloss / nbatches, epoch)
  writer.add_scalar('tadvloss', tadvloss / nbatches, epoch)
  writer.add_scalar('ttransloss', ttransloss / nbatches, epoch)
  writer.add_scalar('realavg', realavg / nbatches, epoch)
  writer.add_scalar('fakeavg', fakeavg / nbatches, epoch)


  # make validation plots once per epoch
  transporting_pt25 = transport(histmc_pt25).squeeze()
  transporting_pt500 = transport(histmc_pt500).squeeze()

  transported_pt25 = tonp(transporting_pt25 + histmc_pt25[:,1])
  transported_pt500 = tonp(transporting_pt500 + histmc_pt500[:,1])

  transporting_pt25 = tonp(transporting_pt25)
  transporting_pt500 = tonp(transporting_pt500)


  plt.figure(figsize=(36, 6))

  plt.subplot(1, 6, 1)

  rangex = (0, 5)
  rangey = (-0.75, 0.25)

  h, b, _ = plt.hist( \
        [mcds_pt25, transported_pt25, datads_pt25]
      , bins=25
      , range=rangex
      , density=True
      , label=["mc", "transported", "data"]
      )
  
  plt.title("discriminant distribution, pT = 25")
  plt.xlabel("discriminant")
  plt.legend()

  plt.subplot(1, 6, 2)


  _ = plt.plot( \
        (b[:-1] + b[1:]) / 2.0
      , h[0] - h[2]
      , label="mc"
      , linewidth=3
      )
  
  _ = plt.plot( \
        (b[:-1] + b[1:]) / 2.0
      , h[1] - h[2]
      , label="transported"
      , linewidth=3
      )
  
  plt.ylim(-0.5, 0.5)
  plt.title("discriminant difference to data, pT = 25")
  plt.xlabel("discriminant")
  plt.ylabel("prediction - data")
  plt.legend()


  plt.subplot(1, 6, 3)

  _ = plt.hist2d(mcds_pt25, transporting_pt25, range=[rangex, rangey], bins=20)
  plt.title("discriminant transport, pT = 25")
  plt.xlabel("mc discriminant")
  plt.ylabel("transport vector")


  plt.subplot(1, 6, 4)

  h, b, _ = plt.hist( \
        [mcds_pt500, transported_pt500, datads_pt500]
      , bins=25
      , range=rangex
      , density=True
      , label=["mc", "transported", "data"]
      )

  plt.title("discriminant distribution, pT = 500")
  plt.xlabel("discriminant")
  plt.legend()


  plt.subplot(1, 6, 5)

  _ = plt.plot( \
        (b[:-1] + b[1:]) / 2.0
      , h[0] - h[2]
      , label="mc"
      , linewidth=3
      )
  
  _ = plt.plot( \
        (b[:-1] + b[1:]) / 2.0
      , h[1] - h[2]
      , label="transported"
      , linewidth=3
      )
  
  plt.ylim(-0.5, 0.5)
  plt.title("discriminant difference to data, pT = 500")
  plt.xlabel("discriminant")
  plt.ylabel("prediction - data")
  plt.legend()


  plt.subplot(1, 6, 6)

  _ = plt.hist2d(mcds_pt500, transporting_pt500, range=[rangex, rangey], bins=20)
  plt.title("discriminant transport, pT = 500")
  plt.xlabel("mc discriminant")
  plt.ylabel("transport vector")

  plt.subplots_adjust(wspace=0.4, hspace=0.4)

  plt.show()

  print("\n\n")

In [None]:
# !rm -r runs