In [67]:
from jet_ML_tools import *
from data_import import data_import

%matplotlib inline
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np

In [68]:
# specify the file inputs
n_files = 1
n_ev_per_file = 10000
s_range = range(1,2)

# read in the data
nevents = n_files * n_ev_per_file
data = data_import(data_type='jetimage', seed_range=s_range, path='../images/', nevents=nevents, img_size=33, channels=[0])

# read in the labels and split the data by quark and gluon
g_data, q_data = data[:10000], data[10000:]
labels = to_categorical(np.concatenate((np.zeros(nevents),np.ones(nevents))), num_cat = 2)

# perform the train test split
x_train, y_train, x_val, y_val, x_test, y_test = data_split(data, labels, val_frac = 0.1, test_frac = 0.1)

In [82]:
# Adapted from: github.com/derylucio/weaksupervision

from keras.optimizers import Adam

# Weak supervision loss function
def weak_loss_function(ytrue, ypred):
    return K.square(K.sum(ypred)/ypred.shape[0] - K.sum(ytrue)/ypred.shape[0])

# Generates batches for training with keras
def weak_data_generator(samples, output):
    while True:
        for i in range(len(samples)):
            yield samples[i], output[i]

# Train a network with weak supervision
# UNDER CONSTRUCTION - EMM
def trainCNN(samples, targets, layersize = 30, nb_epoch = 10, learning_rate = 0.001, val_frac = 0.1, weak = True):
    
    listX_train, listX_val = [], []
    listf_train, listf_val = [], []
    
    # split the data and fractions into training and validation sets
    for X,f in zip(samples,targets):
        X_train, f_train, X_val, f_val = data_split(X, f, val_frac = val_frac, test_frac=0.0)
        listX_train.append(X_train)
        listf_train.append(f_train)
        listX_val.append(X_val)
        listf_val.append(f_val)
    
    trainsize = sum([X.shape[0] for X in listX_train])
        
    # construct the CNN
    hps = {'batch_size': 128, 'img_size': 33, 'nb_conv': [8,4,4], 'nb_filters': [64, 64, 64],
           'nb_neurons': 128, 'nb_pool': [2, 2, 2], 'dropout': [.25, .5, .5, .5],
           'nb_channels': 1, 'patience': 3, 'out_dim' : targets[0].shape[1]}
    
    CNN_model = conv_net_construct(hps, compiled = False)
    earlystopper = EarlyStopping(monitor="val_loss", patience= hps['patience'])
    
    if weak:
        CNN_model.compile(loss = weak_loss_function, optimizer=Adam(lr=learning_rate), metrics = ['accuracy']) 
    else:
        CNN_model.compile(loss="categorical_crossentropy", optimizer=Adam(lr=learning_rate), metrics = ['accuracy'])
    
    CNN_model.fit_generator(generator = weak_data_generator(listX_train, listf_train), samples_per_epoch = trainsize,
                             nb_epoch = nb_epoch, validation_data = weak_data_generator(listX_val, listf_val), 
                             nb_val_samples = len(listX_val), callbacks = [earlystopper])

    return CNN_model