# Import Libraries

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
# disable tensorflow log level infos
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # show only errors

import sys
import pandas as pd

if '..' not in sys.path:
    sys.path.insert(0,'..')

In [3]:
from src.base.data_loaders.data_loader import DLName
from src.base.gt_loaders.gt_names import GTName

from src.base.experiment.training.base_models import BaseModel
from src.base.experiment.training.model_creator import Optimizer

from src.base.experiment.dataset.benchmark_dataset import BenchmarkDataset
from src.base.experiment.evaluation.model_evaluator import DataSource, DataPredSelection

from src.exp_runner import ExperimentRunner

from src.m_utils.mtl_approach import MTLApproach

from src.base.experiment.tasks.task import ICAO_REQ
from src.base.experiment.tasks.task import MNIST_TASK
from src.base.experiment.tasks.task import FASHION_MNIST_TASK
from src.base.experiment.tasks.task import CIFAR_10_TASK
from src.base.experiment.tasks.task import CELEB_A_TASK

 ==> Restrict GPU memory growth: True


# Start Network runner

In [4]:
kwargs = { 
    'use_neptune': True,
    'exp_params' : {
        'name': 'train_vgg16_mtl_3_approach',
        'description': 'Training MTL network for MNIST tasks with Handcrafted 3 network architecture',
        'tags': ['mtl', 'handcrafted_3', 'handcrafted_3_exp', 'mnist', 'benchmark', '10 epochs'],
        'src_files': ["../src/**/*.py"]
    },
    'properties': {
        'approach': MTLApproach.HAND_3,
        'benchmarking': {
            'use_benchmark_data': True,
            'benchmark_dataset': BenchmarkDataset.MNIST,
            'tasks': list(MNIST_TASK)
        },
        'icao_data': {
            'icao_gt': {
                'use_gt_data': False,
                'gt_names': {
                    'train_validation': [],
                    'test': [],
                    'train_validation_test': [GTName.FVC]
                },
            },
            'icao_dl': {
                'use_dl_data': False,
                'tagger_model': None
            },
            'reqs': list(ICAO_REQ),
            'aligned': False,
        },
        'balance_input_data': False,
        'train_model': True,
        'save_trained_model': False,
        'exec_nas': False,
        'orig_model_experiment_id': '',
        'sample_training_data': False,
        'sample_prop': 1.0
    },
    'net_train_params': {
        'base_model': BaseModel.VGG16,
        'batch_size': 32,
        'n_epochs': 10,
        'early_stopping': 200,
        'learning_rate': 1e-3,
        'optimizer': Optimizer.ADAMAX,
        'dropout': 0.3
    },
    'nas_params': {}
}

runner = ExperimentRunner(**kwargs)

-------------------- Init ExperimentRunner -------------------
---------------------------
Parent Process ID: 111753
Process ID: 126096
---------------------------
-----
Use Neptune:  True
-----
-------------------
Args: 
{'exp_params': {'description': 'Training MTL network for MNIST tasks with '
                               'Handcrafted 3 network architecture',
                'name': 'train_vgg16_mtl_3_approach',
                'src_files': ['../src/**/*.py'],
                'tags': ['mtl',
                         'handcrafted_3',
                         'handcrafted_3_exp',
                         'mnist',
                         'benchmark',
                         '10 epochs']},
 'nas_params': {},
 'net_train_params': {'base_model': <BaseModel.VGG16: {'name': 'vgg16', 'target_size': (224, 224), 'prep_function': <function preprocess_input at 0x7f28a339f820>}>,
                      'batch_size': 32,
                      'dropout': 0.3,
                      'early_stoppin

# Load Data

In [5]:
runner.load_training_data()

-------------------- load training data -------------------
Loading data
TrainData.shape: (48000, 11)
ValidationData.shape: (12000, 11)
TestData.shape: (10000, 11)
Data loaded


# Sampling Training Data

In [6]:
runner.sample_training_data()

-------------------- sample training data -------------------
Not applying subsampling in training data!


# Data Balancing

In [7]:
runner.balance_input_data()

-------------------- balance input data -------------------
Not balancing input_data


# Data Generators

In [8]:
runner.setup_data_generators()

-------------------- setup data generators -------------------
Starting data generators
Found 48000 validated image filenames.
Found 12000 validated image filenames.
Found 10000 validated image filenames.
TOTAL: 70000

Logging class indices
 .. MTL model not logging class indices!

Using benchmarking dataset. Not logging class labels!


# Setup Experiment

In [9]:
runner.setup_experiment()

-------------------- create experiment -------------------
Setup neptune properties and parameters
Properties and parameters setup done!


# Labels Distribution

In [10]:
runner.summary_labels_dist()

-------------------- summary labels dist -------------------
Using benchmark data. Not doing summary_labels_dist()


# Create Model

In [11]:
runner.create_model()

-------------------- create model -------------------
Creating model...
Model created


# Vizualize Model

In [12]:
%%capture
runner.visualize_model(outfile_path=f"figs/handcrafted_mtl_model_3.png")

In [13]:
%%capture
runner.model_summary()

# Training Model

In [None]:
runner.train_model()

-------------------- train model -------------------
Training VGG16 network
 .. Not fine tuning base model...
  .. Total params: 17,014,484
  .. Trainable params: 2,299,796
  .. Non-trainable params: 14,714,688
Epoch 1/10

# Plots

In [None]:
runner.draw_training_history()

# Load Best Model

In [None]:
runner.load_best_model()

# Saving Trained Model

In [None]:
runner.save_model()

# Test Trained Model

## Validation Split

In [None]:
runner.set_model_evaluator_data_src(DataSource.VALIDATION)
runner.test_model(verbose=False)

## Test Split

In [None]:
runner.set_model_evaluator_data_src(DataSource.TEST)
runner.test_model(verbose=False)

# Visualize Model Classification

# Finishing Experiment Manager

In [None]:
runner.finish_experiment()