<a href="https://colab.research.google.com/github/mnansary/pyF2O/blob/master/colab_gen_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# colab specific task
*   mount google drive
*   TPU check
*   Check TF version
*   Change to git repo 

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:

# tpu check
import os
import pprint
import tensorflow as tf
if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print ('TPU address is', TPU_ADDRESS)

  with tf.Session(TPU_ADDRESS) as session:
    devices = session.list_devices()
    
  print('TPU devices:')
  pprint.pprint(devices)

tf.__version__

In [0]:
cd /content/gdrive/My\ Drive/PROJECTS/HACT/pyACTRECOG/

# GCS specific task 
* **auth user**
* **save** and **upload** credentials to **tpu**
* set project information

In [0]:
# auth user for cloud SDK
from google.colab import auth
auth.authenticate_user()

In [0]:
# Save credentials
import json
SERVICE_KEY_PATH='/content/adc.json' # @param
# Upload credentials to TPU.
with tf.Session(TPU_ADDRESS) as sess:    
    with open(SERVICE_KEY_PATH, 'r') as f:
        auth_info = json.load(f)
        tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
# set service_account
JSON_DATA=json.load(open(SERVICE_KEY_PATH))
SERVICE_ACCOUNT=str(JSON_DATA['client_id']).split('.')[0]
print('Service Account:',SERVICE_ACCOUNT)

#### SET PROJECT INFORMATION 

In [0]:
PROJECT_ID    ='record-1106154'     # @param 
BUCKET        ='tfalldata'          # @param 
TFIDEN        ='TFRECORD'            # @param
# LIST FILES
TFRECORDS_DIR= 'gs://{}/{}/'.format(BUCKET,TFIDEN)
!gcloud config set project {PROJECT_ID}
!gsutil ls {TFRECORDS_DIR}


# ConvNet3D Model Training

#### Data 
* Set **FLAGS** and **PARAMS**
* Create **Train** and **Eval** Data Generator


In [0]:
import sys
sys.path.append('.')
import numpy as np 

class FLAGS:
    BATCH_SIZE      = 32  #@param
    IMAGE_DIM       = 64   #@param
    NB_CHANNELS     = 3    #@param
    MIN_SEQ_LEN     = 6    #@param
    NB_CLASSES      = 17   #@param
    SHUFFLE_BUFFER  = 6400 #@param

MODEL_DIR           = '/content/gdrive/My Drive/PROJECTS/HACT/Model/' # @param
MODEL_NAME          = 'convNet3D' # @param
EPOCHS              =  250           # @param
NB_TRAIN_DATA       =  49920       # @param
NB_EVAL_DATA        =  3456        # @param
NB_TOTAL_DATA       =  NB_TRAIN_DATA + NB_EVAL_DATA 
STEPS_PER_EPOCH     =  NB_TOTAL_DATA // FLAGS.BATCH_SIZE 
VALIDATION_STEPS    =  NB_EVAL_DATA  // FLAGS.BATCH_SIZE 
CHECK_DATA          =  False
LEARNING_RATE       = 1e-5 #@param

LOAD_WEIGHTS=False #@param
EPOCHS_DONE_BEFORE_RECONNECT=0  #@param
EPOCHS=EPOCHS-EPOCHS_DONE_BEFORE_RECONNECT

#### Data Generator

In [0]:
from google.cloud import storage
from functools import partial

client = storage.Client(PROJECT_ID)
# get bucket from the project
bucket=client.get_bucket(BUCKET)
print(bucket)

def data_input_fn(FLAGS,mode): 
    
    def _parser(example):
      data  ={ 'feats':tf.io.FixedLenFeature((FLAGS.MIN_SEQ_LEN,FLAGS.IMAGE_DIM,FLAGS.IMAGE_DIM,FLAGS.NB_CHANNELS),tf.float32),
                'label':tf.io.FixedLenFeature((),tf.int64)
      }    
      
      parsed_example=tf.io.parse_single_example(example,data)
      
      feats=tf.cast(parsed_example['feats'],tf.float32)
      feats=tf.reshape(feats,(FLAGS.MIN_SEQ_LEN,FLAGS.IMAGE_DIM,FLAGS.IMAGE_DIM,FLAGS.NB_CHANNELS))
      
      idx = tf.cast(parsed_example['label'], tf.int64)
      label=tf.one_hot(idx,FLAGS.NB_CLASSES,dtype=tf.int64)
      
      return feats,label

    dataset = tf.data.TFRecordDataset([os.path.join('gs://{}/'.format(BUCKET), f.name) for f in bucket.list_blobs(prefix='{}/{}'.format(TFIDEN,mode))])
    dataset = dataset.cache()
    dataset = dataset.map(_parser)
    dataset = dataset.shuffle(FLAGS.SHUFFLE_BUFFER,reshuffle_each_iteration=True)
    dataset = dataset.repeat()
    dataset = dataset.batch(FLAGS.BATCH_SIZE,drop_remainder=True)
    dataset = dataset.prefetch(-1) # autotune    
    return dataset

def train_in_fn():
    return data_input_fn(FLAGS,'Train')    

def eval_in_fn():    
    return data_input_fn(FLAGS,'Eval')


#### COMPILE MODEL




In [0]:
import sys
sys.path.append('.')
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam
from coreLib.model import LRCN

tf.logging.set_verbosity(tf.logging.INFO)
resolver = tf.contrib.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)
with strategy.scope():
  model=LRCN(seq_len=FLAGS.MIN_SEQ_LEN,
              img_dim=FLAGS.IMAGE_DIM,
              nb_channels=FLAGS.NB_CHANNELS,
              nb_classes=FLAGS.NB_CLASSES)
  model.summary()
  model.compile(optimizer=Adam(learning_rate=LEARNING_RATE),   
                loss=categorical_crossentropy,
                metrics=['accuracy'])
  if LOAD_WEIGHTS:
    model.load_weights=os.path.join(MODEL_DIR,'{}.h5'.format(MODEL_NAME))


#### Train


In [0]:
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint = ModelCheckpoint(filepath=os.path.join(MODEL_DIR,'{}.h5'.format(MODEL_NAME)), verbose=1, save_best_only=True)
history=model.fit(train_in_fn(),
                    epochs= EPOCHS,
                    steps_per_epoch=STEPS_PER_EPOCH, 
                    validation_data=eval_in_fn(),
                    validation_steps=VALIDATION_STEPS,
                    callbacks=[checkpoint],
                    verbose=1)

#### Save Model



In [0]:
model.save_weights(os.path.join(MODEL_DIR,'{}_final.h5'.format(MODEL_NAME)))

#### Plot Training Histoty

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('LOSS History')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.savefig(os.path.join(MODEL_DIR,'{}_history.png'.format(MODEL_NAME)))