In [None]:
from pyspark.context import SparkContext
from pyspark.sql.session import SparkSession
from pyspark.sql.functions import from_unixtime, when, col, lit

In [None]:
sc = SparkContext('local')
spark = SparkSession(sc)

In [None]:
%matplotlib inline
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

In [None]:
df = spark.read.load('Files/class_data.csv',format="csv", sep=",", inferSchema="true", header="true")

In [None]:
total_events = df.count()
print('There are',total_events,'events')

In [None]:
df.printSchema()

In [None]:
df.groupBy('lep_flavour').count().show()

In [None]:
df_data = df.withColumn("label",lit(0))

In [None]:
df_data.select('label').take(1)

In [None]:
df_bkg_1 = spark.read.load('Files/class_mc_ttZ.csv',format="csv", sep=",", inferSchema="true", header="true")

In [None]:
df_bkg_2 = spark.read.load('Files/class_mc_ttW.csv',format="csv", sep=",", inferSchema="true", header="true")

In [None]:
df_mc_1 = df_bkg_1.withColumn("label",lit(1))

In [None]:
df_mc_2 = df_bkg_2.withColumn("label",lit(2))

In [None]:
df_concat_0 = df_data.concat(df_bkg_1)

In [None]:
def compute_hist(data, feature, target='label', n_bins=100, x_lim=[0,3]):
        
    ## Fix the range
    data = data.where((col(feature)<=x_lim[1]) &
                      (col(feature)>=x_lim[0]))
    
    sgn = data.where(col(target)==0.0) 
    bkg = data.where(col(target)>=1.0)

    ## Compute the histograms
    bins_sgn, counts_sgn = sgn.select(feature).rdd.flatMap(lambda x: x).histogram(n_bins)
    bins_bkg, counts_bkg = bkg.select(feature).rdd.flatMap(lambda x: x).histogram(n_bins)
    
    return (bins_sgn, counts_sgn), (bins_bkg, counts_bkg)

In [None]:
%matplotlib notebook

## hist is a touple containing bins and counts foreach bin
hist_signal, hist_bkg = compute_hist(data=train, feature='Mll01', target='label', n_bins=50, x_lim=[0,3])

f, ax = plt.subplots()
ax.hist(hist_signal[0][:-1], bins=hist_signal[0], weights=hist_signal[1], alpha=0.5, label='signal')
ax.hist(hist_bkg[0][:-1], bins=hist_bkg[0], weights=hist_bkg[1], alpha=0.5, label='background')
ax.set_xlabel('$Mll01$')
ax.set_ylabel('counts')
ax.set_title("Distribution of $M_{l0l1}$")
ax.legend()
plt.show()