<a href="https://colab.research.google.com/github/mnansary/pyF2O/blob/master/colab_gen_unet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# colab specific task
*   mount google drive
*   TPU check
*   Check TF version

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:

# tpu check
import os
import pprint
import tensorflow as tf

if 'COLAB_TPU_ADDR' not in os.environ:
  print('ERROR: Not connected to a TPU runtime; please see the first cell in this notebook for instructions!')
else:
  TPU_ADDRESS = 'grpc://' + os.environ['COLAB_TPU_ADDR']
  print ('TPU address is', TPU_ADDRESS)

  with tf.Session(TPU_ADDRESS) as session:
    devices = session.list_devices()
    
  print('TPU devices:')
  pprint.pprint(devices)



TPU address is grpc://10.42.124.66:8470
TPU devices:
[_DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:CPU:0, CPU, -1, 9360561929670052793),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 17179869184, 9631908168161504717),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 4283045023088742688),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 4476928320087680980),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 14323803195418573646),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:3, TPU, 17179869184, 9045814018959949270),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:4, TPU, 17179869184, 13455168068812267422),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:5, TPU, 17179869184, 6502888340727842632),
 _DeviceAttributes(/job:tpu_worker/replica:0/task:0/device:TPU:6, TPU, 17179869184, 18112862821238

In [0]:
!pip3 install tensorflow==1.14.0



# GCS specific task 
* **auth user**
* **save** and **upload** credentials to **tpu**
* set project information


In [0]:
# auth user for cloud SDK
from google.colab import auth
auth.authenticate_user()

In [0]:
# Save credentials
import json
SERVICE_KEY_PATH='/content/adc.json' # @param
# Upload credentials to TPU.
with tf.Session(TPU_ADDRESS) as sess:    
    with open(SERVICE_KEY_PATH, 'r') as f:
        auth_info = json.load(f)
        tf.contrib.cloud.configure_gcs(sess, credentials=auth_info)
# set service_account
JSON_DATA=json.load(open(SERVICE_KEY_PATH))
SERVICE_ACCOUNT=str(JSON_DATA['client_id']).split('.')[0]
print('Service Account:',SERVICE_ACCOUNT)

The TensorFlow contrib module will not be included in TensorFlow 2.0.
For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
  * https://github.com/tensorflow/io (for I/O related ops)
If you depend on functionality not listed there, please file an issue.

Service Account: 32555940559


#### SET PROJECT INFORMATION 

In [0]:
PROJECT_ID    ='f2oitpu'               # @param 
BUCKET        ='f2odata'               # @param 

# LIST FILES
TFRECORDS_DIR= 'gs://{}/{}/'.format(BUCKET,'tfrecord')
#
# change TFRECORDS_DIR specific to structre
#
!gcloud config set project {PROJECT_ID}
!gsutil ls {TFRECORDS_DIR}

Updated property [core/project].
gs://f2odata/tfrecord/eval/
gs://f2odata/tfrecord/train/


# UNET generator Model Training

## Data Pipeline
* set FLAGS and PARAMS
* define input functions

#### FLAGS AND PARAMS

In [0]:
from glob import glob

class FLAGS:
    IMAGE_DIM       = 256 # @param
    NB_CHANNELS     = 3   # @param
    BATCH_SIZE      = 128 # @param
    SHUFFLE_BUFFER  = 1000 # @param
    
NB_TOTAL_DATA       = 16128 # @param
NB_EVAL_DATA        = 3200  # @param
N_EPOCHS            = 10   # @param
LEARNING_RATE       = 1e-2  # @param

STEPS_PER_EPOCH     =  NB_TOTAL_DATA // FLAGS.BATCH_SIZE 
VALIDATION_STEPS    =  NB_EVAL_DATA  // FLAGS.BATCH_SIZE 

print('Steps Per epoch:',STEPS_PER_EPOCH)
print('Validation Steps:', VALIDATION_STEPS)



Steps Per epoch: 126
Validation Steps: 25


#### Data Input Functions
* get **bucket** 
* define **train_in_fn()** and **eval_in_fn()**
* **NOTE:AVOID USING PARTIALS** 
 

In [0]:
from google.cloud import storage
client = storage.Client(PROJECT_ID)
# get bucket from the project
bucket=client.get_bucket(BUCKET)
print(bucket)
def data_input_fn(FLAGS,mode): 
    
    def _parser(example):
        feature ={  'image'  : tf.io.FixedLenFeature([],tf.string) ,
                    'target' : tf.io.FixedLenFeature([],tf.string)
        }    
        parsed_example=tf.io.parse_single_example(example,feature)
        image_raw=parsed_example['image']
        image=tf.image.decode_png(image_raw,channels=FLAGS.NB_CHANNELS)
        image=tf.cast(image,tf.float32)/255.0
        image=tf.reshape(image,(FLAGS.IMAGE_DIM,FLAGS.IMAGE_DIM,FLAGS.NB_CHANNELS))
        
        target_raw=parsed_example['target']
        target=tf.image.decode_png(target_raw,channels=FLAGS.NB_CHANNELS)
        target=tf.cast(target,tf.float32)/255.0
        target=tf.reshape(target,(FLAGS.IMAGE_DIM,FLAGS.IMAGE_DIM,FLAGS.NB_CHANNELS))
        
        return image,target

    dataset = tf.data.TFRecordDataset([os.path.join('gs://{}/'.format(BUCKET), f.name) for f in bucket.list_blobs(prefix='tfrecord/{}'.format(mode))])
    dataset = dataset.shuffle(FLAGS.SHUFFLE_BUFFER,reshuffle_each_iteration=True)
    dataset = dataset.map(_parser)
    dataset = dataset.repeat()
    dataset = dataset.batch(FLAGS.BATCH_SIZE,drop_remainder=True)
    return dataset

def train_in_fn():
  return data_input_fn(FLAGS,'train')
def eval_in_fn():
  return data_input_fn(FLAGS,'eval')


<Bucket: f2odata>


## Change to git repo dir

In [0]:
cd /content/gdrive/My\ Drive/PROJECTS/F2O/pyF2O/

/content/gdrive/My Drive/PROJECTS/F2O/pyF2O


## CREATE MODEL
* **resolve** TPU cluster
* define **strategy**
* **compile** within **strategy scope**




In [0]:
from tensorflow.keras.optimizers import Adam
import tensorflow.keras.backend as K 
from F2O.generators import unet

tf.logging.set_verbosity(tf.logging.INFO)

resolver = tf.contrib.cluster_resolver.TPUClusterResolver('grpc://' + os.environ['COLAB_TPU_ADDR'])
tf.contrib.distribute.initialize_tpu_system(resolver)
strategy = tf.contrib.distribute.TPUStrategy(resolver)

def patchLoss(y_true,y_pred):
  thresh=0.31
  weight_factor=1e-2
  # image construction loss
  back_loss=weight_factor*K.abs(y_true[y_true < thresh]-y_pred[y_true < thresh])
  # manipulated region loss  
  block_loss=K.abs(y_true[y_true >= thresh]-y_pred[y_true >= thresh]) 
  
  patch_loss=back_loss+block_loss
  
  return tf.keras.layers.Activation('relu')(patch_loss) 
  



with strategy.scope():
  model = unet(image_dim=FLAGS.IMAGE_DIM)
  model.compile(optimizer=Adam(learning_rate=LEARNING_RATE),loss=patchLoss)

model.summary()



INFO:tensorflow:Initializing the TPU system.
INFO:tensorflow:Finished initializing TPU system.
INFO:tensorflow:Querying Tensorflow master (grpc://10.42.124.66:8470) for TPU system metadata.
INFO:tensorflow:Found TPU system:
INFO:tensorflow:*** Num TPU Cores: 8
INFO:tensorflow:*** Num TPU Workers: 1
INFO:tensorflow:*** Num TPU Cores Per Worker: 8
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, -1, 9360561929670052793)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 17179869184, 4283045023088742688)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 17179869184, 4476928320087680980)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 17179869184, 14323803195418573646)
INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 1717

## Train
* define **checkpoints** and **callbacks** (tensorboard avoided)
* train model

In [0]:
from tensorflow.keras.callbacks import ModelCheckpoint
import h5py
model_dir=os.path.join(os.getcwd(),'F2O','model_weights')
model_name='uNet' # @param
checkpoint = ModelCheckpoint(filepath=os.path.join(model_dir,'{}.h5'.format(model_name)), verbose=1, save_best_only=True)
history =  model.fit(
              train_in_fn(),
              epochs= N_EPOCHS,
              steps_per_epoch= STEPS_PER_EPOCH,
              validation_data=eval_in_fn(),
              validation_steps= VALIDATION_STEPS,
              callbacks=[checkpoint],
              verbose=1
            )
    

ValueError: ignored

#### Plot Training Histoty

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('LOSS History')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'validation'], loc='upper left')
plt.savefig(os.path.join(model_dir,'{}_history.png'.format(model_name)))