In [1]:
!wget https://cernbox.cern.ch/remote.php/dav/public-files/hqz8zE7oxyPjvsL/QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet
!wget https://cernbox.cern.ch/remote.php/dav/public-files/hqz8zE7oxyPjvsL/QCDToGGQQ_IMGjet_RH1all_jet0_run1_n47540.test.snappy.parquet
!wget https://cernbox.cern.ch/remote.php/dav/public-files/hqz8zE7oxyPjvsL/QCDToGGQQ_IMGjet_RH1all_jet0_run2_n55494.test.snappy.parquet

--2023-03-26 13:39:44--  https://cernbox.cern.ch/remote.php/dav/public-files/hqz8zE7oxyPjvsL/QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet
Resolving cernbox.cern.ch (cernbox.cern.ch)... 128.142.53.35, 128.142.170.17, 137.138.120.151, ...
Connecting to cernbox.cern.ch (cernbox.cern.ch)|128.142.53.35|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 184115523 (176M) [application/octet-stream]
Saving to: ‘QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet.1’


2023-03-26 13:39:59 (13.4 MB/s) - ‘QCDToGGQQ_IMGjet_RH1all_jet0_run0_n36272.test.snappy.parquet.1’ saved [184115523/184115523]



### Import Libs

In [3]:
import pandas as pd
import dask.dataframe as dd
import tensorflow as tf

### Load dataset

In [None]:
dataset_length=11111 #arbitrary number since I couldn't run the code on my machine

In [None]:
## this loads all three parquet files
df = dd.read_parquet("/kaggle/working/", engine='pyarrow')

In [1]:
class train_generator:
    def __call__(self):
        for i in range(int(dataset_length*0.8)): # train split
            yield df['X'][i], df['y'][i]
class test_generator:
    def __call__(self):
        for i in range(int(dataset_length*0.2)): # train split
            yield df['X'][-i], df['y'][-i]

In [4]:
train_ds = tf.data.Dataset.from_generator(
  train_generator(),
  output_signature=(
    # the doc file says that files are 125x125 with 3 channels
    tf.TensorSpec(shape=(125, 125, 3), dtype=tf.float32), 
    tf.TensorSpec(shape=(), dtype=tf.int32)
  )
)
test_ds = tf.data.Dataset.from_generator(
  test_generator(),
  output_signature=(
    tf.TensorSpec(shape=(125, 125, 3), dtype=tf.float32),
    tf.TensorSpec(shape=(), dtype=tf.int32)
  )
)

In [5]:
batched_train_ds=train_ds.batch(32)
batched_test_ds=test_ds.batch(32)

### CNN Model

In [6]:
model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (3,3), activation='relu',input_shape=(125, 125, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
  
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),


    tf.keras.layers.Flatten(),
   
    tf.keras.layers.Dense(256, activation='relu'),
   
  
    tf.keras.layers.Dense(1, activation='sigmoid')
])

In [7]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 123, 123, 16)      448       
                                                                 
 max_pooling2d (MaxPooling2D  (None, 61, 61, 16)       0         
 )                                                               
                                                                 
 dropout (Dropout)           (None, 61, 61, 16)        0         
                                                                 
 conv2d_1 (Conv2D)           (None, 59, 59, 32)        4640      
                                                                 
 conv2d_2 (Conv2D)           (None, 57, 57, 64)        18496     
                                                                 
 max_pooling2d_1 (MaxPooling  (None, 28, 28, 64)       0         
 2D)                                                    

In [8]:
model.compile(loss='binary_crossentropy',
              optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001),
              metrics=['accuracy',tf.keras.metrics.AUC()])

In [None]:
model.fit(
  batched_train_ds,
  validation_data=batched_test_ds,
  epochs=10
)

In [None]:
model.save_weights('./best.h5')