# Train second set of one vs. rest (OVR) classifiers.

We train another set of classifiers that are used for classifications.  
These classifiers are trained using similar images for each target class; similarities between classes are computed in *classifier_similarity.ipynb*.

## Set up

In [1]:
import os
import sys

import numpy as np

import pandas as pd
import glob

import warnings
warnings.filterwarnings('ignore')

In [2]:
BASE_MODEL_PATH="trained_model"
%mkdir -p $BASE_MODEL_PATH

In [3]:
from models.modelutils import ModelCompiler

Using TensorFlow backend.


In [4]:
compiler = ModelCompiler(BASE_MODEL_PATH)

In [5]:
from models.processor import create_generators

TRAIN_DATAGEN, VALID_DATAGEN = create_generators()

In [6]:
from models.modelutils import dir2filedict_sorted
import random

Load category and file path information.

In [7]:
trdict = dir2filedict_sorted("data_fgvc/train")
valdict = dir2filedict_sorted("data_fgvc/valid")
categories = [str(i) for i in range(0, 100)]

In [8]:
valdict['0'][0:5]

['data_fgvc/valid/0/0062781.jpg',
 'data_fgvc/valid/0/0113201.jpg',
 'data_fgvc/valid/0/0450014.jpg',
 'data_fgvc/valid/0/0602177.jpg',
 'data_fgvc/valid/0/0716386.jpg']

Here is expected outputs.   
All the outputs in {*train.ipynb*, *classifier_similarity.ipynb*, *train_multiclass_classifier.ipynb*, *train_second.ipynb*} must be the same. 

['data_fgvc/valid/0/0062781.jpg',  
 'data_fgvc/valid/0/0113201.jpg',  
 'data_fgvc/valid/0/0450014.jpg',  
 'data_fgvc/valid/0/0602177.jpg',  
 'data_fgvc/valid/0/0716386.jpg']

## Train second level classifiers

Define a class for training second level classifiers.

In [9]:
from models.one_vs_all import OneVsAllModelTrainer
from models.modelutils import split_files

In [10]:
trainer = OneVsAllModelTrainer(TRAIN_DATAGEN, VALID_DATAGEN)

In [11]:
from models.one_vs_all import FilesPair, TrValFiles

In [12]:
class SecondLevelClassifierTrainer:
    def __init__(self, base_model_name, basedir, trainer, compiler):
        self.base_model_name = base_model_name
        self.basedir = basedir

        self.compiler = compiler
        self.trainer = trainer
        
    def setup_filedict(self, train_files_dict, valid_files_dict):
        self.train_files_dict = train_files_dict
        self.valid_files_dict = valid_files_dict
        self.valid_files_dict_org = self.valid_files_dict
        
    def _model_path(self, target_key):
        return os.path.join(self.basedir, "{}_{}".format(self.base_model_name, target_key))
    
    def _split_by_set(self, target_key, false_keyset, files_dict):
        trues = files_dict[target_key]
        falses = [path for key in false_keyset for path in files_dict[key]]
        return FilesPair(trues, falses)
    
    def _split_files(self, targetkey, files_dict):
        return FilesPair(*split_files(targetkey, files_dict))
    
    def train_second_level(self, target_key, highcat_keyset, eachepochs=10, retrainings=1, removecheckpoint=True):
        self.trainer.retrainings = retrainings
        falseset = highcat_keyset - set(target_key)
        trs = self._split_by_set(target_key, falseset, self.train_files_dict)
        vals = self._split_by_set(target_key, falseset, self.valid_files_dict)
        trvals = TrValFiles(trs, vals)
        self._train_one_core("sec_"+target_key, trvals, eachepochs, removecheckpoint)
        
    def _train_one_setup(self, model_key, trvals):
        model_save_path = self._model_path(model_key)

        model = self.compiler.generate_compiled_model(model_save_path)
        self.trainer.set_model(model)
        self.trainer.set_savepath(model_save_path)
        self.trainer.set_dataset(trvals)

    def _train_one_core(self, model_key, trvals, eachepochs, removecheckpoint):
        self._train_one_setup(model_key, trvals)

        self.trainer.train_model(eachepochs=eachepochs, hard_coded_steps_per_epoch=(100, 10))
        if removecheckpoint:
            self.trainer.remove_checkpoint()

    def remove_checkpoint(self, model_key):
        # utility method for cleaup interrupted case
        self.trainer.set_savepath(self._model_path(model_key))
        self.trainer.remove_checkpoint()

In [13]:
sec_trainer = SecondLevelClassifierTrainer("modelfgvc", BASE_MODEL_PATH, trainer, compiler)

In [14]:
sec_trainer.setup_filedict(trdict, valdict)

Load $ClassSim$ results to gather similar classes for each target class. 

In [15]:
classsim = pd.read_pickle("results/valid_sim_df_fgvc.dat")

### Execute training

In [16]:
# SIM_THRESHOLD = 0.1
# This data set is much fine grained. So threshold should be higher. We choose 0.4 for average similary class as about 18.
SIM_THRESHOLD = 0.4


In [None]:
def train_seconds(keys, eachepochs=5):
    for targetkey in keys:
        similarkeyset = set(classsim[targetkey][classsim[targetkey] >= SIM_THRESHOLD].index)
        try:
            if len(similarkeyset) == 1:
                print("no similar category. only first classifier is enough. skip second training.")
            else:
                sec_trainer.train_second_level(targetkey, similarkeyset, eachepochs=eachepochs)
        except ValueError as e:
            print("ValueError, skip {0}: {1}".format(targetkey, e))

In [None]:
train_seconds(categories[0:], eachepochs=2)

Epoch 1/2
  8/100 [=>............................] - ETA: 10:24 - loss: 0.6801 - acc: 0.5781Epoch 00001: saving model to trained_model/modelfgvc_sec_0-01-0.905.h5
  8/100 [=>............................] - ETA: 21:48 - loss: 0.6801 - acc: 0.5781 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00Epoch 2/2
  8/100 [=>............................] - ETA: 8:46 - loss: 0.6507 - acc: 0.6719Epoch 00002: saving model to trained_model/modelfgvc_sec_0-02-0.881.h5
  8/100 [=>............................] - ETA: 11:30 - loss: 0.6507 - acc: 0.6719 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00Epoch 1/2
  8/100 [=>............................] - ETA: 10:53 - loss: 0.7033 - acc: 0.5312Epoch 00001: saving model to trained_model/modelfgvc_sec_1-01-0.962.h5
  8/100 [=>............................] - ETA: 33:52 - loss: 0.7033 - acc: 0.5312 - val_loss: 0.0000e+00 - val_acc: 0.0000e+00Epoch 2/2
  4/100 [>.............................] - ETA: 8:13 - loss: 0.5058 - acc: 0.7344