Skip to content

izharikov/state-farm-distracted-driver-detection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

99 Commits
 
 
 
 
 
 
 
 

Repository files navigation

State farm distracted driver detection

Overview

Implementation of Kaggle driver detection

CNN Models implemented

Train and evaluate

Downloading and splitting data

Using kaggle API:

./src/colab_config.sh

Train

python train.py [...options]

Example of train:

python train.py --model vgg16 --width 224 --optimizer adam --lr 1e-5 --batch 16 --epochs 50

Options

  • --model <model>
    type of model used. <model> one of the following: simple, vgg16, vgg19, inception_v3, xception. Default us simple.
  • --epochs <number>
    <number> - number of training epochs. Default is 20.
  • --width <width>
    Width of image, that will be input of model. Images from dataset are resized to this <width>. Default is 150.
  • --optimizer <optimizer>
    Optimizer, used in train process. <optimizer> one of the following: adam, sgd, rmsprop
  • --summary [True|False]
    Print summary of model. Default is False
  • --lr <learning_rate>
    Set learning rate for <optimizer>. Default value is 5e-5
  • --weight_path <path>
    Path to weights. If specified, weight from file is loaded, if not - weights initialized randomly.
  • --fc <count_of_layers>
    Count of fully connected layers, used in fine-tuning. Default is 2.
  • --fc_dim <dimension>
    Dimension of fully connected layers. Default is 4096.
  • --dropout <dropout>
    Dropout after each fully-connected layer. If < 0, than no dropout layers added. Default is 0.5.
  • --batch <batch>
    Batch size in train process. Default is 32.

Evaluation

Files predict.sh and predict-all.sh.

predict.sh

Second param: file name to saved weights in /content/drive/models folder:

./predict.sh --path_to_model vgg-04-acc-0.9722-loss-0.0997.hdf5 --model vgg16 --width 224

predict-all.sh

For all saved models in folder /content/drive/models make prediction and save to /content/drive/submissions:

./predict-all.sh --model vgg16 --width 224

Environment

Training was running on Colaboratory Google Platform on GPU environment.

About

State farm distracted driver detection solution

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published