# Convolutional Neural Network (CNN) Example

## I/O

In [None]:
TMVA::Tools::Instance();

// for using Keras
//gSystem->Setenv("KERAS_BACKEND","tensorflow"); 
//TMVA::PyMethodBase::PyInitialize();

auto outputFile = TFile::Open("CNN_ClassificationOutput.root", "RECREATE");

TMVA::Factory factory("TMVA_CNN_Classification", outputFile,
                      "!V:ROC:!Silent:Color:!DrawProgressBar:AnalysisType=Classification" ); 


## Load Data and Features

Input data is an image of 16x16 pixels from an EM shower (photon or electron)

In [None]:
TMVA::DataLoader * loader = new TMVA::DataLoader("dataset");

int imgSize = 8 * 8; 

for(auto i = 0; i < imgSize; i++)
     loader->AddVariable(Form("var%d",i),'F');

## Setup Dataset

In [None]:
TString inputFileName = "data/images_data.root";

auto inputFile = TFile::Open( inputFileName );

TTree *signalTree     = (TTree*)inputFile->Get("sig_tree");
TTree *backgroundTree = (TTree*)inputFile->Get("bkg_tree");

Double_t signalWeight     = 1.0;
Double_t backgroundWeight = 1.0;
   
loader->AddSignalTree    ( signalTree,     signalWeight     );
loader->AddBackgroundTree( backgroundTree, backgroundWeight );

TCut mycuts = ""; 
TCut mycutb = "";

loader->PrepareTrainingAndTestTree( mycuts, mycutb,
                                    "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V" );

## Boosted Decision Trees (BDT)

In [None]:
//Boosted Decision Trees
factory.BookMethod(loader,TMVA::Types::kBDT, "BDT",
                   "!V:NTrees=800:MinNodeSize=2.5%:MaxDepth=2:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20" );

## Deep Neural Networks (Dense and Convolutional)

In [None]:
bool useDNN = true; 
bool useCNN = true; 
bool useKeras = false; 

### Dense Neural Network (DNN)

In [None]:
if (useDNN) { 
    
     TString inputLayoutString = "InputLayout=1|1|64"; 
     TString batchLayoutString = "BatchLayout=1|128|64";
     TString layoutString ("Layout=DENSE|64|TANH,DENSE|64|TANH,DENSE|64|TANH,DENSE|64|TANH,DENSE|1|LINEAR");
                                                                                                                                                                                       
      //Training strategy
      TString training1("LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
                        "ConvergenceSteps=20,BatchSize=128,TestRepetitions=1,"
                        "MaxEpochs=20,WeightDecay=1e-4,Regularization=L2,"
                        "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.");
  
      TString trainingStrategyString ("TrainingStrategy=");
      trainingStrategyString += training1; // + "|" + training2 + "|" + training3;

      //Options                                                                                                                                                                
      TString dnnOptions("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
                          "WeightInitialization=XAVIERUNIFORM");
      dnnOptions.Append (":"); dnnOptions.Append (inputLayoutString);
      dnnOptions.Append (":"); dnnOptions.Append (batchLayoutString);
      dnnOptions.Append (":"); dnnOptions.Append (layoutString);
      dnnOptions.Append (":"); dnnOptions.Append (trainingStrategyString);

      dnnOptions += ":Architecture=Standard";
      factory.BookMethod(loader, TMVA::Types::kDL, "DL_DENSE", dnnOptions);
}

### Convolutional Neural Networks (CNN)

In [None]:
if (useCNN) { 
    
    TString inputLayoutString("InputLayout=1|8|8");                                                                                                                                     
    TString batchLayoutString("BatchLayout=128|1|64");
    TString layoutString     ("Layout=CONV|10|3|3|1|1|1|1|RELU,CONV|10|3|3|1|1|1|1|RELU,MAXPOOL|2|2|1|1,"
                              "RESHAPE|FLAT,DENSE|64|TANH,DENSE|1|LINEAR");

    //Training strategy                                                                                                                          
    TString training1("LearningRate=1e-3,Momentum=0.9,Repetitions=1,"
                     "ConvergenceSteps=20,BatchSize=128,TestRepetitions=1,"
                     "MaxEpochs=20,WeightDecay=1e-4,Regularization=None,"
                     "Optimizer=ADAM,DropConfig=0.0+0.0+0.0+0.0");
    TString trainingStrategyString ("TrainingStrategy=");
    trainingStrategyString += training1; // + "|" + training1 + "|" + training2;   }
    
    //Options                                                                                                                         
    TString cnnOptions ("!H:V:ErrorStrategy=CROSSENTROPY:VarTransform=None:"
                       "WeightInitialization=XAVIERUNIFORM");

    cnnOptions.Append(":"); cnnOptions.Append(inputLayoutString);
    cnnOptions.Append(":"); cnnOptions.Append(batchLayoutString);
    cnnOptions.Append(":"); cnnOptions.Append(layoutString);
    cnnOptions.Append(":"); cnnOptions.Append(trainingStrategyString);
    cnnOptions.Append(":Architecture=CPU");

    factory.BookMethod(loader, TMVA::Types::kDL, "DL_CNN", cnnOptions);
}

## Train Methods

In [None]:
factory.TrainAllMethods();

## Test and Evaluate Algorithms

In [None]:
factory.TestAllMethods();

In [None]:
factory.EvaluateAllMethods();    

## Plot ROC Curve

In [None]:
//%jsroot on

In [None]:
c1 = factory.GetROCCurve(loader);
c1->Draw();


In [None]:
// close outputfile to save output file
outputFile->Close();