In [None]:
import os
import h5py
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.model_selection import train_test_split
os.environ['CUDA_VISIBLE_DEVICES']='0'

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession

config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

print(tf.__version__)

In [None]:
num_classes     = 9
batch_size      = 16
train_src_days  = 6
train_trg_days  = 0
train_trg_env_days = 2
epochs          = 500
init_lr         = 0.0001
num_features    = 256
notes           = 'classes-{}_bs-{}_train_src_days-{}_train_trg_days-{}_train_trgenv_days-{}_initlr-{}_num_features-{}'.format(num_classes,
                                                                                                                             batch_size,
                                                                                                                             train_src_days,
                                                                                                                             train_trg_days,
                                                                                                                             train_trg_env_days,
                                                                                                                             init_lr,
                                                                                                                             num_features)                                                                                                                                        

dataset_path    = '/users/kjakkala/mmwave/data/'
log_dir         = '/users/kjakkala/mmwave/logs/new_logs/CycleGAN/{}'.format(notes)

In [None]:
def resize_data(data, output_shape=(256, 256)):
  _, height, width, channels = data.shape
  data = data.transpose((1, 2, 3, 0))
  data = resize(data.reshape(height, width, -1), output_shape)
  data = data.reshape(*output_shape, channels, -1)
  data = data.transpose((3, 0, 1, 2))
  return data

#Read data
hf = h5py.File(os.path.join(dataset_path, 'source_data.h5'), 'r')
X_data = resize_data(np.expand_dims(hf.get('X_data'), axis=-1))
y_data = np.array(hf.get('y_data'))
classes = list(hf.get('classes'))
classes = [n.decode("ascii", "ignore") for n in classes]
hf.close()
print(X_data.shape, y_data.shape, "\n", classes)

#balence dataset to 95 samples per day for each person
X_data_tmp = []
y_data_tmp = []
for day in range(10):
  for idx in range(len(classes)):
    X_data_tmp.extend(X_data[(y_data[:, 0] == idx) & (y_data[:, 1] == day)][:95])
    y_data_tmp.extend(y_data[(y_data[:, 0] == idx) & (y_data[:, 1] == day)][:95])
X_data = np.array(X_data_tmp)
y_data = np.array(y_data_tmp)
del X_data_tmp, y_data_tmp
print(X_data.shape, y_data.shape)

#remove harika's data
X_data = np.delete(X_data, np.where(y_data[:, 0] == 1)[0], 0)
y_data = np.delete(y_data, np.where(y_data[:, 0] == 1)[0], 0)

#update labes to handle 9 classes instead of 10
y_data[y_data[:, 0] >= 2, 0] -= 1
del classes[1]
print(X_data.shape, y_data.shape, "\n", classes)

#split days of data to train and test
X_src = X_data[y_data[:, 1] < train_src_days]
y_src = y_data[y_data[:, 1] < train_src_days, 0]
y_src = np.eye(len(classes))[y_src]
X_train_src, X_test_src, y_train_src, y_test_src = train_test_split(X_src,
                                                                    y_src,
                                                                    stratify=y_src,
                                                                    test_size=0.10,
                                                                    random_state=42)

X_trg = X_data[y_data[:, 1] >= train_src_days]
y_trg = y_data[y_data[:, 1] >= train_src_days]
X_train_trg = X_trg[y_trg[:, 1] < train_src_days+train_trg_days]
y_train_trg = y_trg[y_trg[:, 1] < train_src_days+train_trg_days, 0]
y_train_trg = np.eye(len(classes))[y_train_trg]

X_test_trg = X_data[y_data[:, 1] >= train_src_days+train_trg_days]
y_test_trg = y_data[y_data[:, 1] >= train_src_days+train_trg_days, 0]
y_test_trg = np.eye(len(classes))[y_test_trg]

del X_src, y_src, X_trg, y_trg, X_data, y_data

#standardise dataset
src_mean = np.mean(X_train_src)
X_train_src -= src_mean
src_min = np.min(X_train_src)
src_ptp = np.ptp(X_train_src)
X_train_src = 2.*(X_train_src - src_min)/src_ptp-1

X_test_src -= src_mean
X_test_src = 2.*(X_test_src - src_min)/src_ptp-1

if(X_train_trg.shape[0] != 0):
  trg_mean = np.mean(X_train_trg)
  X_train_trg -= trg_mean
  trg_min = np.min(X_train_trg)
  trg_ptp = np.ptp(X_train_trg)
  X_train_trg = 2.*(X_train_trg - trg_min)/trg_ptp-1

  X_test_trg -= trg_mean
  X_test_trg = 2.*(X_test_trg - trg_min)/trg_ptp-1
else:
  X_test_trg -= src_mean
  X_test_trg = 2.*(X_test_trg - src_min)/src_ptp-1
  
X_train_src = X_train_src.astype(np.float32)
y_train_src = y_train_src.astype(np.uint8)
X_test_src  = X_test_src.astype(np.float32)
y_test_src  = y_test_src.astype(np.uint8)
X_train_trg = X_train_trg.astype(np.float32)
y_train_trg = y_train_trg.astype(np.uint8)
X_test_trg  = X_test_trg.astype(np.float32)
y_test_trg  = y_test_trg.astype(np.uint8)

print(X_train_src.shape, y_train_src.shape,  X_test_src.shape, y_test_src.shape, X_train_trg.shape, y_train_trg.shape, X_test_trg.shape, y_test_trg.shape)

def get_trg_data(fname, src_classes, train_trg_days):
  #Read data
  hf = h5py.File(fname, 'r')
  X_data_trg = resize_data(np.expand_dims(hf.get('X_data'), axis=-1))
  y_data_trg = np.array(hf.get('y_data'))
  trg_classes = list(hf.get('classes'))
  trg_classes = [n.decode("ascii", "ignore") for n in trg_classes]
  hf.close()

  #split days of data to train and test
  X_train_trg = X_data_trg[y_data_trg[:, 1] < train_trg_days]
  y_train_trg = y_data_trg[y_data_trg[:, 1] < train_trg_days, 0]
  y_train_trg = np.array([src_classes.index(trg_classes[y_train_trg[i]]) for i in range(y_train_trg.shape[0])])
  y_train_trg = np.eye(len(src_classes))[y_train_trg]
  y_train_trg = y_train_trg.astype(np.int64)

  X_test_trg = X_data_trg[y_data_trg[:, 1] >= train_trg_days]
  y_test_trg = y_data_trg[y_data_trg[:, 1] >= train_trg_days, 0]
  y_test_trg = np.eye(len(src_classes))[y_test_trg]
  y_test_trg = y_test_trg.astype(np.int64)

  if(X_train_trg.shape[0] != 0):
    trg_mean = np.mean(X_train_trg)
    X_train_trg -= trg_mean
    trg_min = np.min(X_train_trg)
    trg_ptp = np.ptp(X_train_trg)
    X_train_trg = 2.*(X_train_trg - trg_min)/trg_ptp-1

    X_test_trg -= trg_mean
    X_test_trg = 2.*(X_test_trg - trg_min)/trg_ptp-1
  else:
    X_test_trg -= np.mean(X_test_trg)
    trg_min = np.min(X_test_trg)
    trg_ptp = np.ptp(X_test_trg)
    X_test_trg = 2.*(X_test_trg - trg_min)/trg_ptp-1
    
  return X_train_trg.astype(np.float32), y_train_trg.astype(np.uint8), X_test_trg.astype(np.float32), y_test_trg.astype(np.uint8)

X_train_conf,   y_train_conf,   X_test_conf,   y_test_conf   = get_trg_data(os.path.join(dataset_path, 'target_conf_data.h5'),   classes, 3)
X_train_server, y_train_server, X_test_server, y_test_server = get_trg_data(os.path.join(dataset_path, 'target_server_data.h5'), classes, train_trg_env_days)
X_data_office,  y_data_office,  _,             _             = get_trg_data(os.path.join(dataset_path, 'target_office_data.h5'), classes, 3)

print(X_train_conf.shape,   y_train_conf.shape,    X_test_conf.shape,   y_test_conf.shape)
print(X_train_server.shape, y_train_server.shape,  X_test_server.shape, y_test_server.shape)
print(X_data_office.shape, y_data_office.shape)

#get tf.data objects for each set

#Test
conf_test_set = tf.data.Dataset.from_tensor_slices((X_train_conf, y_train_conf))
conf_test_set = conf_test_set.batch(batch_size, drop_remainder=False)
conf_test_set = conf_test_set.prefetch(batch_size)

server_test_set = tf.data.Dataset.from_tensor_slices((X_test_server, y_test_server))
server_test_set = server_test_set.batch(batch_size, drop_remainder=False)
server_test_set = server_test_set.prefetch(batch_size)

office_test_set = tf.data.Dataset.from_tensor_slices((X_data_office, y_data_office))
office_test_set = office_test_set.batch(batch_size, drop_remainder=False)
office_test_set = office_test_set.prefetch(batch_size)

src_test_set = tf.data.Dataset.from_tensor_slices((X_test_src, y_test_src))
src_test_set = src_test_set.batch(batch_size, drop_remainder=False)
src_test_set = src_test_set.prefetch(batch_size)

time_test_set = tf.data.Dataset.from_tensor_slices((X_test_trg, y_test_trg))
time_test_set = time_test_set.batch(batch_size, drop_remainder=False)
time_test_set = time_test_set.prefetch(batch_size)

#Train
src_train_set = tf.data.Dataset.from_tensor_slices((X_train_src, y_train_src))
src_train_set = src_train_set.shuffle(X_train_src.shape[0])
src_train_set = src_train_set.batch(batch_size, drop_remainder=True)
src_train_set = src_train_set.prefetch(batch_size)

server_train_set = tf.data.Dataset.from_tensor_slices((X_train_server, y_train_server))
server_train_set = server_train_set.shuffle(X_train_server.shape[0])
server_train_set = server_train_set.batch(batch_size, drop_remainder=True)
server_train_set = server_train_set.prefetch(batch_size)
server_train_set = server_train_set.repeat(-1)

In [None]:
class InstanceNormalization(tf.keras.layers.Layer):
  """Instance Normalization Layer (https://arxiv.org/abs/1607.08022)."""

  def __init__(self, epsilon=1e-5):
    super(InstanceNormalization, self).__init__()
    self.epsilon = epsilon

  def build(self, input_shape):
    self.scale = self.add_weight(
        name='scale',
        shape=input_shape[-1:],
        initializer=tf.random_normal_initializer(1., 0.02),
        trainable=True)

    self.offset = self.add_weight(
        name='offset',
        shape=input_shape[-1:],
        initializer='zeros',
        trainable=True)

  def call(self, x):
    mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
    inv = tf.math.rsqrt(variance + self.epsilon)
    normalized = (x - mean) * inv
    return self.scale * normalized + self.offset
  
def downsample(filters, size, norm_type='batchnorm', apply_norm=True):
  """Downsamples an input.
  Conv2D => Batchnorm => LeakyRelu
  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_norm: If True, adds the batchnorm layer
  Returns:
    Downsample Sequential Model
  """
  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
                             kernel_initializer=initializer, use_bias=False))

  if apply_norm:
    if norm_type.lower() == 'batchnorm':
      result.add(tf.keras.layers.BatchNormalization())
    elif norm_type.lower() == 'instancenorm':
      result.add(InstanceNormalization())

  result.add(tf.keras.layers.LeakyReLU())

  return result


def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
  """Upsamples an input.
  Conv2DTranspose => Batchnorm => Dropout => Relu
  Args:
    filters: number of filters
    size: filter size
    norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
    apply_dropout: If True, adds the dropout layer
  Returns:
    Upsample Sequential Model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  result = tf.keras.Sequential()
  result.add(
      tf.keras.layers.Conv2DTranspose(filters, size, strides=2,
                                      padding='same',
                                      kernel_initializer=initializer,
                                      use_bias=False))

  if norm_type.lower() == 'batchnorm':
    result.add(tf.keras.layers.BatchNormalization())
  elif norm_type.lower() == 'instancenorm':
    result.add(InstanceNormalization())

  if apply_dropout:
    result.add(tf.keras.layers.Dropout(0.5))

  result.add(tf.keras.layers.ReLU())

  return result


def unet_generator(output_channels, norm_type='batchnorm'):
  """Modified u-net generator model (https://arxiv.org/abs/1611.07004).
  Args:
    output_channels: Output channels
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
  Returns:
    Generator model
  """

  down_stack = [
      downsample(64, 4, norm_type, apply_norm=False),  # (bs, 128, 128, 64)
      downsample(128, 4, norm_type),  # (bs, 64, 64, 128)
      downsample(256, 4, norm_type),  # (bs, 32, 32, 256)
      downsample(512, 4, norm_type),  # (bs, 16, 16, 512)
      downsample(512, 4, norm_type),  # (bs, 8, 8, 512)
      downsample(512, 4, norm_type),  # (bs, 4, 4, 512)
      downsample(512, 4, norm_type),  # (bs, 2, 2, 512)
      downsample(512, 4, norm_type),  # (bs, 1, 1, 512)
  ]

  up_stack = [
      upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 2, 2, 1024)
      upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 4, 4, 1024)
      upsample(512, 4, norm_type, apply_dropout=True),  # (bs, 8, 8, 1024)
      upsample(512, 4, norm_type),  # (bs, 16, 16, 1024)
      upsample(256, 4, norm_type),  # (bs, 32, 32, 512)
      upsample(128, 4, norm_type),  # (bs, 64, 64, 256)
      upsample(64, 4, norm_type),  # (bs, 128, 128, 128)
  ]

  initializer = tf.random_normal_initializer(0., 0.02)
  last = tf.keras.layers.Conv2DTranspose(
      output_channels, 4, strides=2,
      padding='same', kernel_initializer=initializer,
      activation='tanh')  # (bs, 256, 256, 3)

  concat = tf.keras.layers.Concatenate()

  inputs = tf.keras.layers.Input(shape=[None, None, 1])
  x = inputs

  # Downsampling through the model
  skips = []
  for down in down_stack:
    x = down(x)
    skips.append(x)

  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    x = concat([x, skip])

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)


def discriminator(norm_type='batchnorm', target=True):
  """PatchGan discriminator model (https://arxiv.org/abs/1611.07004).
  Args:
    norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
    target: Bool, indicating whether target image is an input or not.
  Returns:
    Discriminator model
  """

  initializer = tf.random_normal_initializer(0., 0.02)

  inp = tf.keras.layers.Input(shape=[None, None, 1], name='input_image')
  x = inp

  if target:
    tar = tf.keras.layers.Input(shape=[None, None, 1], name='target_image')
    x = tf.keras.layers.concatenate([inp, tar])  # (bs, 256, 256, channels*2)

  down1 = downsample(64, 4, norm_type, False)(x)  # (bs, 128, 128, 64)
  down2 = downsample(128, 4, norm_type)(down1)  # (bs, 64, 64, 128)
  down3 = downsample(256, 4, norm_type)(down2)  # (bs, 32, 32, 256)

  zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
  conv = tf.keras.layers.Conv2D(
      512, 4, strides=1, kernel_initializer=initializer,
      use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

  if norm_type.lower() == 'batchnorm':
    norm1 = tf.keras.layers.BatchNormalization()(conv)
  elif norm_type.lower() == 'instancenorm':
    norm1 = InstanceNormalization()(conv)

  leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

  zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

  last = tf.keras.layers.Conv2D(
      1, 4, strides=1,
      kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

  if target:
    return tf.keras.Model(inputs=[inp, tar], outputs=last)
  else:
    return tf.keras.Model(inputs=inp, outputs=last)

In [None]:
L2_WEIGHT_DECAY = 1e-4
BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON = 1e-5

class IdentityBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters, stage, block, activation='relu'):
    self.activation = activation

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    super().__init__(name='stage-' + str(stage) + '_block-' + block)

    filters1, filters2, filters3 = filters
    bn_axis = -1

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2a')
    self.bn2a = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2a')
    self.act1  = tf.keras.layers.Activation(self.activation)

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size,
                                         padding='same',
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2b')
    self.bn2b = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2b')
    self.act2  = tf.keras.layers.Activation(self.activation)

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2c')
    self.bn2c = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2c')
    self.act3  = tf.keras.layers.Activation(self.activation)

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = self.act1(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = self.act2(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    x = tf.keras.layers.add([x, input_tensor])
    x = self.act3(x)
    return x


"""A block that has a conv layer at shortcut.

Note that from stage 3,
the second conv layer at main path is with strides=(2, 2)
And the shortcut should have strides=(2, 2) as well

Args:
  kernel_size: the kernel size of middle conv layer at main path
  filters: list of integers, the filters of 3 conv layer at main path
  stage: integer, current stage label, used for generating layer names
  block: 'a','b'..., current block label, used for generating layer names
  strides: Strides for the second conv layer in the block.

Returns:
  A Keras model instance for the block.
"""
class ConvBlock(tf.keras.Model):
  def __init__(self, kernel_size, filters, stage, block, strides=(2, 2), activation='relu'):
    self.activation = activation

    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'

    super().__init__(name='stage-' + str(stage) + '_block-' + block)

    filters1, filters2, filters3 = filters
    bn_axis = -1

    self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2a')
    self.bn2a = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2a')
    self.act1  = tf.keras.layers.Activation(self.activation)

    self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size,
                                         strides=strides,
                                         padding='same',
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2b')
    self.bn2b = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2b')
    self.act2  = tf.keras.layers.Activation(self.activation)

    self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1),
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '2c')
    self.bn2c = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '2c')

    self.conv2s = tf.keras.layers.Conv2D(filters3, (1, 1),
                                         strides=strides,
                                         use_bias=False,
                                         kernel_initializer='he_normal',
                                         kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                         name=conv_name_base + '1')
    self.bn2s = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                   momentum=BATCH_NORM_DECAY,
                                                   epsilon=BATCH_NORM_EPSILON,
                                                   name=bn_name_base + '1')
    self.act3  = tf.keras.layers.Activation(self.activation)

  def call(self, input_tensor, training=False):
    x = self.conv2a(input_tensor)
    x = self.bn2a(x, training=training)
    x = self.act1(x)

    x = self.conv2b(x)
    x = self.bn2b(x, training=training)
    x = self.act2(x)

    x = self.conv2c(x)
    x = self.bn2c(x, training=training)

    shortcut = self.conv2s(input_tensor)
    shortcut = self.bn2s(shortcut, training=training)

    x = tf.keras.layers.add([x, shortcut])
    x = self.act3(x)
    return x


"""Instantiates the ResNet50 architecture.

Args:
  num_classes: `int` number of classes for image classification.

Returns:
    A Keras model instance.
"""
class ResNet50(tf.keras.Model):
  def __init__(self, num_classes, num_features, activation='relu'):
    super().__init__(name='generator')
    bn_axis = -1
    self.activation = activation
    self.num_classes = num_classes

    self.conv1 = tf.keras.layers.Conv2D(32, (7, 7),
                                        strides=(2, 2),
                                        padding='valid',
                                        use_bias=False,
                                        kernel_initializer='he_normal',
                                        kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                        name='conv1')
    self.bn1 = tf.keras.layers.BatchNormalization(axis=bn_axis,
                                                  momentum=BATCH_NORM_DECAY,
                                                  epsilon=BATCH_NORM_EPSILON,
                                                  name='bn_conv1')
    self.act1 = tf.keras.layers.Activation(self.activation, name=self.activation+'1')
    self.max_pool1 = tf.keras.layers.MaxPooling2D((3, 3),
                                                  strides=(2, 2),
                                                  padding='same',
                                                  name='max_pool1')

    self.blocks = []
    self.blocks.append(ConvBlock(3, [32, 32, 128], strides=(1, 1), stage=2, block='a', activation=self.activation))
    self.blocks.append(IdentityBlock(3, [32, 32, 128], stage=2, block='b', activation=self.activation))

    self.blocks.append(ConvBlock(3, [64, 64, 256], stage=3, block='a', activation=self.activation))
    self.blocks.append(IdentityBlock(3, [64, 64, 256], stage=3, block='b', activation=self.activation))

    self.blocks.append(ConvBlock(3, [64, 64, 256], stage=4, block='a', activation=self.activation))
    self.blocks.append(IdentityBlock(3, [64, 64, 256], stage=4, block='b', activation=self.activation))

    self.avg_pool = tf.keras.layers.GlobalAveragePooling2D(name='avg_pool')
    self.fc1 = tf.keras.layers.Dense(num_features,
                                     activation=self.activation,
                                     kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
                                     kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                     bias_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                     name='fc1')

    self.logits = tf.keras.layers.Dense(num_classes,
                                        activation=None,
                                        kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01),
                                        kernel_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                        bias_regularizer=tf.keras.regularizers.l2(L2_WEIGHT_DECAY),
                                        name='logits')
    
  def call(self, img_input, training=False):
    x = self.conv1(img_input)
    x = self.bn1(x, training=training)
    x = self.act1(x)
    x = self.max_pool1(x)

    for block in self.blocks:
      x = block(x, training=training)

    x = self.avg_pool(x)
    fc1 = self.fc1(x)
    logits = self.logits(fc1)

    return logits, fc1

In [None]:
LAMBDA = 10

loss_obj = tf.keras.losses.BinaryCrossentropy(from_logits=True)

def discriminator_loss(real, generated):
  real_loss = loss_obj(tf.ones_like(real), real)
  generated_loss = loss_obj(tf.zeros_like(generated), generated)
  total_disc_loss = real_loss + generated_loss
  return total_disc_loss * 0.5

def generator_loss(generated):
  return loss_obj(tf.ones_like(generated), generated)

def identity_loss(real_image, same_image):
  loss = tf.reduce_mean(tf.abs(real_image - same_image))
  return LAMBDA * 0.5 * loss

def get_cross_entropy_loss(labels, logits):
  loss = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=logits)
  return tf.reduce_mean(loss)

In [None]:
tb_gen_src_entropy_loss           = tf.keras.metrics.Mean(name='gen_src_entropy_loss')
tb_gen_trg_entropy_loss_src_guide = tf.keras.metrics.Mean(name='gen_trg_entropy_loss_src_guide')
tb_gen_trg_entropy_loss_src_data  = tf.keras.metrics.Mean(name='gen_trg_entropy_loss_src_data')
tb_disc_src_loss                  = tf.keras.metrics.Mean(name='disc_src_loss')
tb_disc_trg_loss                  = tf.keras.metrics.Mean(name='disc_trg_loss')
tb_gen_src_loss                   = tf.keras.metrics.Mean(name='gen_src_loss')
tb_gen_trg_loss                   = tf.keras.metrics.Mean(name='gen_trg_loss')
tb_src_identity_loss              = tf.keras.metrics.Mean(name='src_identity_loss')
tb_trg_identity_loss              = tf.keras.metrics.Mean(name='trg_identity_loss')
tb_total_gen_s_loss               = tf.keras.metrics.Mean(name='total_gen_s_loss')   
tb_total_gen_t_loss               = tf.keras.metrics.Mean(name='total_gen_t_loss')
tb_total_clas_s_loss              = tf.keras.metrics.Mean(name='total_clas_s_loss')   
tb_total_clas_t_loss              = tf.keras.metrics.Mean(name='total_clas_t_loss')
temporal_test_acc    = tf.keras.metrics.CategoricalAccuracy(name='temporal_test_acc')
source_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='source_train_acc')
source_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='source_test_acc')
office_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='office_test_acc')
server_train_acc     = tf.keras.metrics.CategoricalAccuracy(name='server_train_acc')
server_test_acc      = tf.keras.metrics.CategoricalAccuracy(name='server_test_acc')
conference_test_acc  = tf.keras.metrics.CategoricalAccuracy(name='conference_test_acc')

@tf.function
def train_step(src_x, src_y, trg_x, trg_y):
  # persistent is set to True because the tape is used more than
  # once to calculate the gradients.
  with tf.GradientTape(persistent=True) as tape:
    src_x_fake   = generator_s(trg_x,      training=True)
    trg_x_cycled = generator_t(src_x_fake, training=True)

    trg_x_fake   = generator_t(src_x,      training=True)
    src_x_cycled = generator_s(trg_x_fake, training=True)
    
    src_x_same = generator_s(src_x, training=True)
    trg_x_same = generator_t(trg_x, training=True)
    
    src_logits, _ = classifier_s(src_x, training=True)
    trg_logits, _ = classifier_t(trg_x, training=True)
    
    src_logits_fake, _ = classifier_s(src_x_fake, training=True)
    trg_logits_fake, _ = classifier_t(trg_x_fake, training=True)
    
    #Classifier##########################################
    gen_src_entropy_loss = get_cross_entropy_loss(labels=src_y, 
                                                  logits=src_logits)
    
    gen_trg_entropy_loss_src_guide = get_cross_entropy_loss(labels=tf.nn.softmax(src_logits_fake), 
                                                            logits=trg_logits)   
    gen_trg_entropy_loss_src_data  = get_cross_entropy_loss(labels=src_y, 
                                                            logits=trg_logits_fake)
    
    #Discriminator##########################################
    disc_src      = discriminator_s(src_x, training=True)
    disc_src_fake = discriminator_s(src_x_fake, training=True)
    disc_src_loss = discriminator_loss(disc_src, disc_src_fake)

    disc_trg      = discriminator_t(trg_x, training=True)
    disc_trg_fake = discriminator_t(trg_x_fake, training=True)
    disc_trg_loss = discriminator_loss(disc_trg, disc_trg_fake)

    #Generator##########################################
    gen_src_loss = generator_loss(disc_src_fake)
    gen_trg_loss = generator_loss(disc_trg_fake)
    
    src_identity_loss = identity_loss(src_x, src_x_same)
    trg_identity_loss = identity_loss(trg_x, trg_x_same)
    
    #Loss##########################################
    total_gen_s_loss = gen_src_loss + src_identity_loss + gen_src_entropy_loss
    total_gen_t_loss = gen_trg_loss + \
                       trg_identity_loss + \
                       gen_trg_entropy_loss_src_guide + \
                       gen_trg_entropy_loss_src_data 
    
    total_clas_s_loss = gen_src_entropy_loss
    total_clas_t_loss = gen_trg_entropy_loss_src_guide + gen_trg_entropy_loss_src_data 
    
  # Calculate the gradients for generator and discriminator
  classifier_s_gradients = tape.gradient(total_clas_s_loss, 
                                         classifier_s.trainable_variables)
  classifier_t_gradients = tape.gradient(total_clas_t_loss, 
                                         classifier_t.trainable_variables)
  
  generator_s_gradients = tape.gradient(total_gen_s_loss, 
                                        generator_s.trainable_variables)
  generator_t_gradients = tape.gradient(total_gen_t_loss, 
                                        generator_t.trainable_variables)
  
  discriminator_s_gradients = tape.gradient(disc_src_loss, 
                                            discriminator_s.trainable_variables)
  discriminator_t_gradients = tape.gradient(disc_trg_loss, 
                                            discriminator_t.trainable_variables)
    
  # Apply the gradients to the optimizer
  classifier_s_optimizer.apply_gradients(zip(classifier_s_gradients, 
                                             classifier_s.trainable_variables))

  classifier_t_optimizer.apply_gradients(zip(classifier_t_gradients, 
                                             classifier_t.trainable_variables))
  
  generator_s_optimizer.apply_gradients(zip(generator_s_gradients, 
                                            generator_s.trainable_variables))

  generator_t_optimizer.apply_gradients(zip(generator_t_gradients, 
                                            generator_t.trainable_variables))
  
  discriminator_s_optimizer.apply_gradients(zip(discriminator_s_gradients,
                                                discriminator_s.trainable_variables))
  
  discriminator_t_optimizer.apply_gradients(zip(discriminator_t_gradients,
                                                discriminator_t.trainable_variables))
  
  tb_gen_src_entropy_loss(gen_src_entropy_loss)
  tb_gen_trg_entropy_loss_src_guide(gen_trg_entropy_loss_src_guide)
  tb_gen_trg_entropy_loss_src_data(gen_trg_entropy_loss_src_data)
  tb_disc_src_loss(disc_src_loss)
  tb_disc_trg_loss(disc_trg_loss)
  tb_gen_src_loss(gen_src_loss)
  tb_gen_trg_loss(gen_trg_loss)
  tb_src_identity_loss(src_identity_loss)
  tb_trg_identity_loss(trg_identity_loss)
  tb_total_gen_s_loss(total_gen_s_loss)      
  tb_total_gen_t_loss(total_gen_t_loss)    
  tb_total_clas_s_loss(total_clas_s_loss)    
  tb_total_clas_t_loss(total_clas_t_loss)    
  source_train_acc(src_y, tf.nn.softmax(src_logits))
  server_train_acc(trg_y, tf.nn.softmax(trg_logits))
  
@tf.function
def test_src(images):
  _, logits =  generator_s(images, training=False)
  return tf.nn.softmax(logits)

@tf.function
def test_trg(images):
  _, logits =  generator_t(images, training=False)
  return tf.nn.softmax(logits)

In [None]:
class lr_schedule():
  def __init__(self, init_lr=0.01, alpha=10, beta=0.75):
    self.init_lr = init_lr
    self.alpha = alpha
    self.beta = beta
    self.p = 0

  def set_p(self, p):
    self.p = p

  def __call__(self):
    return self.init_lr/((1+(self.alpha*self.p))**self.beta)

generator_t     = unet_generator(1, norm_type='instancenorm')
generator_s     = unet_generator(1, norm_type='instancenorm')
discriminator_s = discriminator(norm_type='instancenorm', target=False)
discriminator_t = discriminator(norm_type='instancenorm', target=False)
classifier_s    = ResNet50(num_classes, num_features, "selu")
classifier_t    = ResNet50(num_classes, num_features, "selu")

learning_rate             = lr_schedule(init_lr=init_lr)
classifier_s_optimizer    = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
classifier_t_optimizer    = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
generator_s_optimizer     = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
generator_t_optimizer     = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
discriminator_s_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)
discriminator_t_optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.5)

summary_writer = tf.summary.create_file_writer(log_dir)

In [None]:
for epoch in range(epochs):
  print(epoch)
  learning_rate.set_p(epoch/epochs)

  for source_data, server_data in zip(src_train_set, server_train_set):
    train_step(source_data[0], source_data[1], server_data[0], server_data[1])
    
  for data in time_test_set:
    temporal_test_acc(test_trg(data[0]), data[1])

  for data in src_test_set:
    source_test_acc(test_trg(data[0]), data[1])

  for data in office_test_set:
    office_test_acc(test_trg(data[0]), data[1])

  for data in server_test_set:
    server_test_acc(test_trg(data[0]), data[1])

  for data in conf_test_set:
    conference_test_acc(test_trg(data[0]), data[1])

  with summary_writer.as_default():
    tf.summary.scalar("tb_gen_src_entropy_loss", tb_gen_src_entropy_loss.result(), step=epoch)
    tf.summary.scalar("tb_gen_trg_entropy_loss_src_guide", tb_gen_trg_entropy_loss_src_guide.result(), step=epoch)
    tf.summary.scalar("tb_gen_trg_entropy_loss_src_data", tb_gen_trg_entropy_loss_src_data.result(), step=epoch)
    tf.summary.scalar("tb_disc_src_loss", tb_disc_src_loss.result(), step=epoch)
    tf.summary.scalar("tb_disc_trg_loss", tb_disc_trg_loss.result(), step=epoch)
    tf.summary.scalar("tb_gen_src_loss", tb_gen_src_loss.result(), step=epoch)
    tf.summary.scalar("tb_gen_trg_loss", tb_gen_trg_loss.result(), step=epoch)
    tf.summary.scalar("tb_src_identity_loss", tb_src_identity_loss.result(), step=epoch)
    tf.summary.scalar("tb_trg_identity_loss", tb_trg_identity_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_gen_s_loss", tb_total_gen_s_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_gen_t_loss", tb_total_gen_t_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_clas_s_loss", tb_total_clas_s_loss.result(), step=epoch)
    tf.summary.scalar("tb_total_clas_t_loss", tb_total_clas_t_loss.result(), step=epoch)    
    tf.summary.scalar("source_train_acc", source_train_acc.result(), step=epoch)
    tf.summary.scalar("server_train_acc", server_train_acc.result(), step=epoch)
    tf.summary.scalar("temporal_test_acc", temporal_test_acc.result(), step=epoch)
    tf.summary.scalar("source_test_acc", source_test_acc.result(), step=epoch)
    tf.summary.scalar("office_test_acc", office_test_acc.result(), step=epoch)
    tf.summary.scalar("server_test_acc", server_test_acc.result(), step=epoch)
    tf.summary.scalar("conference_test_acc", conference_test_acc.result(), step=epoch)
    
  tb_gen_src_entropy_loss.reset_states()
  tb_gen_trg_entropy_loss_src_guide.reset_states()
  tb_gen_trg_entropy_loss_src_data.reset_states()
  tb_disc_src_loss.reset_states()
  tb_disc_trg_loss.reset_states()
  tb_gen_src_loss.reset_states()
  tb_gen_trg_loss.reset_states()
  tb_src_identity_loss.reset_states()
  tb_trg_identity_loss.reset_states()
  tb_total_gen_s_loss.reset_states()    
  tb_total_gen_t_loss.reset_states() 
  tb_total_clas_s_loss.reset_states() 
  tb_total_clas_t_loss.reset_states()     
  source_train_acc.reset_states()
  server_train_acc.reset_states()
  temporal_test_acc.reset_states()
  source_test_acc.reset_states()
  office_test_acc.reset_states()
  server_test_acc.reset_states()
  conference_test_acc.reset_states()