In [1]:
# - initialization
# - batchnorm parameters
# - weight decay and regularization

# **TensorFlow**

In [2]:
dataset = "cifar100"
iterations = "2,2,2,2"
u_channels = "256,256,256,256"
f_channels = "256,256,256,256"
batch_size = 128
epochs = 150
epoch_step = 30
lr = .1
lr_step = 10
momentum = .9
wd = .0005
wd = 0

iterations = [int(x) for x in iterations.split(",")]
u_channels = [int(x) for x in u_channels.split(",")]
f_channels = [int(x) for x in f_channels.split(",")]

In [3]:
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
import tensorflow as tf
import numpy as np
import logging
tf.get_logger().setLevel(logging.ERROR)
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm.notebook import tqdm

gpus = tf.config.experimental.list_physical_devices("GPU")
for gpu in gpus:
  tf.config.experimental.set_memory_growth(gpu, True)

In [4]:
(ds_train, ds_test), ds_info = tfds.load(
    dataset,
    split = ["train", "test"],
    as_supervised = True,
    with_info = True)

rescale = tf.keras.layers.Rescaling(1. / 255)
if dataset == "mnist":
  mean, variance = [.1307], np.square([.3081])
if dataset == "cifar10":
  mean, variance = [.4914, .4822, .4465], np.square([.2023, .1994, .2010])
if dataset == "cifar100":
  mean, variance = [.5071, .4865, .4409], np.square([.2673, .2564, .2762])
normalize = tf.keras.layers.Normalization(mean = mean,
                                          variance = variance)

def preprocess(ds, training):
  if training:
    layers = tf.keras.Sequential([
      rescale,
      tf.keras.layers.RandomTranslation(height_factor = .125,
                                        width_factor = .125,
                                        fill_mode = "constant"),
      tf.keras.layers.RandomFlip(mode = "horizontal"),
      normalize
    ])
    ds = ds.shuffle(ds_info.splits["train"].num_examples)
  else:
    layers = tf.keras.Sequential([rescale, normalize])

  ds = ds.batch(batch_size)
  ds = ds.map(lambda x, y: (layers(x), y),
              num_parallel_calls = tf.data.AUTOTUNE)
  ds = ds.cache()
  ds = ds.prefetch(tf.data.AUTOTUNE)

  return ds

ds_train = preprocess(ds_train, training = True)
ds_test = preprocess(ds_test, training = False)

In [5]:
class HeUniform(tf.keras.initializers.Initializer):
  
  def __init__(self, a, mode, nonlinearity, bound = None):
    self.a = a
    self.mode = mode
    self.nonlinearity = nonlinearity
    self.bound = bound
    
    if self.nonlinearity == "sigmoid":
      self.gain = 1
    elif self.nonlinearity == "tanh":
      self.gain = 5.0 / 3
    elif self.nonlinearity == "relu":
      self.gain = np.sqrt(2.0)
    elif self.nonlinearity == "leaky_relu":
      if self.a is None:
        self.gain = .01
      else:
        self.gain = np.sqrt(2.0 / (1 + self.a ** 2))
    elif self.nonlinearity == "selu":
      self.gain = 3.0 / 4
    
  def __call__(self, shape, dtype = None, **kwargs):
    if self.bound:
      return tf.random.uniform(shape,
                               minval = -self.bound,
                               maxval = self.bound)
    
    torch_shape = np.flip(shape)
    
    num_input_fmaps = torch_shape[1]
    num_output_fmaps = torch_shape[0]
    receptive_field_size = 1
    if len(torch_shape) > 2:
      for s in torch_shape[2:]:
        receptive_field_size *= s
    fan_in = num_input_fmaps * receptive_field_size    
    fan_out = num_output_fmaps * receptive_field_size
    
    if self.mode == "fan_in":
      fan = fan_in
    elif self.mode == "fan_out":
      fan = fan_out
      
    std = self.gain / np.sqrt(fan)
    bound = np.sqrt(3.0) * std
    
    return tf.random.uniform(shape,
                             minval = -bound,
                             maxval = bound)
  
class Conv2D(tf.keras.layers.Layer):
  
  def __init__(self,
               filters,
               kernel_size,
               strides = (1, 1),
               padding = "valid",
               data_format = None,
               dilation_rate = (1, 1),
               groups = 1,
               activation = None,
               use_bias = True,
               kernel_initializer = None,
               bias_initializer = "zeros",
               kernel_regularizer = None,
               bias_regularizer = None,
               activity_regularizer = None,
               kernel_constraint = None,
               bias_constraint = None,
               **kwargs
              ):
    super(Conv2D, self).__init__()
    
    if kernel_initializer is None:
      kernel_initializer = HeUniform(np.sqrt(5),
                                     "fan_in",
                                     "leaky_relu")
      
    self.torch_padding = None
    if isinstance(padding, list) or isinstance(padding, tuple):
      self.torch_padding = padding
      padding = "valid"
      
    self.conv2d = tf.keras.layers.Conv2D(filters = filters,
                                         kernel_size = kernel_size,
                                         strides = strides,
                                         padding = padding,
                                         data_format = data_format,
                                         dilation_rate = dilation_rate,
                                         groups = groups,
                                         activation = activation,
                                         use_bias = use_bias,
                                         kernel_initializer = kernel_initializer,
                                         bias_initializer = bias_initializer,
                                         kernel_regularizer = kernel_regularizer,
                                         bias_regularizer = bias_regularizer,
                                         activity_regularizer = activity_regularizer,
                                         kernel_constraint = kernel_constraint,
                                         bias_constraint = bias_constraint,
                                         **kwargs
                                        )
    
  def call(self, inputs):
    if self.torch_padding:
      inputs = tf.pad(inputs,
                      [[0, 0],
                      [self.torch_padding[0], self.torch_padding[0]],
                      [self.torch_padding[1], self.torch_padding[1]],
                      [0, 0]],
                      "CONSTANT")
    out = self.conv2d(inputs)
    return out

In [6]:
class MgSmooth(tf.keras.layers.Layer):

  def __init__(self,
               iterations,
               u_channels,
               f_channels,
               wd):
    super(MgSmooth, self).__init__()

    self.iterations = iterations
    self.A = Conv2D(u_channels,
                    (3, 3),
                    strides = (1, 1),
                    padding = (1, 1),
                    use_bias = False,
                    kernel_regularizer = 
                      tf.keras.regularizers.L2(wd))
    self.B = Conv2D(u_channels,
                    (3, 3),
                    strides = (1, 1),
                    padding = (1, 1),
                    use_bias = False,
                    kernel_regularizer = 
                      tf.keras.regularizers.L2(wd))

    self.A_bns, self.B_bns = [], []
    for _ in range(self.iterations):
      self.A_bns.append(tf.keras.layers.BatchNormalization(momentum = .9,
                                                           epsilon = 1e-5))
      self.B_bns.append(tf.keras.layers.BatchNormalization(momentum = .9,
                                                           epsilon = 1e-5))

  def call(self, u, f):
    for i in range(self.iterations):
      error = tf.nn.relu(self.A_bns[i](f - self.A(u)))
      u = u + tf.nn.relu(self.B_bns[i](self.B(error)))
    return u, f

class MgBlock(tf.keras.layers.Layer):

  def __init__(self,
               iterations,
               u_channels,
               f_channels,
               A_old,
               wd):
    super(MgBlock, self).__init__()

    self.iterations = iterations
    self.Pi = Conv2D(u_channels,
                     (3, 3),
                     strides = (2, 2),
                     padding = (1, 1),
                     use_bias = False,
                     kernel_regularizer = 
                       tf.keras.regularizers.L2(wd))
    self.R = Conv2D(u_channels,
                    (3, 3),
                    strides = (2, 2),
                    padding = (1, 1),
                    use_bias = False,
                    kernel_regularizer = 
                      tf.keras.regularizers.L2(wd))
    self.A_old = A_old
    self.MgSmooth = MgSmooth(self.iterations,
                             u_channels,
                             f_channels,
                             wd)

    self.Pi_bn = tf.keras.layers.BatchNormalization(momentum = .9,
                                                    epsilon = 1e-5)
    self.R_bn = tf.keras.layers.BatchNormalization(momentum = .9,
                                                   epsilon = 1e-5)

  def call(self, u0, f0):
    u1 = tf.nn.relu(self.Pi_bn(self.Pi(u0)))
    error = tf.nn.relu(self.R_bn(self.R(f0 - self.A_old(u0))))
    f1 = error + self.MgSmooth.A(u1)
    u, f = self.MgSmooth(u1, f1)
    return u, f

class MgNet(tf.keras.Model):

  def __init__(self,
               iterations,
               u_channels,
               f_channels,
               in_shape,
               out_shape,
               wd):
    super(MgNet, self).__init__()

    self._name = "mgnet_tensorflow"
    self.iterations = iterations
    self.in_shape = in_shape
    self.A_init = Conv2D(u_channels[0],
                         (3, 3),
                         strides = (1, 1),
                         padding = (1, 1),
                         use_bias = False,
                         kernel_regularizer = 
                           tf.keras.regularizers.L2(wd))
    self.A_bn = tf.keras.layers.BatchNormalization(momentum = .9,
                                                   epsilon = 1e-5)

    self.blocks = []
    for i in range(len(self.iterations)):
      if i == 0:
        self.blocks.append(MgSmooth(iterations[i],
                                    u_channels[i],
                                    f_channels[i],
                                    wd))
        continue
      if i == 1:
        self.blocks.append(MgBlock(iterations[i],
                                   u_channels[i],
                                   f_channels[i],
                                   self.blocks[0].A,
                                   wd))
        continue
      self.blocks.append(MgBlock(iterations[i],
                                 u_channels[i],
                                 f_channels[i],
                                 self.blocks[i - 1].MgSmooth.A,
                                 wd))

    x = in_shape[0]
    for i in range(len(self.blocks) - 1):
      x = ((x + 2 - 3) // 2) + 1
    self.pool = tf.keras.layers.AveragePooling2D(pool_size = (x, x))
    self.fc = tf.keras.layers.Dense(out_shape,
                                    kernel_initializer = 
                                      HeUniform(np.sqrt(5),
                                                "fan_in",
                                                "leaky_relu"),
                                    bias_initializer = 
                                      HeUniform(np.sqrt(5),
                                                "fan_in",
                                                "leaky_relu",
                                                1 / np.sqrt(u_channels[-1])),
                                    kernel_regularizer = 
                                      tf.keras.regularizers.L2(wd))
    
    self.A_init._name = "initial_A_conv"
    self.A_bn._name = "initial_A_bn"
    for i, block in enumerate(self.blocks):
      block._name = f"block{i}"
      if i == 0:
        block.A._name = "block0_A_conv"
        block.B._name = "block0_B_conv"
        for j, bn in enumerate(block.A_bns):
          bn._name = f"block0_A_batchnorm{j}"
        for j, bn in enumerate(block.B_bns):
          bn._name = f"block0_B_batchnorm{j}"
      else:
        block.MgSmooth._name = f"block{i}_MgSmooth"
        block.MgSmooth.A._name = f"block{i}_A_conv"
        block.MgSmooth.B._name = f"block{i}_B_conv"
        for j, bn in enumerate(block.MgSmooth.A_bns):
          bn._name = f"block{i}_A_batchnorm{j}"
        for j, bn in enumerate(block.MgSmooth.B_bns):
          bn._name = f"block{i}_B_batchnorm{j}"
        block.Pi._name = f"block{i}_Pi_conv"
        block.R._name = f"block{i}_R_conv"
        block.Pi_bn._name = f"block{i}_Pi_batchnorm"
        block.R_bn._name = f"block{i}_R_batchnorm"
    self.pool._name = "final_average_pool"
    self.fc._name = "output_softmax"
  
  def call(self, u0):
    f = tf.nn.relu(self.A_bn(self.A_init(u0)))
    u = tf.multiply(f, 0)

    for block in self.blocks:
      u, f = block(u, f)
    u = self.pool(u)
    u = tf.squeeze(u, [-2, -3])
    u = self.fc(u)
    return u

In [7]:
# no batchnorm

class MgSmooth(tf.keras.layers.Layer):

  def __init__(self,
               iterations,
               u_channels,
               f_channels,
               wd):
    super(MgSmooth, self).__init__()

    self.iterations = iterations
    self.A = Conv2D(u_channels,
                    (3, 3),
                    strides = (1, 1),
                    padding = (1, 1),
                    use_bias = False,
                    kernel_regularizer = 
                      tf.keras.regularizers.L2(wd))
    self.B = Conv2D(f_channels,
                    (3, 3),
                    strides = (1, 1),
                    padding = (1, 1),
                    use_bias = False,
                    kernel_regularizer = 
                      tf.keras.regularizers.L2(wd))

  def call(self, u, f):
    for i in range(self.iterations):
      error = tf.nn.relu((f - self.A(u)))
      u = u + tf.nn.relu((self.B(error)))
    return u, f

class MgBlock(tf.keras.layers.Layer):

  def __init__(self,
               iterations,
               u_channels,
               f_channels,
               A_old,
               wd):
    super(MgBlock, self).__init__()

    self.iterations = iterations
    self.Pi = Conv2D(u_channels,
                     (3, 3),
                     strides = (2, 2),
                     padding = (1, 1),
                     use_bias = False,
                     kernel_regularizer = 
                       tf.keras.regularizers.L2(wd))
    self.R = Conv2D(f_channels,
                    (3, 3),
                    strides = (2, 2),
                    padding = (1, 1),
                    use_bias = False,
                    kernel_regularizer = 
                      tf.keras.regularizers.L2(wd))
    self.A_old = A_old
    self.MgSmooth = MgSmooth(self.iterations,
                             u_channels,
                             f_channels,
                             wd)

  def call(self, u0, f0):
    u1 = tf.nn.relu((self.Pi(u0)))
    error = tf.nn.relu((self.R(f0 - self.A_old(u0))))
    f1 = error + self.MgSmooth.A(u1)
    u, f = self.MgSmooth(u1, f1)
    return u, f

class MgNet(tf.keras.Model):

  def __init__(self,
               iterations,
               u_channels,
               f_channels,
               in_shape,
               out_shape,
               wd):
    super(MgNet, self).__init__()

    self._name = "mgnet_tensorflow"
    self.iterations = iterations
    self.in_shape = in_shape
    self.A_init = Conv2D(u_channels[0],
                         (3, 3),
                         strides = (1, 1),
                         padding = (1, 1),
                         use_bias = False,
                         kernel_regularizer = 
                           tf.keras.regularizers.L2(wd))

    self.blocks = []
    for i in range(len(self.iterations)):
      if i == 0:
        self.blocks.append(MgSmooth(iterations[i],
                                    u_channels[i],
                                    f_channels[i],
                                    wd))
        continue
      if i == 1:
        self.blocks.append(MgBlock(iterations[i],
                                   u_channels[i],
                                   f_channels[i],
                                   self.blocks[0].A,
                                   wd))
        continue
      self.blocks.append(MgBlock(iterations[i],
                                 u_channels[i],
                                 f_channels[i],
                                 self.blocks[i - 1].MgSmooth.A,
                                 wd))

    x = in_shape[0]
    for i in range(len(self.blocks) - 1):
      x = ((x + 2 - 3) // 2) + 1
    self.pool = tf.keras.layers.AveragePooling2D(pool_size = (x, x))
    self.fc = tf.keras.layers.Dense(out_shape,
                                    kernel_initializer = 
                                      tf.keras.initializers.Constant(0.),
                                    bias_initializer = 
                                      tf.keras.initializers.Constant(0.),
                                    kernel_regularizer = 
                                      tf.keras.regularizers.L2(wd))
    
    self.A_init._name = "initial_A_conv"
    for i, block in enumerate(self.blocks):
      block._name = f"block{i}"
      if i == 0:
        block.A._name = "block0_A_conv"
        block.B._name = "block0_B_conv"
      else:
        block.MgSmooth._name = f"block{i}_MgSmooth"
        block.MgSmooth.A._name = f"block{i}_A_conv"
        block.MgSmooth.B._name = f"block{i}_B_conv"
        block.Pi._name = f"block{i}_Pi_conv"
        block.R._name = f"block{i}_R_conv"
    self.pool._name = "final_average_pool"
    self.fc._name = "output_softmax"
  
  def call(self, u0):
    f = tf.nn.relu((self.A_init(u0)))
    u = tf.multiply(f, 0)

    for block in self.blocks:
      u, f = block(u, f)
      # f_temp = tf.transpose(f, [0, 3, 1, 2])
      # print(f_temp[0][0][0])
    u = self.pool(u)
    u = tf.squeeze(u, [-2, -3])
    u = self.fc(u)
    return u

In [8]:
def log_weights(writer, model, epoch, grads):
  with writer.as_default():
    with tf.summary.record_if(True):
      for layer in model.layers:
        for weight in layer.weights:
          weight_name = weight.name.replace(":", "_")
          histogram_weight_name = f"{model.name}{weight_name}"
          tf.summary.histogram(histogram_weight_name,
                               weight,
                               step = epoch)
      if grads:
        weight_names = [x.name.replace(":", "_") for x in model.trainable_weights]
        for i, grad in enumerate(grads):
          tf.summary.histogram(weight_names[i] + "_gradient",
                               grad,
                               step = epoch)
      writer.flush()

In [None]:
def lr_schedule(epoch, lr):
  if (epoch + 1) % epoch_step == 0:
    return lr / lr_step
  return lr

tf.debugging.set_log_device_placement(True)
gpus = tf.config.list_logical_devices("GPU")
strategy = tf.distribute.MirroredStrategy(gpus)
with strategy.scope():
  model = MgNet(iterations = iterations,
                u_channels = u_channels,
                f_channels = f_channels,
                in_shape = ds_info.features["image"].shape,
                out_shape = ds_info.features["label"].num_classes,
                wd = wd)
  
  loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)
  
  log_dir = "logs/tensorflow/" + datetime.now().strftime("%Y%m%d-%H%M%S")
  tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir = log_dir,
                                                        histogram_freq = 1)

  lr_s = tf.keras.callbacks.LearningRateScheduler(lr_schedule)
  # optimizer = tf.keras.optimizers.SGD(learning_rate = lr,
  #                                     momentum = momentum)
  optimizer = tf.keras.optimizers.SGD(learning_rate = lr)
  
  model.compile(optimizer = optimizer,
                loss = loss,
                metrics = ["accuracy"])

  history = model.fit(ds_train,
                      epochs = epochs,
                      validation_data = ds_test,
                      callbacks = [lr_s,
                                   tensorboard_callback])

model.summary()

In [None]:
# custom loop
model = MgNet(iterations = iterations,
              u_channels = u_channels,
              f_channels = f_channels,
              in_shape = ds_info.features["image"].shape,
              out_shape = ds_info.features["label"].num_classes,
              wd = wd)

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)
optimizer = tf.keras.optimizers.SGD(learning_rate = lr)

train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc = tf.keras.metrics.SparseCategoricalAccuracy()

log_dir = "logs/tensorflow/" + datetime.now().strftime("%Y%m%d-%H%M%S")
train_writer = tf.summary.create_file_writer(log_dir + "/train")
val_writer = tf.summary.create_file_writer(log_dir + "/validation")

for epoch in range(epochs):
  iterate = tqdm(enumerate(ds_train), total = 
                 -(ds_info.splits["train"].num_examples // -batch_size))
  for batch, (images, labels) in iterate:
    with tf.GradientTape() as tape:
      logits = model(images, training = True)
      loss_val = loss(labels, logits)
    grads = tape.gradient(loss_val, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    train_loss.update_state(loss_val)
    train_acc.update_state(labels, logits)
    iterate.set_description(f"loss: {train_loss.result():.2f} \
                            - accuracy: {train_acc.result():.4f}")
    
  for images, labels in ds_test:
    logits = model(images, training = False)
    val_loss.update_state(labels, logits)
    val_acc.update_state(labels, logits)

  log_weights(train_writer, model, epoch, grads)
  with train_writer.as_default():
    tf.summary.scalar("epoch_loss", train_loss.result(), epoch)
    tf.summary.scalar("epoch_accuracy", train_acc.result(), epoch)
  with val_writer.as_default():
    tf.summary.scalar("epoch_loss", val_loss.result(), epoch)
    tf.summary.scalar("epoch_accuracy", val_acc.result(), epoch)
  
  print(f"epoch: {epoch + 1} - validation loss: {val_loss.result():.4f} - validation accuracy: {val_acc.result():.4f}")
  
  train_loss.reset_states()
  train_acc.reset_states()
  val_loss.reset_states()
  val_acc.reset_states()

In [8]:
# Log weight initializations
model = MgNet(iterations = iterations,
              u_channels = u_channels,
              f_channels = f_channels,
              in_shape = ds_info.features["image"].shape,
              out_shape = ds_info.features["label"].num_classes,
              wd = wd)
model.build((batch_size,) + ds_info.features["image"].shape)
writer = tf.summary.create_file_writer("logs/tensorflow/init")
log_weights(writer, model, 0)

# **PyTorch**

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
import torchvision
import numpy as np
import logging
import os

In [10]:
def load_data(path, minibatch_size, dataset):
  if dataset == "cifar100":
    normalize = torchvision.transforms.Normalize(mean=(0.5071, 0.4865, 0.4409),
                                                 std=(0.2673, 0.2564, 0.2762))
    transform_train = torchvision.transforms.Compose(
      [torchvision.transforms.RandomCrop(32, padding = 4),
       torchvision.transforms.RandomHorizontalFlip(),
       torchvision.transforms.ToTensor(),
       normalize])
    transform_test  = torchvision.transforms.Compose(
      [torchvision.transforms.ToTensor(),
       normalize])

    trainset = torchvision.datasets.CIFAR100(root = path,
                                             train = True,
                                             download = True,
                                             transform = transform_train)
    trainloader = torch.utils.data.DataLoader(trainset,
                                              batch_size = minibatch_size,
                                              num_workers = 4,
                                              shuffle = True)

    testset = torchvision.datasets.CIFAR100(root = path,
                                            train = False,
                                            download = True,
                                            transform = transform_test)
    testloader = torch.utils.data.DataLoader(testset,
                                             batch_size = minibatch_size,
                                             num_workers = 4,
                                             shuffle = False)
    num_classes = 100
  
  return trainloader, testloader, num_classes

trainloader, testloader, num_classes = load_data("~/pytorch_datasets",
                                                 128,
                                                 "cifar100")

Files already downloaded and verified
Files already downloaded and verified


In [12]:
class MgIte(nn.Module):
  
  def __init__(self,
               A,
               B):
    super().__init__()
    self.A = A
    self.B = B        
    self.bn1 = nn.BatchNorm2d(A.weight.size(0))
    self.bn2 = nn.BatchNorm2d(B.weight.size(0))

  def forward(self, out):
    u, f = out
    u = u + F.relu(self.bn2(self.B(F.relu(self.bn1((f - self.A(u)))))))
    out = (u, f)
    return out
    
class MgRestriction(nn.Module):
  
  def __init__(self,
               A_old,
               A_conv,
               Pi_conv,
               R_conv):
    super().__init__()
    self.A_old = A_old
    self.A_conv = A_conv
    self.Pi_conv = Pi_conv
    self.R_conv = R_conv

    self.bn1 = nn.BatchNorm2d(Pi_conv.weight.size(0))
    self.bn2 = nn.BatchNorm2d(A_old.weight.size(0))

  def forward(self, out):
    u_old, f_old = out
    u = F.relu(self.bn1(self.Pi_conv(u_old)))
    f = F.relu(self.bn2(self.R_conv(f_old - self.A_old(u_old)))) + self.A_conv(u)
    out = (u, f)
    return out

class MgNet(nn.Module):
  
  def __init__(self,
               dataset,
               num_iterations,
               num_channel_f,
               num_channel_u,
               wise_B,
               num_classes):
    super().__init__()
    self.num_iterations = num_iterations
    self.num_channel_f = num_channel_f
    self.num_channel_u = num_channel_u
    self.wise_B = wise_B
    
    if dataset == "mnist":
      self.num_channel_input = 1
    else:
      self.num_channel_input = 3
      
    self.conv1 = nn.Conv2d(self.num_channel_input,
                           self.num_channel_f,
                           kernel_size = 3,
                           stride = 1,
                           padding = 1,
                           bias = False)
    self.bn1 = nn.BatchNorm2d(self.num_channel_f)        

    A_conv = nn.Conv2d(self.num_channel_u,
                       self.num_channel_f,
                       kernel_size = 3,
                       stride = 1,
                       padding = 1,
                       bias = False)
    if not self.wise_B:
      B_conv = nn.Conv2d(self.num_channel_f,
                         self.num_channel_u,
                         kernel_size = 3,
                         stride = 1,
                         padding = 1,
                         bias = False)
    layers = []
    for l, num_iteration_l in enumerate(self.num_iterations):
      for i in range(num_iteration_l):
        if self.wise_B:
          B_conv = nn.Conv2d(self.num_channel_f,
                             self.num_channel_u,
                             kernel_size = 3,
                             stride = 1,
                             padding = 1,
                             bias = False)
        layers.append(MgIte(A_conv,
                            B_conv))
      setattr(self,
              "layer" + str(l),
              nn.Sequential(*layers))

      if l < len(self.num_iterations) - 1:
        A_old = A_conv
        A_conv = nn.Conv2d(self.num_channel_u,
                           self.num_channel_f,
                           kernel_size = 3,
                           stride = 1,
                           padding = 1,
                           bias = False)
        if not self.wise_B:
          B_conv = nn.Conv2d(self.num_channel_f,
                             self.num_channel_u,
                             kernel_size = 3,
                             stride = 1,
                             padding = 1,
                             bias = False)
        Pi_conv = nn.Conv2d(self.num_channel_u,
                            self.num_channel_u,
                            kernel_size = 3,
                            stride = 2,
                            padding = 1,
                            bias = False)
        R_conv = nn.Conv2d(self.num_channel_f,
                           self.num_channel_u,
                           kernel_size = 3,
                           stride = 2,
                           padding = 1,
                           bias = False)
        layers= [MgRestriction(A_old,
                               A_conv,
                               Pi_conv,
                               R_conv)]

    self.pooling = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Linear(self.num_channel_u,
                        num_classes)

  def forward(self, u):
    f = F.relu(self.bn1(self.conv1(u)))
    if torch.cuda.is_available():
      u = torch.zeros(f.size(),
                      device = torch.device("cuda"))
    else:
      u = torch.zeros(f.size())
    out = (u, f)

    for l in range(len(self.num_iterations)):
      out = getattr(self,
                    "layer" + str(l))(out)
    u, f = out
    u = self.pooling(u)
    u = u.view(u.shape[0], -1)
    u = self.fc(u)
    return u

In [18]:
# no batchnorm

class MgIte(nn.Module):
  
  def __init__(self,
               A,
               B):
    super().__init__()
    self.A = A
    self.B = B        

  def forward(self, out):
    u, f = out
    u = u + F.relu((self.B(F.relu(((f - self.A(u)))))))
    out = (u, f)
    return out
    
class MgRestriction(nn.Module):
  
  def __init__(self,
               A_old,
               A_conv,
               Pi_conv,
               R_conv):
    super().__init__()
    self.A_old = A_old
    self.A_conv = A_conv
    self.Pi_conv = Pi_conv
    self.R_conv = R_conv

  def forward(self, out):
    u_old, f_old = out
    u = F.relu((self.Pi_conv(u_old)))
    f = F.relu((self.R_conv(f_old - self.A_old(u_old)))) + self.A_conv(u)
    out = (u, f)
    return out

class MgNet(nn.Module):
  
  def __init__(self,
               dataset,
               num_iterations,
               num_channel_f,
               num_channel_u,
               wise_B,
               num_classes):
    super().__init__()
    self.num_iterations = num_iterations
    self.num_channel_f = num_channel_f
    self.num_channel_u = num_channel_u
    self.wise_B = wise_B
    
    if dataset == "mnist":
      self.num_channel_input = 1
    else:
      self.num_channel_input = 3
      
    self.conv1 = nn.Conv2d(self.num_channel_input,
                           self.num_channel_f,
                           kernel_size = 3,
                           stride = 1,
                           padding = 1,
                           bias = False)

    A_conv = nn.Conv2d(self.num_channel_u,
                       self.num_channel_f,
                       kernel_size = 3,
                       stride = 1,
                       padding = 1,
                       bias = False)
    if not self.wise_B:
      B_conv = nn.Conv2d(self.num_channel_f,
                         self.num_channel_u,
                         kernel_size = 3,
                         stride = 1,
                         padding = 1,
                         bias = False)
    layers = []
    for l, num_iteration_l in enumerate(self.num_iterations):
      for i in range(num_iteration_l):
        if self.wise_B:
          B_conv = nn.Conv2d(self.num_channel_f,
                             self.num_channel_u,
                             kernel_size = 3,
                             stride = 1,
                             padding = 1,
                             bias = False)
        layers.append(MgIte(A_conv,
                            B_conv))
      setattr(self,
              "layer" + str(l),
              nn.Sequential(*layers))

      if l < len(self.num_iterations) - 1:
        A_old = A_conv
        A_conv = nn.Conv2d(self.num_channel_u,
                           self.num_channel_f,
                           kernel_size = 3,
                           stride = 1,
                           padding = 1,
                           bias = False)
        if not self.wise_B:
          B_conv = nn.Conv2d(self.num_channel_f,
                             self.num_channel_u,
                             kernel_size = 3,
                             stride = 1,
                             padding = 1,
                             bias = False)
        Pi_conv = nn.Conv2d(self.num_channel_u,
                            self.num_channel_u,
                            kernel_size = 3,
                            stride = 2,
                            padding = 1,
                            bias = False)
        R_conv = nn.Conv2d(self.num_channel_f,
                           self.num_channel_u,
                           kernel_size = 3,
                           stride = 2,
                           padding = 1,
                           bias = False)
        layers = [MgRestriction(A_old,
                                A_conv,
                                Pi_conv,
                                R_conv)]

    self.pooling = nn.AdaptiveAvgPool2d(1)
    self.fc = nn.Linear(self.num_channel_u,
                        num_classes)

  def forward(self, u):
    f = F.relu((self.conv1(u)))
    if torch.cuda.is_available():
      u = torch.zeros(f.size(),
                      device = torch.device("cuda"))
    else:
      u = torch.zeros(f.size())
    out = (u, f)

    for l in range(len(self.num_iterations)):
      out = getattr(self,
                    "layer" + str(l))(out)
      u, f = out
      # print(f[0][0][0])
    u, f = out
    u = self.pooling(u)
    u = u.view(u.shape[0], -1)
    u = self.fc(u)
    return u

In [11]:
def log_weights_pytorch(writer, model, epoch):
  writer.add_histogram("mgnet_pytorch/initial_A_conv/kernel_0", model.conv1.weight, epoch)
  writer.add_histogram("mgnet_pytorch/initial_A_bn/gamma_0", model.bn1.weight, epoch)
  writer.add_histogram("mgnet_pytorch/initial_A_bn/beta_0", model.bn1.bias, epoch)
  writer.add_histogram("mgnet_pytorch/initial_A_bn/moving_mean_0", model.bn1.running_mean, epoch)
  writer.add_histogram("mgnet_pytorch/initial_A_bn/moving_variance_0", model.bn1.running_var, epoch)
  for l in range(len(model.num_iterations)):
    seq = getattr(model, "layer" + str(l))
    if l == 0:
      writer.add_histogram("mgnet_pytorch/block0/block0_A_conv/kernel_0", seq[0].A.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_conv/kernel_0", seq[0].B.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm0/gamma_0", seq[0].bn1.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm0/beta_0", seq[0].bn1.bias, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm0/moving_mean_0", seq[0].bn1.running_mean, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm0/moving_variance_0", seq[0].bn1.running_var, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm0/gamma_0", seq[0].bn2.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm0/beta_0", seq[0].bn2.bias, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm0/moving_mean_0", seq[0].bn2.running_mean, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm0/moving_variance_0", seq[0].bn2.running_var, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm1/gamma_0", seq[1].bn1.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm1/beta_0", seq[1].bn1.bias, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm1/moving_mean_0", seq[1].bn1.running_mean, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_batchnorm1/moving_variance_0", seq[1].bn1.running_var, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm1/gamma_0", seq[1].bn2.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm1/beta_0", seq[1].bn2.bias, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm1/moving_mean_0", seq[1].bn2.running_mean, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_batchnorm1/moving_variance_0", seq[1].bn2.running_var, epoch)
    else:
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_conv/kernel_0", seq[0].Pi_conv.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_conv/kernel_0", seq[0].R_conv.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_batchnorm/gamma_0", seq[0].bn1.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_batchnorm/beta_0", seq[0].bn1.bias, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_batchnorm/moving_mean_0", seq[0].bn1.running_mean, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_batchnorm/moving_variance_0", seq[0].bn1.running_var, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_batchnorm/gamma_0", seq[0].bn2.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_batchnorm/beta_0", seq[0].bn2.bias, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_batchnorm/moving_mean_0", seq[0].bn2.running_mean, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_batchnorm/moving_variance_0", seq[0].bn2.running_var, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_conv/kernel_0", seq[1].A.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_conv/kernel_0", seq[1].B.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm0/gamma_0", seq[1].bn1.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm0/beta_0", seq[1].bn1.bias, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm0/moving_mean_0", seq[1].bn1.running_mean, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm0/moving_variance_0", seq[1].bn1.running_var, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm0/gamma_0", seq[1].bn2.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm0/beta_0", seq[1].bn2.bias, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm0/moving_mean_0", seq[1].bn2.running_mean, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm0/moving_variance_0", seq[1].bn2.running_var, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm1/gamma_0", seq[2].bn1.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm1/beta_0", seq[2].bn1.bias, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm1/moving_mean_0", seq[2].bn1.running_mean, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_batchnorm1/moving_variance_0", seq[2].bn1.running_var, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm1/gamma_0", seq[2].bn2.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm1/beta_0", seq[2].bn2.bias, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm1/moving_mean_0", seq[2].bn2.running_mean, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_batchnorm1/moving_variance_0", seq[2].bn2.running_var, epoch)
  writer.add_histogram("mgnet_pytorch/output_softmax/kernel_0", model.fc.weight, epoch)
  writer.add_histogram("mgnet_pytorch/output_softmax/bias_0", model.fc.bias, epoch)

In [12]:
def log_weights_pytorch(writer, model, epoch):
  # model = model.module
  writer.add_histogram("mgnet_pytorch/initial_A_conv/kernel_0", model.conv1.weight, epoch)
  writer.add_histogram("mgnet_pytorch/initial_A_conv/kernel_0_grad", model.conv1.weight.grad, epoch)
  for l in range(len(model.num_iterations)):
    seq = getattr(model, "layer" + str(l))
    if l == 0:
      writer.add_histogram("mgnet_pytorch/block0/block0_A_conv/kernel_0", seq[0].A.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_conv/kernel_0", seq[0].B.weight, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_A_conv/kernel_0_grad", seq[0].A.weight.grad, epoch)
      writer.add_histogram("mgnet_pytorch/block0/block0_B_conv/kernel_0_grad", seq[0].B.weight.grad, epoch)
    else:
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_conv/kernel_0", seq[0].Pi_conv.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_conv/kernel_0", seq[0].R_conv.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_conv/kernel_0", seq[1].A.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_conv/kernel_0", seq[1].B.weight, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_Pi_conv/kernel_0_grad", seq[0].Pi_conv.weight.grad, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_R_conv/kernel_0_grad", seq[0].R_conv.weight.grad, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_A_conv/kernel_0_grad", seq[1].A.weight.grad, epoch)
      writer.add_histogram(f"mgnet_pytorch/block{l}/block{l}_B_conv/kernel_0_grad", seq[1].B.weight.grad, epoch)
  writer.add_histogram("mgnet_pytorch/output_softmax/kernel_0", model.fc.weight, epoch)
  writer.add_histogram("mgnet_pytorch/output_softmax/bias_0", model.fc.bias, epoch)
  writer.add_histogram("mgnet_pytorch/output_softmax/kernel_0_grad", model.fc.weight.grad, epoch)
  writer.add_histogram("mgnet_pytorch/output_softmax/bias_0_grad", model.fc.bias.grad, epoch)

In [135]:
def adjust_learning_rate(optimizer, epoch, init_lr):
  if epoch == 0:
    return init_lr
  lr = init_lr * 0.1 ** (epoch // 30)
  for param_group in optimizer.param_groups:
    param_group["lr"] = lr
  return lr

def train_process(model, num_epochs, lr, trainloader, testloader):
  criterion = nn.CrossEntropyLoss()
  # optimizer = optim.SGD(model.parameters(),
  #                       lr = lr,
  #                       momentum = 0.9,
  #                       weight_decay = 0.0005)
  optimizer = optim.SGD(model.parameters(),
                        lr = lr)

  log_dir = "logs/pytorch/" + datetime.now().strftime("%Y%m%d-%H%M%S")
  train_writer = SummaryWriter(log_dir = log_dir + "/train")
  val_writer = SummaryWriter(log_dir = log_dir + "/validation")
  
  for epoch in range(num_epochs):
    current_lr = adjust_learning_rate(optimizer, epoch, lr)
    total_batches = -(50000 // -batch_size)
    iterate = tqdm(enumerate(trainloader), total = 
                   total_batches)
    model.train()
    total_train_loss = 0
    for i, (images, labels) in iterate:
      if use_cuda:
        images = images.cuda()
        labels = labels.cuda()

      outputs = model(images) 
      loss = criterion(outputs, labels)
      total_train_loss += loss

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    total_val_batches = -(10000 // -batch_size)
    def calculate_acc(loader, training, total_batches):
      if training:
        model.train()
      if not training:
        model.eval()
      total_loss = 0
      correct, total = 0, 0
      for i, (images, labels) in enumerate(loader):
        with torch.no_grad():
          if use_cuda:
            images = images.cuda()
            labels = labels.cuda()
          outputs = model(images)
          loss = criterion(outputs, labels)
          p_max, predicted = torch.max(outputs, 1) 
          total += labels.size(0)
          correct += (predicted == labels).sum()
          total_loss += loss
      return float(correct) / total, total_loss / total_batches

    train_acc, train_loss = calculate_acc(trainloader, True, total_batches)
    val_acc, val_loss = calculate_acc(testloader, False, total_val_batches)
    
    train_writer.add_scalar("epoch_loss", train_loss, epoch)
    train_writer.add_scalar("epoch_accuracy", train_acc, epoch)
    val_writer.add_scalar("epoch_loss", val_loss, epoch)
    val_writer.add_scalar("epoch_accuracy", val_acc, epoch)
    
    log_weights_pytorch(train_writer, model, epoch)
    
    print(f"training loss: {train_loss} - validation loss: {val_loss}")
    print(f"epoch: {epoch + 1} - training accuracy: {train_acc} - validation accuracy: {val_acc}")

In [None]:
model = MgNet(dataset = dataset,
              num_iterations = iterations,
              num_channel_f = 256,
              num_channel_u = 256,
              wise_B = False,
              num_classes = 100)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model)
model.to(device)
train_process(model = model,
              num_epochs = 150,
              lr = .1,
              trainloader = trainloader,
              testloader = testloader)

In [9]:
# Log weight initializations
model = MgNet(dataset = dataset,
              num_iterations = iterations,
              num_channel_f = 256,
              num_channel_u = 256,
              wise_B = False,
              num_classes = 100)
model = model.cuda()
writer = SummaryWriter(log_dir = "logs/pytorch/init")
init_iter = iter(trainloader)
images, labels = init_iter.next()
writer.add_graph(model, images.cuda())
log_weights_pytorch(writer, model, 0)

# **TEST**

In [13]:
dataset = []
for batch, (images, labels) in enumerate(trainloader):
  dataset.append((images, labels))
test_dataset = []
for batch, (images, labels) in enumerate(testloader):
  test_dataset.append((images, labels))

In [14]:
def calc_bound(shape):  
  gain = np.sqrt(2.0 / (1 + np.sqrt(5) ** 2))
  torch_shape = np.flip(shape) 
  num_input_fmaps = torch_shape[1]
  receptive_field_size = 1
  if len(torch_shape) > 2:
    for s in torch_shape[2:]:
      receptive_field_size *= s
  fan_in = num_input_fmaps * receptive_field_size

  std = gain / np.sqrt(fan_in)
  return np.sqrt(3.0) * std

A_bound = calc_bound((3, 3, 3, 256))
A_init = np.random.uniform(low = -A_bound,
                           high = A_bound,
                           size = (3, 3, 3, 256))
conv_bound = calc_bound((3, 3, 256, 256))
conv_init = []
for i in range(14):
  conv_init.append(np.random.uniform(low = -conv_bound,
                                     high = conv_bound,
                                     size = (3, 3, 256, 256)))
fc_bound = calc_bound((256, 100))
fc_init = np.random.uniform(low = -fc_bound,
                            high = fc_bound,
                            size = (256, 100))
bias_bound = 1 / np.sqrt(256)
bias_init = np.random.uniform(low = -bias_bound,
                              high = bias_bound,
                              size = (100))

In [15]:
# tensorflow test

model_tf = MgNet(iterations = iterations,
              u_channels = u_channels,
              f_channels = f_channels,
              in_shape = ds_info.features["image"].shape,
              out_shape = ds_info.features["label"].num_classes,
              wd = wd)
model_tf.build((128, 32, 32, 3))

In [16]:
model_tf.A_init.weights[0] = tf.Variable(initial_value = A_init)
model_tf.blocks[0].A.weights[0] = tf.Variable(initial_value = conv_init[0])
model_tf.blocks[0].B.weights[0] = tf.Variable(initial_value = conv_init[1])
model_tf.blocks[1].Pi.weights[0] = tf.Variable(initial_value = conv_init[2])
model_tf.blocks[1].R.weights[0] = tf.Variable(initial_value = conv_init[3])
model_tf.blocks[1].MgSmooth.A.weights[0] = tf.Variable(initial_value = conv_init[4])
model_tf.blocks[1].MgSmooth.B.weights[0] = tf.Variable(initial_value = conv_init[5])
model_tf.blocks[2].Pi.weights[0] = tf.Variable(initial_value = conv_init[6])
model_tf.blocks[2].R.weights[0] = tf.Variable(initial_value = conv_init[7])
model_tf.blocks[2].MgSmooth.A.weights[0] = tf.Variable(initial_value = conv_init[8])
model_tf.blocks[2].MgSmooth.B.weights[0] = tf.Variable(initial_value = conv_init[9])
model_tf.blocks[3].Pi.weights[0] = tf.Variable(initial_value = conv_init[10])
model_tf.blocks[3].R.weights[0] = tf.Variable(initial_value = conv_init[11])
model_tf.blocks[3].MgSmooth.A.weights[0] = tf.Variable(initial_value = conv_init[12])
model_tf.blocks[3].MgSmooth.B.weights[0] = tf.Variable(initial_value = conv_init[13])
model_tf.fc.weights[0] = tf.Variable(initial_value = fc_init)
model_tf.fc.weights[1] = tf.Variable(initial_value = bias_init)

In [24]:
model_tf(tf.transpose(tf.convert_to_tensor(dataset[0][0].numpy()), [0, 2, 3, 1]))

<tf.Tensor: shape=(128, 100), dtype=float32, numpy=
array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)>

In [None]:
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits = True)
optimizer = tf.keras.optimizers.SGD(learning_rate = lr)

train_loss = tf.keras.metrics.Mean()
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc = tf.keras.metrics.SparseCategoricalAccuracy()

log_dir = "logs/tensorflow/" + datetime.now().strftime("%Y%m%d-%H%M%S")
train_writer = tf.summary.create_file_writer(log_dir + "/train")
val_writer = tf.summary.create_file_writer(log_dir + "/validation")

for epoch in range(epochs):
  iterate = tqdm(enumerate(dataset), total = 
                 -(ds_info.splits["train"].num_examples // -batch_size))
  for batch, (images, labels) in iterate:
    images = tf.transpose(tf.convert_to_tensor(images.numpy()), [0, 2, 3, 1])
    labels = tf.convert_to_tensor(labels.numpy())
    with tf.GradientTape() as tape:
      logits = model_tf(images, training = True)
      loss_val = loss(labels, logits)
    grads = tape.gradient(loss_val, model_tf.trainable_weights)
    optimizer.apply_gradients(zip(grads, model_tf.trainable_weights))
    
    train_loss.update_state(loss_val)
    train_acc.update_state(labels, logits)
    iterate.set_description(f"loss: {train_loss.result():.2f} \
                            - accuracy: {train_acc.result():.4f}")
    
  for images, labels in test_dataset:
    images = tf.transpose(tf.convert_to_tensor(images.numpy()), [0, 2, 3, 1])
    labels = tf.convert_to_tensor(labels.numpy())
    logits = model_tf(images, training = False)
    val_loss.update_state(labels, logits)
    val_acc.update_state(labels, logits)

  log_weights(train_writer, model_tf, epoch, grads)
  with train_writer.as_default():
    tf.summary.scalar("epoch_loss", train_loss.result(), epoch)
    tf.summary.scalar("epoch_accuracy", train_acc.result(), epoch)
  with val_writer.as_default():
    tf.summary.scalar("epoch_loss", val_loss.result(), epoch)
    tf.summary.scalar("epoch_accuracy", val_acc.result(), epoch)
  
  print(f"epoch: {epoch + 1} - validation loss: {val_loss.result():.4f} - validation accuracy: {val_acc.result():.4f}")
  
  train_loss.reset_states()
  train_acc.reset_states()
  val_loss.reset_states()
  val_acc.reset_states()

In [19]:
# pytorch test

model_pt = MgNet(dataset = dataset,
              num_iterations = iterations,
              num_channel_f = 256,
              num_channel_u = 256,
              wise_B = False,
              num_classes = 100)

In [20]:
model_pt.conv1.weight = nn.Parameter(torch.from_numpy(np.transpose(A_init, [3, 2, 0, 1])).float())
getattr(model_pt, "layer0")[0].A.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[0], [3, 2, 0, 1])).float())
getattr(model_pt, "layer0")[0].B.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[1], [3, 2, 0, 1])).float())
getattr(model_pt, "layer1")[0].Pi_conv.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[2], [3, 2, 0, 1])).float())
getattr(model_pt, "layer1")[0].R_conv.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[3], [3, 2, 0, 1])).float())
getattr(model_pt, "layer1")[1].A.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[4], [3, 2, 0, 1])).float())
getattr(model_pt, "layer1")[1].B.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[5], [3, 2, 0, 1])).float())
getattr(model_pt, "layer2")[0].Pi_conv.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[6], [3, 2, 0, 1])).float())
getattr(model_pt, "layer2")[0].R_conv.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[7], [3, 2, 0, 1])).float())
getattr(model_pt, "layer2")[1].A.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[8], [3, 2, 0, 1])).float())
getattr(model_pt, "layer2")[1].B.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[9], [3, 2, 0, 1])).float())
getattr(model_pt, "layer3")[0].Pi_conv.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[10], [3, 2, 0, 1])).float())
getattr(model_pt, "layer3")[0].R_conv.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[11], [3, 2, 0, 1])).float())
getattr(model_pt, "layer3")[1].A.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[12], [3, 2, 0, 1])).float())
getattr(model_pt, "layer3")[1].B.weight = nn.Parameter(torch.from_numpy(np.transpose(conv_init[13], [3, 2, 0, 1])).float())
model_pt.fc.weight = nn.Parameter(torch.from_numpy(np.transpose(fc_init, [1, 0])).float())
model_pt.fc.bias = nn.Parameter(torch.from_numpy(bias_init).float())

In [21]:
test_input = dataset[0][0].cuda()
model_pt = model_pt.cuda()
model_pt(test_input)

tensor([[ 8.3921e-03, -3.7010e-02,  1.9178e-02,  ...,  1.1835e-02,
         -8.6755e-03, -6.8535e-02],
        [ 8.7179e-03, -3.3621e-02,  2.6003e-02,  ...,  1.2554e-02,
         -7.5911e-03, -6.9567e-02],
        [ 2.8036e-02, -2.4973e-02,  3.3461e-02,  ...,  1.4026e-02,
         -3.5621e-03, -7.0676e-02],
        ...,
        [ 4.5102e-03, -4.1469e-02,  1.7230e-02,  ...,  1.0431e-02,
         -3.6491e-03, -6.3158e-02],
        [ 2.3616e-03, -4.7221e-02,  9.6849e-03,  ...,  1.0257e-02,
         -5.8721e-03, -6.5567e-02],
        [ 1.8216e-02, -3.5550e-02,  2.4981e-02,  ...,  1.0037e-02,
         -2.5053e-06, -6.2280e-02]], device='cuda:0', grad_fn=<AddmmBackward0>)

In [162]:
def adjust_learning_rate(optimizer, epoch, init_lr):
  if epoch == 0:
    return init_lr
  lr = init_lr * 0.1 ** (epoch // 30)
  for param_group in optimizer.param_groups:
    param_group["lr"] = lr
  return lr

def train_process(model, num_epochs, lr, trainloader, testloader):
  criterion = nn.CrossEntropyLoss()
  # optimizer = optim.SGD(model.parameters(),
  #                       lr = lr,
  #                       momentum = 0.9,
  #                       weight_decay = 0.0005)
  optimizer = optim.SGD(model.parameters(),
                        lr = lr)

  log_dir = "logs/pytorch/" + datetime.now().strftime("%Y%m%d-%H%M%S")
  train_writer = SummaryWriter(log_dir = log_dir + "/train")
  val_writer = SummaryWriter(log_dir = log_dir + "/validation")
  
  for epoch in range(num_epochs):
    current_lr = adjust_learning_rate(optimizer, epoch, lr)
    total_batches = -(50000 // -batch_size)
    iterate = tqdm(enumerate(trainloader), total = 
                   total_batches)
    model.train()
    total_train_loss = 0
    for i, (images, labels) in iterate:
      if use_cuda:
        images = images.cuda()
        labels = labels.cuda()

      outputs = model(images) 
      loss = criterion(outputs, labels)
      total_train_loss += loss

      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

    total_val_batches = -(10000 // -batch_size)
    def calculate_acc(loader, training, total_batches):
      if training:
        model.train()
      if not training:
        model.eval()
      total_loss = 0
      correct, total = 0, 0
      for i, (images, labels) in enumerate(loader):
        with torch.no_grad():
          if use_cuda:
            images = images.cuda()
            labels = labels.cuda()
          outputs = model(images)
          loss = criterion(outputs, labels)
          p_max, predicted = torch.max(outputs, 1) 
          total += labels.size(0)
          correct += (predicted == labels).sum()
          total_loss += loss
      return float(correct) / total, total_loss / total_batches

    train_acc, train_loss = calculate_acc(trainloader, True, total_batches)
    val_acc, val_loss = calculate_acc(testloader, False, total_val_batches)
    
    train_writer.add_scalar("epoch_loss", train_loss, epoch)
    train_writer.add_scalar("epoch_accuracy", train_acc, epoch)
    val_writer.add_scalar("epoch_loss", val_loss, epoch)
    val_writer.add_scalar("epoch_accuracy", val_acc, epoch)
    
    log_weights_pytorch(train_writer, model, epoch)
    
    print(f"training loss: {train_loss} - validation loss: {val_loss}")
    print(f"epoch: {epoch + 1} - training accuracy: {train_acc} - validation accuracy: {val_acc}")

In [None]:
model_pt = model_pt.cuda()
train_process(model = model_pt,
              num_epochs = 150,
              lr = .1,
              trainloader = dataset,
              testloader = test_dataset)

# **MgNet SCAN**

In [None]:
class AttentionModule(tf.keras.layers.Layer):

  def __init__(self,
                channels,
                wd):
    super(AttentionModule, self).__init__()

    self.conv = Conv2D(channels,
                        (3, 3),
                        strides = (2, 2),
                        padding = (1, 1),
                        use_bias = False,
                        kernel_regularizer = 
                          tf.keras.regularizers.L2(wd))
    self.bn1 = tf.keras.layers.BatchNormalization(momentum = .9,
                                                  epsilon = 1e-5)
    self.deconv = tf.keras.layers.Conv2DTranspose(channels,
                                                  (3, 3),
                                                  strides = (2, 2),
                                                  padding = (1, 1),
                                                  use_bias = False,
                                                  kernel_regularizer = 
                                                    tf.keras.regularizers.L2(wd))
    self.bn2 = tf.keras.layers.BatchNormalization(momentum = .9,
                                                  epsilon = 1e-5)

    self.attn_conv1 = tf.keras.layers.Conv2D()
    self.attn_conv2 = tf.keras.layers.Conv2D()

  def call(self, u):
    u_d = self.conv(u)
    u_d = tf.nn.relu(self.bn1(u_d))
    u_d = self.deconv(u_d)
    u_d = tf.math.sigmoid(self.bn2(u_d))

    u_a = self.attn_conv1(u)
    u_a = self.attn_conv2(u_a)

    u = tf.matmul(u_d, u_a)
    return u

class ShallowClassifier(tf.keras.layers.Layer):

  def __init__(self,
               in_shape,
               out_shape):
    super(ShallowClassifier, self).__init__()

    self.bottleneck_channels = 512
    self.bottleneck = tf.keras.layers.Dense(self.bottleneck_channels,
                                            kernel_initializer = 
                                              HeUniform(np.sqrt(5),
                                                        "fan_in",
                                                        "leaky_relu"),
                                            bias_initializer = 
                                              HeUniform(np.sqrt(5),
                                                        "fan_in",
                                                        "leaky_relu",
                                                        1 / np.sqrt(in_shape[-1])),
                                            kernel_regularizer = 
                                              tf.keras.regularizers.L2(wd))
    self.fc = tf.keras.layers.Dense(out_shape,
                                    kernel_initializer = 
                                      HeUniform(np.sqrt(5),
                                                "fan_in",
                                                "leaky_relu"),
                                    bias_initializer = 
                                      HeUniform(np.sqrt(5),
                                                "fan_in",
                                                "leaky_relu",
                                                1 / np.sqrt(self.bottleneck_channels)),
                                    kernel_regularizer = 
                                      tf.keras.regularizers.L2(wd))

  def call(self, u):
    u = self.bottleneck(u)
    u = self.fc(u)
    return u