In [2]:
import numpy as np
import matplotlib.pyplot as plt
import energyflow as ef
import energyflow.archs
from matplotlib import gridspec

from matplotlib import rc
rc('font', size=20)

In [None]:
#These are the same datasets from the OmniFold paper https://arxiv.org/abs/1911.09107.  More detail at https://energyflow.network/docs/datasets/.
#Pythia and Herwig are two generators; one will be treated here as the "simulation" and one as "data".
datasets = {'Pythia26': ef.zjets_delphes.load('Pythia26', num_data=1000000),
            'Herwig': ef.zjets_delphes.load('Herwig', num_data=1000000)}

## Set up the observables

In [None]:
def is_charged(myin):
    if (myin == 0):
        return 0
    elif (myin == 0.1):
        return 1
    elif (myin == 0.2):
        return -1
    elif (myin == 0.3):
        return 0
    elif (myin == 0.4):
        return -1
    elif (myin == 0.5):
        return 1
    elif (myin == 0.6):
        return -1
    elif (myin == 0.7):
        return 1
    elif (myin == 0.8):
        return 1
    elif (myin == 0.9):
        return -1
    elif (myin == 1.0):
        return 1
    elif (myin == 1.1):
        return -1
    elif (myin == 1.2):
        return 0
    elif (myin == 1.3):
        return 0

In [None]:
for dataset in datasets:
    mycharges = []
    mycharges2 = []
    for i in range(len(datasets[dataset]['gen_particles'])):
        pTs = datasets[dataset]['gen_particles'][i][:,0]
        charges = [is_charged(datasets[dataset]['gen_particles'][i][:,3][j]) for j in range(len(datasets[dataset]['gen_particles'][i][:,3]))]
        mycharges+=[np.sum(charges*pTs**0.5)/np.sum(pTs**0.5)]
        mycharges2+=[np.sum(np.abs(charges)*pTs)/np.sum(pTs)]
    datasets[dataset]['gen_charge'] = mycharges
    datasets[dataset]['gen_pTcharge'] = mycharges2

    mycharges = []
    mycharges2 = []
    for i in range(len(datasets[dataset]['sim_particles'])):
        pTs = datasets[dataset]['sim_particles'][i][:,0]
        charges = [is_charged(datasets[dataset]['sim_particles'][i][:,3][j]) for j in range(len(datasets[dataset]['sim_particles'][i][:,3]))]
        mycharges+=[np.sum(charges*pTs**0.5)/np.sum(pTs**0.5)]
        mycharges2+=[np.sum(np.abs(charges)*pTs)/np.sum(pTs)]
    datasets[dataset]['sim_charge'] = mycharges
    datasets[dataset]['sim_pTcharge'] = mycharges2

In [None]:
datasets['Pythia26'].keys()

In [None]:
tau2s_Pythia_sim = datasets['Pythia26']['sim_tau2s']
tau2s_Herwig_sim = datasets['Herwig']['sim_tau2s']

tau1s_Pythia_sim = datasets['Pythia26']['sim_widths']
tau1s_Herwig_sim = datasets['Herwig']['sim_widths']

tau2s_Pythia_gen = datasets['Pythia26']['gen_tau2s']
tau2s_Herwig_gen = datasets['Herwig']['gen_tau2s']

tau1s_Pythia_gen = datasets['Pythia26']['gen_widths']
tau1s_Herwig_gen = datasets['Herwig']['gen_widths']

In [None]:
pT_true = datasets['Pythia26']['gen_jets'][:,0]
m_true = datasets['Pythia26']['gen_jets'][:,3]
pT_reco = datasets['Pythia26']['sim_jets'][:,0]
m_reco = datasets['Pythia26']['sim_jets'][:,3]

pT_true_alt = datasets['Herwig']['gen_jets'][:,0]
m_true_alt = datasets['Herwig']['gen_jets'][:,3]
pT_reco_alt = datasets['Herwig']['sim_jets'][:,0]
m_reco_alt = datasets['Herwig']['sim_jets'][:,3]

#
w_true = datasets['Pythia26']['gen_widths']
w_reco = datasets['Pythia26']['sim_widths']
w_true_alt = datasets['Herwig']['gen_widths']
w_reco_alt = datasets['Herwig']['sim_widths']

#
q_true = np.array(datasets['Pythia26']['gen_charge'])
q_reco = np.array(datasets['Pythia26']['sim_charge'])
q_true_alt = np.array(datasets['Herwig']['gen_charge'])
q_reco_alt = np.array(datasets['Herwig']['sim_charge'])

#
r_true = np.array(datasets['Pythia26']['gen_pTcharge'])
r_reco = np.array(datasets['Pythia26']['sim_pTcharge'])
r_true_alt = np.array(datasets['Herwig']['gen_pTcharge'])
r_reco_alt = np.array(datasets['Herwig']['sim_pTcharge'])

In [None]:
plt.hist(r_true,bins=np.linspace(0,1.5,20))
plt.axvline(0.66,color="black",ls=":")

## Set up the binning 

Goal is to learn average value of the four jet substructure observables as a function of jet pT.

In [None]:
#For the binning, make it so we have 50% in each bin growing from the low side
binvals = [100]
i = 0
while binvals[-1] < 500:
    for binhigh in range(binvals[i]+1,1000):
        purity = len(pT_true[(pT_true > binvals[i])*(pT_true < binhigh)*(pT_reco > binvals[i])*(pT_reco < binhigh)]) / len(pT_true[(pT_true > binvals[i])*(pT_true < binhigh)])
        if (purity > 0.5**0.5):
            print(binhigh,purity)
            i+=1
            binvals+=[binhigh]
            break

In [None]:
fig = plt.figure(figsize=(8, 6)) 
gs = gridspec.GridSpec(1, 1, height_ratios=[1]) 
ax0 = plt.subplot(gs[0])
ax0.yaxis.set_ticks_position('both')
ax0.xaxis.set_ticks_position('both')
ax0.tick_params(direction="in",which="both")
ax0.minorticks_on()
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)

n,b,_=plt.hist(pT_true,bins=binvals)
plt.xlabel("Truth jet $p_T$ [GeV]")
plt.ylabel("Number of jets")
fig.savefig('figures/jetpt.pdf',bbox_inches='tight')

For the binned comparison, we pick the largest bins that have 50% purity.

In [None]:
#Next, let's construct the response matrix.

pTbin_truth = np.clip(np.digitize(pT_true,binvals),1,len(binvals)-1)-1
pTbin_reco = np.clip(np.digitize(pT_reco,binvals),1,len(binvals)-1)-1

#alt

pTbin_truth_alt = np.clip(np.digitize(pT_true_alt,binvals),1,len(binvals)-1)-1
pTbin_reco_alt = np.clip(np.digitize(pT_reco_alt,binvals),1,len(binvals)-1)-1

In [None]:
fig = plt.figure(figsize=(8, 6)) 
gs = gridspec.GridSpec(1, 1, height_ratios=[1]) 
ax0 = plt.subplot(gs[0])
ax0.yaxis.set_ticks_position('both')
ax0.xaxis.set_ticks_position('both')
ax0.tick_params(direction="in",which="both")
ax0.minorticks_on()
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
    
H_pT, xedges, yedges = np.histogram2d(pTbin_truth,pTbin_reco,bins=[range(len(binvals)),range(len(binvals))])
H_norm_pT = H_pT / H_pT.sum(axis=1, keepdims=True)
plt.imshow(H_norm_pT,origin='lower',cmap="Reds",vmin = 0,vmax = 1)
cbar = plt.colorbar()
cbar.ax.set_ylabel('Pr(Detector | Particle)') 
plt.xlabel("Particle-level $p_{T}$ bin",fontsize=20)
plt.ylabel("Detector-level $p_{T}$ bin",fontsize=20)

for i in range(len(binvals)-1):
    for j2 in range(len(binvals)-1):
        plt.text(j2,i, "%0.2f" % H_norm_pT.T[i,j2], 
                color="w", ha="center", va="center", fontweight="bold",fontsize=12)
fig.savefig('figures/RsponsepT.pdf',bbox_inches='tight')

## IBU

In [None]:
def IBU(T,D,R,n):
    phis = [T]
    for i in range(n):
        m = R * phis[-1]
        m /= (m.sum(axis=1)[:,np.newaxis] + 10**-50)
        phis.append(np.dot(m.T, D))
    return phis[-1]

In [None]:
T = np.sum(H_pT,axis=1)
D = np.sum(H_pT,axis=0)
IBU(T,D,H_norm_pT.T,1)

In [None]:
T

In [None]:
pT_true = datasets['Pythia26']['gen_jets'][:,0]
pT_reco = datasets['Pythia26']['sim_jets'][:,0]

pT_true_alt = datasets['Herwig']['gen_jets'][:,0]
pT_reco_alt = datasets['Herwig']['sim_jets'][:,0]

features = {}

features["m","rec","nom"] = np.array(datasets['Pythia26']['sim_jets'][:,3])
features["m","tru","nom"] = np.array(datasets['Pythia26']['gen_jets'][:,3])
features["m","rec","alt"] = np.array(datasets['Herwig']['sim_jets'][:,3])
features["m","tru","alt"] = np.array(datasets['Herwig']['gen_jets'][:,3])

features["w","rec","nom"] = np.array(datasets['Pythia26']['sim_widths'])
features["w","tru","nom"] = np.array(datasets['Pythia26']['gen_widths'])
features["w","rec","alt"] = np.array(datasets['Herwig']['sim_widths'])
features["w","tru","alt"] = np.array(datasets['Herwig']['gen_widths'])

features["q","rec","nom"] = np.array(datasets['Pythia26']['sim_charge'])
features["q","tru","nom"] = np.array(datasets['Pythia26']['gen_charge'])
features["q","rec","alt"] = np.array(datasets['Herwig']['sim_charge'])
features["q","tru","alt"] = np.array(datasets['Herwig']['gen_charge'])

features["r","rec","nom"] = np.array(datasets['Pythia26']['sim_pTcharge'])
features["r","tru","nom"] = np.array(datasets['Pythia26']['gen_pTcharge'])
features["r","rec","alt"] = np.array(datasets['Herwig']['sim_pTcharge'])
features["r","tru","alt"] = np.array(datasets['Herwig']['gen_pTcharge'])

maxvalues = {}

maxvalues['m'] = 200
maxvalues['w'] = 0.7
maxvalues['q'] = 0.5
maxvalues['r'] = 1

minvalues = {}

minvalues['m'] = 0
minvalues['w'] = 0
minvalues['q'] = -0.5
minvalues['r'] = 0

omnifolded = {}
omnifolded['r'] = weightsR[-1,1]
omnifolded['q'] = weightsQ[-1,1]
omnifolded['w'] = weightsW[-1,1]
omnifolded['m'] = weights[-1,1]

In [None]:
binvalsObs = {}
for obs in ['r','q','w','m']:
    binvalsObs[obs] = {}
    pTbin_truth = np.clip(np.digitize(pT_true,binvals),1,len(binvals)-1)-1
    pTbin_reco = np.clip(np.digitize(pT_reco,binvals),1,len(binvals)-1)-1
    xt = features[obs,"tru","nom"]
    xr = features[obs,"rec","nom"]
    for ii in range(len(binvals)-1):
        binvalsObs[obs][ii] = [minvalues[obs]]
        i = 0
        disttotal = 0.
        breakloop = True
        while len(binvalsObs[obs][ii]) < 15 and binvalsObs[obs][ii][-1] < maxvalues[obs] and breakloop:
            mycount = 0
            for binhigh in np.linspace(binvalsObs[obs][ii][i]+0.01,maxvalues[obs],100):
                mycount+=1
                purity = len(xt[(pTbin_reco==ii)*(pTbin_truth==ii)*(xt > binvalsObs[obs][ii][i])*(xt < binhigh)*(xr > binvalsObs[obs][ii][i])*(xr < binhigh)]) / (0.00000001+len(xt[(pTbin_truth==ii)*(xt > binvalsObs[obs][ii][i])*(xt < binhigh)]))
                distamount = len(xt[(pTbin_reco==ii)*(pTbin_truth==ii)*(xr > binvalsObs[obs][ii][i])*(xr < binhigh)]) / (0.00000001+len(xr[(pTbin_reco==ii)*(pTbin_truth==ii)]))
                if (purity > 0.5):
                    i+=1
                    disttotal += distamount
                    binvalsObs[obs][ii]+=[binhigh]
                    #print("  ",binhigh,purity,disttotal,distamount)
                    break
                if (mycount==99):
                    breakloop = False
                    #binvalsObs[obs][ii]+=[rmax]

        print(len(binvalsObs[obs][ii])) #,binvalsObs[obs][ii])
        pass
    
    #Next, let's construct the response matrix.

    pTbin_truth = np.clip(np.digitize(pT_true,binvals),1,len(binvals)-1)-1
    pTbin_reco = np.clip(np.digitize(pT_reco,binvals),1,len(binvals)-1)-1

    xbin_truth_all = np.array([np.clip(np.digitize(xt,binvalsObs[obs][ii]),1,len(binvalsObs[obs][ii])-1)-1 for ii in range(len(binvals)-1)])
    xbin_reco_all = np.array([np.clip(np.digitize(xr,binvalsObs[obs][ii]),1,len(binvalsObs[obs][ii])-1)-1 for ii in range(len(binvals)-1)] )
    xbin_truth_all = xbin_truth_all.T
    xbin_reco_all = xbin_reco_all.T

    xbin_truth = np.array([xbin_truth_all[i][pTbin_truth[i]] for i in range(len(pTbin_truth))])
    xbin_reco = np.array([xbin_reco_all[i][pTbin_reco[i]] for i in range(len(pTbin_reco))])

    ###
    bin2_truth = []
    for i in range(len(pTbin_truth)):
        mybin = 0
        for i2 in range(0,pTbin_truth[i]):
            mybin+=len(binvalsObs[obs][i2])-1
        bin2_truth+=[mybin+xbin_truth[i]]

    bin2_reco = []
    for i in range(len(pTbin_reco)):
        mybin = 0
        for i2 in range(0,pTbin_reco[i]):
            mybin+=len(binvalsObs[obs][i2])-1
        bin2_reco+=[mybin+xbin_reco[i]]

    #alt
    xta = features[obs,"tru","alt"]
    xra = features[obs,"rec","alt"]

    pTbin_truth_alt = np.clip(np.digitize(pT_true_alt,binvals),1,len(binvals)-1)-1
    pTbin_reco_alt = np.clip(np.digitize(pT_reco_alt,binvals),1,len(binvals)-1)-1

    xbin_truth_all_alt = np.array([np.clip(np.digitize(xta,binvalsObs[obs][ii]),1,len(binvalsObs[obs][ii])-1)-1 for ii in range(len(binvals)-1)])
    xbin_reco_all_alt = np.array([np.clip(np.digitize(xra,binvalsObs[obs][ii]),1,len(binvalsObs[obs][ii])-1)-1 for ii in range(len(binvals)-1)] )
    xbin_truth_all_alt = xbin_truth_all_alt.T
    xbin_reco_all_alt = xbin_reco_all_alt.T

    xbin_truth_alt = np.array([xbin_truth_all_alt[i][pTbin_truth_alt[i]] for i in range(len(pTbin_truth_alt))])
    xbin_reco_alt = np.array([xbin_reco_all_alt[i][pTbin_reco_alt[i]] for i in range(len(pTbin_reco_alt))])

    bin2_truth_alt = []
    for i in range(len(pTbin_truth_alt)):
        mybin = 0
        for i2 in range(0,pTbin_truth_alt[i]):
            mybin+=len(binvalsObs[obs][i2])-1
        bin2_truth_alt+=[mybin+xbin_truth_alt[i]]

    bin2_reco_alt = []
    for i in range(len(pTbin_reco_alt)):
        mybin = 0
        for i2 in range(0,pTbin_reco_alt[i]):
            mybin+=len(binvalsObs[obs][i2])-1
        bin2_reco_alt+=[mybin+xbin_reco_alt[i]]

    fig = plt.figure(figsize=(8, 6)) 
    gs = gridspec.GridSpec(1, 1, height_ratios=[1]) 
    ax0 = plt.subplot(gs[0])
    ax0.yaxis.set_ticks_position('both')
    ax0.xaxis.set_ticks_position('both')
    ax0.tick_params(direction="in",which="both")
    ax0.minorticks_on()
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)

    H, xedges, yedges = np.histogram2d(bin2_truth,bin2_reco,bins=[range(max(bin2_truth)+2),range(max(bin2_truth)+2)])
    H_alt, xedges, yedges = np.histogram2d(bin2_truth_alt,bin2_reco_alt,bins=[range(max(bin2_truth)+2),range(max(bin2_truth)+2)])
    H_norm = H / H.sum(axis=1, keepdims=True)
    plt.imshow(H_norm,origin='lower',cmap="Reds",vmin = 0,vmax = 1)
    cbar = plt.colorbar()
    cbar.ax.set_ylabel('Pr(Detector | Particle)')

    plt.xlabel("Particle-level "+obs+" and $p_{T}$ bin",fontsize=20)
    plt.ylabel("Detector-level "+obs+" and $p_{T}$ bin",fontsize=20)
    fig.savefig('figures/response'+obs+'.pdf',bbox_inches='tight')
    
    #IBU
    T = np.sum(H,axis=1)
    D = np.sum(H,axis=0)
    D_alt = np.sum(H_alt,axis=0)
    T_alt = np.sum(H_alt,axis=1)
    ibu = IBU(T,D_alt,H_norm.T,10)
    #ibu = IBU(T,D,H_norm.T,10)
    
    for moment in [1,2]:

        #Unbinned
        means_unbinnedx = np.array([np.mean(xt[pTbin_truth==i]**moment) for i in range(7)])
        means_unbinnedx_alt = np.array([np.mean(xta[pTbin_truth_alt==i]**moment) for i in range(7)])

        means_unbinnedx_reco = np.array([np.mean(xr[pTbin_reco==i]**moment) for i in range(7)])
        means_unbinnedx_reco_alt = np.array([np.mean(xra[pTbin_reco_alt==i]**moment) for i in range(7)])
        #means_unbinnedx_omnifold_alt = np.array([np.average(xt[pTbin_truth==i]**moment,weights=omnifolded[obs][pTbin_truth==i]) for i in range(7)])

        #Various corrections
        means_binnedx = []
        means_binnedx_alt = []
        means_binnedx_alt_corrected = []
        for i in range(7):

            mybin = 0
            for i2 in range(0,i):
                mybin += len(binvalsObs[obs][i2])-1

            ibu_i = [ibu[mybin+j] for j in range(len(binvalsObs[obs][i])-1)]

            x_centers = 0.5*(np.array(binvalsObs[obs][i][0:-1])+np.array(binvalsObs[obs][i][1:]))
            xvals = [len(xt[(pTbin_truth==i)*(xbin_truth==j)]) for j in range(len(binvalsObs[obs][i])-1)]
            xvals_alt = [len(xta[(pTbin_truth_alt==i)*(xbin_truth_alt==j)]) for j in range(len(binvalsObs[obs][i])-1)]
            means_binnedx += [np.sum(x_centers**moment*xvals)/np.sum(xvals)]
            means_binnedx_alt += [np.sum(x_centers**moment*ibu_i)/np.sum(ibu_i)]

            #Try a per-bin correction
            xiavg = []
            for j in range(len(binvalsObs[obs][i])-1):
                xiavg+=[np.mean(xt[(pTbin_truth==i)*(xbin_truth==j)]**moment)]
            xiavg = np.array(xiavg)
            #print(np.mean(xt[(pTbin_truth==i)]),np.sum(xiavg*ibu_i)/np.sum(ibu_i))
            means_binnedx_alt_corrected += [np.sum(xiavg*ibu_i)/np.sum(ibu_i)]

        means_binnedx = np.array(means_binnedx)
        means_binnedx_alt = np.array(means_binnedx_alt)
        means_binnedx_alt_corrected = np.array(means_binnedx_alt_corrected)

        fig = plt.figure(figsize=(8, 8))
        gs = gridspec.GridSpec(2, 1, height_ratios=[3,4]) 
        ax0 = plt.subplot(gs[0])
        ax0.yaxis.set_ticks_position('both')
        ax0.xaxis.set_ticks_position('both')
        ax0.tick_params(direction="in",which="both")
        plt.xticks(fontsize=0)
        plt.yticks(fontsize=20)
        ax0.minorticks_on()

        plt.plot(0.5*(np.array(binvals[0:-1])+np.array(binvals[1:])),means_unbinnedx,marker='^',ls="",label="Unbinned Pythia",color='red')
        plt.plot(0.5*(np.array(binvals[0:-1])+np.array(binvals[1:])),means_unbinnedx_alt,marker='v',ls="",label="Unbinned Herwig",color='black')
        #plt.ylim([0.05,0.25])

        plt.legend(frameon=True,fontsize=14)
        if (moment==1):
            plt.ylabel(r'$< '+obs+' >$',fontsize=20)
        else:
            plt.ylabel(r'$< '+obs+'^'+str(moment)+' >$',fontsize=20)
        if (obs=='m'):
            if (moment==1):
                plt.ylabel(r'$< '+obs+' >$ [GeV]',fontsize=20)
            else:
                plt.ylabel(r'$< '+obs+'^'+str(moment)+' >$ [GeV$^'+str(moment)+'$]',fontsize=20)
            
        ax1 = plt.subplot(gs[1])
        ax1.yaxis.set_ticks_position('both')
        ax1.xaxis.set_ticks_position('both')
        ax1.tick_params(direction="in",which="both")
        ax1.minorticks_on()

        ax1.plot(0.5*(np.array(binvals[0:-1])+np.array(binvals[1:])),means_binnedx_alt_corrected/means_unbinnedx_alt,marker='s',ls="",label="IBU + Bin Avg Correction")
        ax1.plot(0.5*(np.array(binvals[0:-1])+np.array(binvals[1:])),(means_unbinnedx_reco_alt*means_unbinnedx/means_unbinnedx_reco)/means_unbinnedx_alt,marker='o',ls="",label="IBU + Bin Correction")
        ax1.plot(0.5*(np.array(binvals[0:-1])+np.array(binvals[1:])),(means_binnedx_alt*means_unbinnedx/means_binnedx)/means_unbinnedx_alt,marker='s',ls="",label="Naive Correction")
        #ax1.plot(0.5*(np.array(binvals[0:-1])+np.array(binvals[1:])),means_unbinnedx_omnifold_alt/means_unbinnedx_alt,marker='s',ls="",label="OmniFold")
        ax1.set_xlabel("Jet $p_{T}$ [GeV]",fontsize=20)
        ax1.axhline(1.,ls=":",color="black")
        ax1.set_ylim([0.85,1.2])
        ax1.legend(frameon=True,ncol=2,fontsize=14)
        plt.ylabel('Ratio to \n Unbinned Herwig',fontsize=20)
        fig.show()

In [None]:
gauss_data = np.random.normal(0,1,100000)
gauss_sim = np.random.normal(-0.5,1,100000)

In [None]:
def weighted_binary_crossentropy(y_true, y_pred):
    weights = tf.gather(y_true, [1], axis=1) # event weights
    y_true = tf.gather(y_true, [0], axis=1) # actual y_true for loss
    
    # Clip the prediction value to prevent NaN's and Inf's
    epsilon = K.epsilon()
    y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
    t_loss = -weights * ((y_true) * K.log(y_pred) +
                         (1 - y_true) * K.log(1 - y_pred))
    return K.mean(t_loss)

In [None]:
losses = []
for lambda1 in np.linspace(-1,1,20):

    xvals_1 = np.concatenate([gauss_data,gauss_sim])
    yvals_1 = np.concatenate([np.ones(len(gauss_data)),np.zeros(len(gauss_sim))])
    weights_1 = np.concatenate([np.ones(len(gauss_data)),np.exp(lambda1*gauss_sim)*len(gauss_data)/np.sum(np.exp(lambda1*gauss_sim))])

    X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

    Y_train_2 = np.stack((Y_train_1, w_train_1), axis=1)
    Y_test_2 = np.stack((Y_test_1, w_test_1), axis=1)

    inputs = Input((1, ))
    hidden_layer_1 = Dense(50, activation='relu')(inputs)
    hidden_layer_2 = Dense(50, activation='relu')(hidden_layer_1)
    hidden_layer_3 = Dense(50, activation='relu')(hidden_layer_2)
    outputs = Dense(1, activation='sigmoid')(hidden_layer_3)
    model = Model(inputs=inputs, outputs=outputs)
    
    model.compile(loss=weighted_binary_crossentropy, optimizer='Adam', metrics=['accuracy'])
    model.fit(X_train_1,
              Y_train_2,
              epochs=10,
              batch_size=1000,
              verbose=1)
    losses+=[model.history.history['loss'][-1]]

In [None]:
lambda1 = 1

xvals_1 = np.concatenate([gauss_data,gauss_sim])
yvals_1 = np.concatenate([np.ones(len(gauss_data)),np.zeros(len(gauss_sim))])
weights_1 = np.concatenate([np.ones(len(gauss_data)),np.exp(lambda1*gauss_sim)*len(gauss_data)/np.sum(np.exp(lambda1*gauss_sim))])

X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

_,_,_=plt.hist(X_test_1[Y_test_1==1],bins=np.linspace(-4,4,20),alpha=0.5,label="data")
_,_,_=plt.hist(X_test_1[Y_test_1==0],bins=np.linspace(-4,4,20),alpha=0.5,label="MC")
_,_,_=plt.hist(X_test_1[Y_test_1==0],bins=np.linspace(-4,4,20),weights=w_test_1[Y_test_1==0],histtype="step",color="black",ls=":",label="weighted MC")
plt.legend(fontsize=15)
plt.ylim([0,8000])

In [None]:
plt.plot(np.linspace(-1,1,20),losses)

In [None]:
#GW,_,_ = plt.hist(w_true,weights=weightsW[-1,1],bins=np.linspace(0,1,30),histtype="step",ls="-",color="black",label="rw")
#OFW,_,_  = plt.hist(w_true_alt,bins=np.linspace(0,1,30),ls=":",histtype="step",color="black",lw=3,label="target")
#TW,_,_  = plt.hist(w_true,bins=np.linspace(0,1,30),alpha=0.2,color="blue",label="prior")
#plt.legend()
#plt.xlabel("jet width")

In [None]:
losses_width = []
for lambda1 in np.linspace(-2,10,20):

    xvals_1 = np.concatenate([w_true_alt,w_true])
    yvals_1 = np.concatenate([np.ones(len(w_true_alt)),np.zeros(len(w_true))])
    weights_1 = np.concatenate([np.ones(len(w_true_alt)),np.exp(lambda1*w_true)*len(w_true_alt)/np.sum(np.exp(lambda1*w_true))])

    X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

    Y_train_2 = np.stack((Y_train_1, w_train_1), axis=1)
    Y_test_2 = np.stack((Y_test_1, w_test_1), axis=1)

    inputs = Input((1, ))
    hidden_layer_1 = Dense(50, activation='relu')(inputs)
    hidden_layer_2 = Dense(50, activation='relu')(hidden_layer_1)
    hidden_layer_3 = Dense(50, activation='relu')(hidden_layer_2)
    outputs = Dense(1, activation='sigmoid')(hidden_layer_3)
    model = Model(inputs=inputs, outputs=outputs)
    
    model.compile(loss=weighted_binary_crossentropy, optimizer='Adam', metrics=['accuracy'])
    model.fit(X_train_1,
              Y_train_2,
              epochs=10,
              batch_size=1000,
              verbose=1)
    losses_width+=[model.history.history['loss'][-1]]

In [None]:
losses_width2 = []
for lambda1 in np.linspace(-2,10,20):

    xvals_1 = np.concatenate([w_true_alt,w_true])
    yvals_1 = np.concatenate([np.ones(len(w_true_alt)),np.zeros(len(w_true))])
    weights_1 = np.concatenate([np.ones(len(w_true_alt)),np.exp(lambda1*w_true**2)*len(w_true_alt)/np.sum(np.exp(lambda1*w_true**2))])

    X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

    Y_train_2 = np.stack((Y_train_1, w_train_1), axis=1)
    Y_test_2 = np.stack((Y_test_1, w_test_1), axis=1)

    inputs = Input((1, ))
    hidden_layer_1 = Dense(50, activation='relu')(inputs)
    hidden_layer_2 = Dense(50, activation='relu')(hidden_layer_1)
    hidden_layer_3 = Dense(50, activation='relu')(hidden_layer_2)
    outputs = Dense(1, activation='sigmoid')(hidden_layer_3)
    model = Model(inputs=inputs, outputs=outputs)
    
    model.compile(loss=weighted_binary_crossentropy, optimizer='Adam', metrics=['accuracy'])
    model.fit(X_train_1,
              Y_train_2,
              epochs=10,
              batch_size=1000,
              verbose=1)
    losses_width2+=[model.history.history['loss'][-1]]

In [None]:
plt.plot(np.linspace(-2,10,20),losses_width)
plt.axvline(2.5)

In [None]:
lambda1 = 2.5

xvals_1 = np.concatenate([w_true_alt,w_true])
yvals_1 = np.concatenate([np.ones(len(w_true_alt)),np.zeros(len(w_true))])
weights_1 = np.concatenate([np.ones(len(w_true_alt)),np.exp(lambda1*w_true)*len(w_true_alt)/np.sum(np.exp(lambda1*w_true))])

X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

_,_,_=plt.hist(X_test_1[Y_test_1==1],bins=np.linspace(0,1,30),alpha=0.5,label="data")
_,_,_=plt.hist(X_test_1[Y_test_1==0],bins=np.linspace(0,1,30),alpha=0.5,label="MC")
_,_,_=plt.hist(X_test_1[Y_test_1==0],bins=np.linspace(0,1,30),weights=w_test_1[Y_test_1==0],histtype="step",color="black",ls=":",label="weighted MC")
plt.legend(fontsize=15)

In [None]:
for lambda1 in np.linspace(-1,10,20):

    xvals_1 = np.concatenate([w_true_alt,w_true])
    yvals_1 = np.concatenate([np.ones(len(w_true_alt)),np.zeros(len(w_true))])
    weights_1 = np.concatenate([np.ones(len(w_true_alt)),np.exp(lambda1*w_true)*len(w_true_alt)/np.sum(np.exp(lambda1*w_true))])

    X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

    print(lambda1,np.mean(X_test_1[Y_test_1==1]),np.mean(X_test_1[Y_test_1==0]),np.average(X_test_1[Y_test_1==0],weights=w_test_1[Y_test_1==0]))

In [None]:
for lambda1 in np.linspace(-1,10,20):

    xvals_1 = np.concatenate([w_true_alt,w_true])
    yvals_1 = np.concatenate([np.ones(len(w_true_alt)),np.zeros(len(w_true))])
    weights_1 = np.concatenate([np.ones(len(w_true_alt)),np.exp(lambda1*w_true**2)*len(w_true_alt)/np.sum(np.exp(lambda1*w_true**2))])

    X_train_1, X_test_1, Y_train_1, Y_test_1, w_train_1, w_test_1 = train_test_split(xvals_1, yvals_1, weights_1)

    print(lambda1,np.mean(X_test_1[Y_test_1==1]**2),np.mean(X_test_1[Y_test_1==0]**2),np.average(X_test_1[Y_test_1==0]**2,weights=w_test_1[Y_test_1==0]))

In [None]:
losses_width = np.array(losses_width)
losses_width2 = np.array(losses_width2)
plt.plot(np.linspace(-2,10,20),losses_width)
plt.plot(np.linspace(-2,10,20),losses_width2)
plt.axvline(2.5,color="tab:blue",ls=":")
plt.axvline(4.5,color="tab:orange",ls=":")
plt.ylim([0.65,0.7])