In [1]:
import sys
sys.path.append('../')
import importlib
import models.transunet as vit
import models.decoder_layers as decoder_layers
import models.decoder_layers as encoder_layers
import models.utils as ut
import experiments.config as conf
# import data_processing.dataset_synapse as dp 
import data_processing.data_parser as dp
import data_processing.dataset_synapse as ds
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf 
import tensorflow_addons as tfa
import models.resnet_v2 as res

tfk = tf.keras
tfkl= tfk.layers

In [20]:
x = tf.constant(np.zeros((1,224,224,3)))
y = tf.constant(np.ones((1,224,224,1)))
z = tf.concat([x, y], axis=-1)

In [29]:
z[:,:,:,:3].shape

TensorShape([1, 224, 224, 3])

In [7]:
class StdConv2D(tfkl.Conv2D):

    def call(self, x):
        w = self.kernel
        m, v = tf.nn.moments(w, axes=[0, 1, 2], keepdims=True)
        w = (w-m) / tf.sqrt(v+1e-5)
        return tf.nn.conv2d(x, w, self.strides, self.padding.upper(), 
                            data_format="NHWC" if self.data_format == "channels_last" else "NCHW", dilations=self.dilation_rate, name=self.name)

In [8]:
conv = StdConv2D(filters=64, kernel_size=7, strides=2, padding="same")
dummy = np.random.rand(1,224,224,3)

In [9]:
conv(dummy)

<tf.Tensor: shape=(1, 112, 112, 64), dtype=float32, numpy=
array([[[[-1.0168917e+00, -2.7894324e-01,  2.9173000e+00, ...,
          -1.6748133e+00,  3.6070344e+00,  4.3997159e+00],
         [ 2.2356491e+00,  8.9061537e+00,  2.2993343e+00, ...,
          -6.3995085e+00, -4.8971076e+00, -2.7893918e+00],
         [-1.9005616e+00,  2.3726332e+00,  3.0995746e+00, ...,
           1.5501729e+00, -3.4278903e+00,  1.1393967e+00],
         ...,
         [ 7.9036576e-01,  3.7301404e+00,  2.0970452e+00, ...,
          -9.5213127e+00, -2.0025275e+00,  1.2449298e+00],
         [ 5.7282060e-01,  7.1527004e+00, -1.3646924e+00, ...,
          -4.5222034e+00,  2.9717882e+00, -3.3893976e+00],
         [ 6.0545838e-01,  4.7432780e-01, -2.0039618e+00, ...,
           4.9592608e-01, -3.3688035e+00,  5.8225274e-01]],

        [[-2.5489106e+00, -5.8954654e+00,  2.1209078e+00, ...,
           1.3287851e+00, -6.3876629e-01,  5.3617811e+00],
         [-5.2728481e+00,  1.1989442e+00, -1.0543882e-02, ...,
        

In [3]:
gn1 = tfa.layers.GroupNormalization(
            32, epsilon=1e-6, beta_regularizer=tfk.regularizers.L2(
                1e-4))

In [114]:
conv = tfk.layers.Conv2D(64,3,use_bias=False)
gn = tfa.layers.GroupNormalization(64)
_ = gn(conv(dummy))


In [94]:
y = tf.keras.layers.ZeroPadding2D(padding=((0,1),(0,1)))(x)
vprint(y)

tf.Tensor(
[[[[ 0  1  2  3]
   [ 4  5  6  7]
   [ 0  0  0  0]]

  [[ 8  9 10 11]
   [12 13 14 15]
   [ 0  0  0  0]]

  [[ 0  0  0  0]
   [ 0  0  0  0]
   [ 0  0  0  0]]]], shape=(1, 3, 3, 4), dtype=int32)


In [148]:
weights = np.load("../data/R50+ViT-B_16.npz")
# weights.files

In [150]:
importlib.reload(res)
model = res.ResNetV2([3,4,9])
model(dummy)
model.load_weights(weights)

In [134]:
test = np.zeros((1,1,1,64))
np.squeeze(test, axis=(0,1,2)).shape

(64,)

In [85]:
t = tf.constant([[1, 2, 3], [4, 5, 6]])
t.shape

TensorShape([2, 3])

In [10]:
X, y = ds.load_data("../data/train_npz/", 10, output_size=224)

In [8]:
print(X.shape)
print(y.shape)

(10, 384, 384, 3)
(10, 384, 384, 9)


In [2]:
resnet = tfk.applications.ResNet50V2(include_top=False, input_shape=(224,224,3))

In [37]:
w = resnet.get_layer('conv1_conv').weights[0]
m, v = tf.nn.moments(w, axes=[0,1,2], keepdims=True) 
w = (w-m) / tf.sqrt(v+1e-5)
tf.math.is_nan(w)
w.shape


TensorShape([7, 7, 3, 64])

In [60]:
class StdConv2D(tfkl.Conv2D):
    
    def call(self, x):
        w = self.weights[0]
        m, v = tf.nn.moments(w, axes=[0,1,2], keepdims=True) 
        w = (w-m) / tf.sqrt(v+1e-5)
        # self.weights = w 
        return tf.nn.conv2d(x, w, self.strides, self.padding.upper(), "NHWC" if self.data_format == "channels_last" else "NCHW", self.dilation_rate, self.name)

In [67]:
conv = StdConv2D(filters=64, kernel_size=7, strides=2, padding="same")
dummy = np.random.rand(1,224,224,3)
# conv(dummy)

In [3]:
resnet.summary()

Model: "resnet50v2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D)       (None, 230, 230, 3)  0           input_1[0][0]                    
__________________________________________________________________________________________________
conv1_conv (Conv2D)             (None, 112, 112, 64) 9472        conv1_pad[0][0]                  
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D)       (None, 114, 114, 64) 0           conv1_conv[0][0]                 
_________________________________________________________________________________________

In [19]:
importlib.reload(conf)
importlib.reload(decoder_layers)
importlib.reload(encoder_layers)
importlib.reload(vit)
importlib.reload(ut)
importlib.reload(res)
config = conf.get_r50_b16()
trans = vit.TransUnet(config)
trans.compile()

In [20]:
trans.model.fit(x=X, y=y, epochs=1, batch_size=24, verbose=1)



<tensorflow.python.keras.callbacks.History at 0x1a4c7167e80>

In [24]:
trans.resnet50v2.summary()

Model: "resnet_v2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
sequential_20 (Sequential)   (None, 112, 112, 64)      9536      
_________________________________________________________________
sequential_21 (Sequential)   (None, 55, 55, 256)       215808    
_________________________________________________________________
sequential_22 (Sequential)   (None, 28, 28, 512)       1219584   
_________________________________________________________________
sequential_23 (Sequential)   (None, 14, 14, 1024)      10449920  
Total params: 11,894,848
Trainable params: 11,894,848
Non-trainable params: 0
_________________________________________________________________


In [11]:
importlib.reload(conf)
importlib.reload(decoder_layers)
importlib.reload(encoder_layers)
importlib.reload(vit)
importlib.reload(ut)
config = conf.get_b16_none()
trans = vit.TransUnet(config)
trans.compile()