In [None]:
%run pet_classifier_modules.ipynb

def pet_classifier_training(datasetSize=800, epochs=600, pretrained=False):

  imgHeight = 64
  imgWidth = 64
  colorChannels = 1
  use = "train"
  deviceType = torch.device("cpu")
  cnnParamterPath = 'params/parametersCNN/'
  mlpParamterPath = 'params/parametersMLP/'


  isConvLayer =          [True, False, True, False]
  filterCounts =         [2, 2, 4, 4] # THE NUMBER OF FILTERS IN A POOLING KERNEL MUST MATCH THE NUMBER OF FILTERS IN THE PRECEEDING CONVOLUTIONAL LAYER KERNEL
  kernelShapes =         [(5, 5), (2, 2), (3, 3), (2, 2)]
  kernelStrides =        [1, 2, 1, 2]
  CNNactivationFunctions =  ["leakyReLU", "none", "leakyReLU", "none"]

  neuronCounts =        [1]
  MLPactivationFunctions = ["sigmoid"]

  CNNmodelConfig = {
    "is_conv_layer": isConvLayer,
    "filter_counts": filterCounts,
    "kernel_shapes": kernelShapes,
    "kernel_strides": kernelStrides,
    "CNN_activation_functions": CNNactivationFunctions
  }
  MLPmodelConfig = {
    "neuron_counts": neuronCounts,
    "MLP_activation_functions": MLPactivationFunctions
  }

  CNNHyperParameters = {
    "learn_rate": 0.01,
    "batch_size": 24,
    "loss_func": "BCELoss",
    "reduction": "mean",
    "optimizer": "adam",
    "lambda_L2": 1e-6,
    "dropout_rate": None
  }
  MLPHyperParameters = {
    "learn_rate": 0.01,
    "optimizer": "adam",
    "lambda_L2": 1e-6,
    "dropout_rate": None
  }



  (imgBatch, labelBatch) = genEE364PetImageStack(datasetSize, imgHeight, imgWidth, colorChannels, use, deviceType)

  if not pretrained:
    cnn = CNN(pretrained=False, training=True, device_type=deviceType, hyperparameters=CNNHyperParameters, mlp_hyperparameters=MLPHyperParameters, input_data_dim=(colorChannels, imgHeight, imgWidth), cnn_model_config=CNNmodelConfig, mlp_model_config=MLPmodelConfig)
  else:
    cnn = CNN(pretrained=True, training=True, device_type=deviceType, hyperparameters=CNNHyperParameters, mlp_hyperparameters=MLPHyperParameters, cnn_model_params=fetchCNNParametersFromFile(deviceType, cnnParamterPath), mlp_model_params=fetchMLPParametersFromFile(deviceType, mlpParamterPath))


  (epochPlt, lossPlt) = cnn.train(imgBatch, labelBatch, epochs=epochs, save_params=True)
  plotTrainingResults(epochPlt, lossPlt) if epochPlt else None


  pth_to_pkl('params/parametersCNN' , 'pet_classifier_trainedModel_NOT_USED/parametersCNN')
  pth_to_pkl('params/parametersMLP' , 'pet_classifier_trainedModel_NOT_USED/parametersMLP')


