In [37]:
dataset = "cifar100"
iterations = "2,2,2,2"
u_channels = "256,256,256,256"
f_channels = "256,256,256,256"
batch_size = 1024
epochs = 150
epoch_step = 30
lr = .1
lr_step = 10
momentum = .9
wd = .0005
graph = True

iterations = [int(x) for x in iterations.split(",")]
u_channels = [int(x) for x in u_channels.split(",")]
f_channels = [int(x) for x in f_channels.split(",")]

In [38]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
import logging
tf.get_logger().setLevel(logging.ERROR)
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from datetime import datetime

In [39]:
(ds_train, ds_test), ds_info = tfds.load(
    dataset,
    split = ["train", "test"],
    shuffle_files = True,
    as_supervised = True,
    with_info = True)

if dataset == "mnist":
  mean, variance = [.1307], [.3081]
if dataset == "cifar10":
  mean, variance = [.4914, .4822, .4465], [.2023, .1994, .2010]
if dataset == "cifar100":
  mean, variance = [.5071, .4865, .4409], [.2673, .2564, .2762]
normalize = tf.keras.layers.Normalization(mean = mean,
                                          variance = variance)
train_layers = tf.keras.Sequential([
  tf.keras.layers.RandomTranslation(height_factor = .125,
                                    width_factor = .125,
                                    fill_mode = "constant"),
  tf.keras.layers.RandomRotation(factor = .08,
                                 fill_mode = "constant"),
  tf.keras.layers.RandomFlip(mode = "horizontal"),
  normalize
])
test_layers = tf.keras.Sequential([normalize])

def preprocess(ds, layers):
  ds = ds.map(lambda x, y: (layers(x), y),
              num_parallel_calls = tf.data.AUTOTUNE)
  ds = ds.cache()
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.AUTOTUNE)

  return ds

ds_train = preprocess(ds_train, train_layers)
ds_test = preprocess(ds_test, test_layers)

In [40]:
class MgSmooth(tf.keras.layers.Layer):

  def __init__(self, iterations, u_channels, f_channels):
    super(MgSmooth, self).__init__()

    self.iterations = iterations
    self.A = tf.keras.layers.Conv2D(u_channels,
                                    (3, 3),
                                    strides = (1, 1),
                                    padding = "same",
                                    use_bias = False)
    self.B = tf.keras.layers.Conv2D(f_channels,
                                    (3, 3),
                                    strides = (1, 1),
                                    padding = "same",
                                    use_bias = False)

    self.A_bns, self.B_bns = [], []
    for _ in range(self.iterations):
      self.A_bns.append(tf.keras.layers.BatchNormalization())
      self.B_bns.append(tf.keras.layers.BatchNormalization())

  def call(self, u, f):
    for i in range(self.iterations):
      error = tf.nn.relu(self.A_bns[i](f - self.A(u)))
      u = u + tf.nn.relu(self.B_bns[i](self.B(error)))
    return u, f

class MgBlock(tf.keras.layers.Layer):

  def __init__(self, iterations, u_channels, f_channels, A_old):
    super(MgBlock, self).__init__()

    self.iterations = iterations
    self.Pi = tf.keras.layers.Conv2D(u_channels,
                                     (3, 3),
                                     strides = (2, 2),
                                     padding = "same",
                                     use_bias = False)
    self.R = tf.keras.layers.Conv2D(f_channels,
                                    (3, 3),
                                    strides = (2, 2),
                                    padding = "same",
                                    use_bias = False)
    self.A_old = A_old
    self.MgSmooth = MgSmooth(self.iterations, u_channels, f_channels)

    self.Pi_bn = tf.keras.layers.BatchNormalization()
    self.R_bn = tf.keras.layers.BatchNormalization()

  def call(self, u0, f0):
    u1 = tf.nn.relu(self.Pi_bn(self.Pi(u0)))
    error = tf.nn.relu(self.R_bn(self.R(f0 - self.A_old(u0))))
    f1 = error + self.MgSmooth.A(u1)
    u, f = self.MgSmooth(u1, f1)
    return u, f

class MgNet(tf.keras.Model):

  def __init__(self, iterations, u_channels, f_channels, in_shape, out_shape):
    super(MgNet, self).__init__()

    self.iterations = iterations
    self.in_shape = in_shape
    self.A_init = tf.keras.layers.Conv2D(u_channels[0],
                                         (3, 3),
                                         strides = (1, 1),
                                         padding = "same",
                                         use_bias = False)
    self.A_bn = tf.keras.layers.BatchNormalization()

    self.blocks = []
    for i in range(len(self.iterations)):
      if i == 0:
        self.blocks.append(MgSmooth(iterations[i],
                                    u_channels[i],
                                    f_channels[i]))
        continue
      if i == 1:
        self.blocks.append(MgBlock(iterations[i],
                                   u_channels[i],
                                   f_channels[i],
                                   self.blocks[0].A))
        continue
      self.blocks.append(MgBlock(iterations[i],
                                 u_channels[i],
                                 f_channels[i],
                                 self.blocks[i - 1].MgSmooth.A))

    x = in_shape[0]
    for i in range(len(self.blocks) - 1):
      x = ((x + 2 - 3) // 2) + 1
    self.pool = tf.keras.layers.AveragePooling2D(pool_size = (x, x))
    self.softmax = tf.keras.layers.Dense(out_shape,
                                         activation = "softmax")
  
  def call(self, u0):
    f = tf.nn.relu(self.A_bn(self.A_init(u0)))
    u = tf.multiply(f, 0)

    for block in self.blocks:
      u, f = block(u, f)
    u = self.pool(u)
    u = tf.squeeze(u, [-2, -3])
    u = self.softmax(u)
    return u

In [41]:
def lr_schedule(epoch, lr):
  if (epoch + 1) % epoch_step == 0:
    return lr / lr_step
  return lr

tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_logical_devices("GPU")
strategy = tf.distribute.MirroredStrategy(gpus)
with strategy.scope():
  model = MgNet(iterations = iterations,
                u_channels = u_channels,
                f_channels = f_channels,
                in_shape = ds_info.features["image"].shape,
                out_shape = ds_info.features["label"].num_classes)

  loss = tf.keras.losses.SparseCategoricalCrossentropy()

  lr_s = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
  optimizer = tfa.optimizers.SGDW(learning_rate = lr,
                                  weight_decay = wd,
                                  momentum = momentum)
  
  model.compile(optimizer = optimizer,
                loss = loss,
                metrics = ["accuracy"])

  history = model.fit(ds_train,
                      epochs = epochs,
                      validation_data = ds_test,
                      callbacks = [lr_s])

model.summary()

2022-10-21 09:53:51.809632: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 1/150

2022-10-21 09:54:55.569813: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:547] The `assert_cardinality` transformation is currently not handled by the auto-shard rewrite and will be removed.


Epoch 2/150
Epoch 3/150
 49/391 [==>...........................] - ETA: 15s - loss: 3.4746 - accuracy: 0.1649

KeyboardInterrupt: 

In [None]:
if graph:
  loss = history.history["loss"]
  accuracy = history.history["accuracy"]
  val_loss = history.history["val_loss"]
  val_accuracy = history.history["val_accuracy"]
  timerange = range(len(loss))

  fig,ax = plt.subplots()
  train_loss_plot, = ax.plot(timerange, loss, color = "blue")
  val_loss_plot, = ax.plot(timerange, val_loss, color = "cyan")
  train_loss_plot.set_label("Train Loss")
  val_loss_plot.set_label("Validation Loss")
  ax.set_xlabel("Epoch")
  ax.set_ylabel("Loss")
  ax.legend(loc = "upper left")
  ax2 = ax.twinx()
  train_acc_plot, = ax2.plot(timerange, accuracy, color = "purple")
  val_acc_plot, = ax2.plot(timerange, val_accuracy, color = "pink")
  train_acc_plot.set_label("Train Accuracy")
  val_acc_plot.set_label("Validation Accuracy")
  ax2.set_ylabel("Accuracy")
  ax2.legend(loc = "upper right")
  plt.title("Loss vs Accuracy")
  plt.savefig(f"{dataset}_mgnet_{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.png")