# GradRev in FE

In [1]:
import os

import tensorflow as tf
import numpy as np

import fastestimator as fe

## Initial Imports

In [2]:
from fastestimator.dataset import mnist, usps
from fastestimator.op.numpyop import ImageReader
from fastestimator import RecordWriter

usps_train_csv, usps_eval_csv, usps_parent_dir = usps.load_data()
mnist_train_csv, mnist_eval_csv, mnist_parent_dir = mnist.load_data()

Downloading train data to /root/fastestimator_data/USPS
100% [......................................................]    1.83 / 1.83 MB
Extracting /root/fastestimator_data/USPS/zip.train.gz
Downloading test data to /root/fastestimator_data/USPS
100% [........................................................]  0.44 / 0.44 MB
Extracting /root/fastestimator_data/USPS/zip.test.gz
Writing image data to /root/fastestimator_data/USPS/image
Data summary is saved at /root/fastestimator_data/USPS/train.csv
Data summary is saved at /root/fastestimator_data/USPS/eval.csv
Writing image data to /root/fastestimator_data/MNIST/image
Data summary is saved at /root/fastestimator_data/MNIST/train.csv
Data summary is saved at /root/fastestimator_data/MNIST/eval.csv


In [3]:
batch_size = 128
epochs = 100

In [4]:
import pandas as pd

df = pd.read_csv(mnist_train_csv)
df.columns = ['source_img', 'source_label']
df.to_csv(mnist_train_csv, index=False)

df = pd.read_csv(usps_train_csv)
df.columns = ['target_img', 'target_label']
df.to_csv(usps_train_csv, index=False)

In [5]:
from fastestimator.op.tensorop import Resize, Minmax

writer = RecordWriter(save_dir=os.path.join(os.path.dirname(mnist_parent_dir), 'dann', 'tfr'),
                      train_data=(usps_train_csv, mnist_train_csv),
                      ops=(
                          [ImageReader(inputs="target_img", outputs="target_img", parent_path=usps_parent_dir, grey_scale=True)], # first tuple element
                          [ImageReader(inputs="source_img", outputs="source_img", parent_path=mnist_parent_dir, grey_scale=True)])) # second tuple element

In [6]:
pipeline = fe.Pipeline(
    batch_size=batch_size,
    data=writer,
    ops=[
        Resize(inputs="target_img", outputs="target_img", size=(28, 28)),
        Resize(inputs="source_img", outputs="source_img", size=(28, 28)),
        Minmax(inputs="target_img", outputs="target_img"),
        Minmax(inputs="source_img", outputs="source_img")
    ]
)

## Gradient Reversal Layer

In [7]:
from tensorflow.keras import layers, Model

@tf.custom_gradient
def grad_reverse(x, l):
    def custom_grad(dy):
        return tf.math.negative(dy)*l, None    
    return tf.identity(x), custom_grad


class GradReversalLayer(tf.keras.layers.Layer):
    def __init__(self, l=1):
        super().__init__()
        self.l = l
        
    def call(self, x):
        return grad_reverse(x, self.l)
    
    def get_config(self):
        return {'l': self.l}

In [8]:
alpha = tf.Variable(0.0, dtype=tf.float32, trainable=False)

def build_feature_extractor(img_shape=(28, 28, 1)):
    x0 = layers.Input(shape=img_shape)
    x = layers.Conv2D(32, 5, activation="relu", padding="same")(x0)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(48, 5, activation="relu", padding="same")(x)
    x = layers.MaxPooling2D()(x)
    feat_map = layers.Flatten()(x)
    return Model(inputs=x0, outputs=feat_map)


def build_label_predictor(feat_dim):
    x0 = layers.Input(shape=(feat_dim,))
    x = layers.Dense(100, activation="relu")(x0)
    x = layers.Dense(100, activation="relu")(x)
    return Model(inputs=x0, outputs=x)

def build_domain_predictor(feat_dim):
    x0 = layers.Input(shape=(feat_dim,))
    x = GradReversalLayer(l=alpha)(x0)
    x = layers.Dense(100, activation="relu")(x)
    x = layers.Dense(1, activation="sigmoid")(x)
    return Model(inputs=x0, outputs=x)

In [9]:
img_shape=(28, 28, 1)
feat_dim = 7 * 7 * 48

feature_extractor = fe.build(
    model_def=lambda: build_feature_extractor(img_shape),
    model_name="feature_extractor",
    loss_name="fe_loss",
    optimizer=tf.keras.optimizers.Adam(1e-4)
)

label_predictor = fe.build(
    model_def=lambda: build_label_predictor(feat_dim),
    model_name="label_predictor",
    loss_name="fe_loss",
    optimizer=tf.keras.optimizers.Adam(1e-4)
)

domain_predictor = fe.build(
    model_def=lambda: build_domain_predictor(feat_dim),
    model_name="domain_predictor",
    loss_name="fe_loss",
    optimizer=tf.keras.optimizers.Adam(1e-4)
)

In [10]:
from fastestimator.op.tensorop.loss import Loss, BinaryCrossentropy, SparseCategoricalCrossentropy
from tensorflow.keras import losses

class FELoss(Loss):
    def __init__(self, inputs, outputs=None, mode=None):
        super().__init__(inputs=inputs, outputs=outputs, mode=mode)        
        self.label_loss_obj = losses.SparseCategoricalCrossentropy(reduction=losses.Reduction.NONE)
        self.domain_loss_obj = losses.BinaryCrossentropy(reduction=losses.Reduction.NONE)        
        
    def forward(self, data, state):
        src_c_logit, src_c_label, src_d_logit, tgt_d_logit = data
        c_loss = self.label_loss_obj(y_true=src_c_label, y_pred=src_c_logit)
        src_d_loss = self.domain_loss_obj(y_true=tf.zeros_like(src_d_logit), y_pred=src_d_logit) 
        tgt_d_loss = self.domain_loss_obj(y_true=tf.ones_like(tgt_d_logit), y_pred=tgt_d_logit)
        return c_loss + src_d_loss + tgt_d_loss

In [11]:
from fastestimator.op.tensorop.model import ModelOp
network = fe.Network(ops=[
    ModelOp(inputs="source_img", outputs="src_feat", model=feature_extractor),
    ModelOp(inputs="target_img", outputs="tgt_feat", model=feature_extractor),
    ModelOp(inputs="src_feat", outputs="src_c_logit", model=label_predictor),
    ModelOp(inputs="src_feat", outputs="src_d_logit", model=domain_predictor),
    ModelOp(inputs="tgt_feat", outputs="tgt_d_logit", model=domain_predictor),
    FELoss(inputs=("src_c_logit","source_label", "src_d_logit", "tgt_d_logit"), outputs="fe_loss")    
])

In [12]:
from fastestimator.trace import Trace
from tensorflow.python.keras import backend

class GRLWeightController(Trace):
    def __init__(self, alpha):
        super().__init__(inputs=None, outputs=None, mode="train")
        self.alpha = alpha
        
    def on_begin(self, state):
        self.total_steps = state['total_train_steps']
        
    def on_batch_begin(self, state):
        p = state['train_step'] / self.total_steps
        current_alpha = float(2.0 / (1.0 + np.exp(-10.0 * p)) - 1.0)
        backend.set_value(self.alpha, current_alpha)

In [13]:
traces = [GRLWeightController(alpha=alpha)]

estimator = fe.Estimator(
    pipeline= pipeline, 
    network=network,
    traces = traces,
    epochs = epochs
)

In [14]:
estimator.fit()

    ______           __  ______     __  _                 __            
   / ____/___ ______/ /_/ ____/____/ /_(_)___ ___  ____ _/ /_____  _____
  / /_  / __ `/ ___/ __/ __/ / ___/ __/ / __ `__ \/ __ `/ __/ __ \/ ___/
 / __/ / /_/ (__  ) /_/ /___(__  ) /_/ / / / / / / /_/ / /_/ /_/ / /    
/_/    \__,_/____/\__/_____/____/\__/_/_/ /_/ /_/\__,_/\__/\____/_/     
                                                                        

FastEstimator: Saving tfrecord to /root/fastestimator_data/dann/tfr
FastEstimator: Converting Train TFRecords 0.0%, Speed: 0.00 record/sec
FastEstimator: Converting Train TFRecords 5.0%, Speed: 13031.75 record/sec
FastEstimator: Converting Train TFRecords 10.0%, Speed: 13740.44 record/sec
FastEstimator: Converting Train TFRecords 15.0%, Speed: 14161.23 record/sec
FastEstimator: Converting Train TFRecords 20.0%, Speed: 13921.10 record/sec
FastEstimator: Converting Train TFRecords 25.0%, Speed: 13652.25 record/sec
FastEstimator: Converting Train TFRecords 3

FastEstimator-Train: step: 4500; fe_loss: 2.004524; examples/sec: 5736.6; progress: 80.4%; 
FastEstimator-Train: step: 4600; fe_loss: 2.4861064; examples/sec: 6489.4; progress: 82.1%; 
FastEstimator-Train: step: 4700; fe_loss: 1.7448742; examples/sec: 2757.3; progress: 83.9%; 
FastEstimator-Train: step: 4800; fe_loss: 1.9282994; examples/sec: 5722.5; progress: 85.7%; 
FastEstimator-Train: step: 4900; fe_loss: 1.3160886; examples/sec: 5739.0; progress: 87.5%; 
FastEstimator-Train: step: 5000; fe_loss: 1.5141902; examples/sec: 6585.9; progress: 89.3%; 
FastEstimator-Train: step: 5100; fe_loss: 3.653502; examples/sec: 5676.4; progress: 91.1%; 
FastEstimator-Train: step: 5200; fe_loss: 2.950117; examples/sec: 2728.6; progress: 92.9%; 
FastEstimator-Train: step: 5300; fe_loss: 2.2227886; examples/sec: 5666.0; progress: 94.6%; 
FastEstimator-Train: step: 5400; fe_loss: 1.6200781; examples/sec: 6464.9; progress: 96.4%; 
FastEstimator-Train: step: 5500; fe_loss: 1.6394211; examples/sec: 5701.0