In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D,Input
from tensorflow.nn import relu
from tensorflow.keras import Model
from tensorflow.keras.utils import plot_model
import pydot,graphviz
from IPython.display import Image
from tensorflow.keras.initializers import TruncatedNormal
# import numpy as np
# import cv2

In [None]:
use_texture_conv = False
use_shape_conv = False
texture_downsample = False
probe_pt = {}

In [None]:
def expand_dims_1_to_4(tensor, dims=None):
    if not dims:
        dims = [-1, -1, -1]
    return tf.expand_dims(
             tf.expand_dims(
               tf.expand_dims(tensor, axis=dims[0]),
               axis=dims[1]),
             axis=dims[2])

def res_manipulator(enc_a,
                    enc_b,
                    amplification_factor,
                    layer_dims=32,
                    num_resblk=1,
                    num_conv=0,
                    num_aft_conv=0,
                    probe_pt=probe_pt):
    diff = (enc_b - enc_a)
    if probe_pt is not None:
        probe_pt["mani_diff"] = diff
    for i in range(num_conv):
        p = 3
        k = 7
        diff = tf.pad(diff, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
        cname = f'mani_conv_{i}'
        diff = Conv2D(layer_dims, kernel_size=k, strides=1, activation='relu', name=cname + 'c',kernel_initializer=TruncatedNormal(stddev=0.2))(diff)
    if probe_pt is not None:
        probe_pt["mani_after_conv"] = diff
    diff = diff * expand_dims_1_to_4(amplification_factor - 1.0)
    if probe_pt is not None:
        probe_pt["mani_after_mult"] = diff
    for i in range(num_aft_conv):
        diff = tf.pad(diff, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")
        cname = f'mani_aft_conv_{i}'
        diff = Conv2D(layer_dims, kernel_size=3, strides=1, name=cname + 'c',kernel_initializer=TruncatedNormal(stddev=0.2))(diff)
    for i in range(num_resblk):
        diff = residual_block(diff, layer_dims, 3, 1, name=f'mani_resblk{i}')
    if probe_pt is not None:
        probe_pt["mani_after_res"] = diff
    return enc_b + diff


def res_encoder(image, no,layer_dims=32, num_resblk=5):
    c0 = tf.pad(image, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
    c1 = Conv2D(layer_dims // 2, kernel_size=7, strides=1,activation='relu', name='res_enc_conv1_c'+no,kernel_initializer=TruncatedNormal(stddev=0.2))(c0)
    c2 = Conv2D(layer_dims, kernel_size=3, strides=2,activation='relu', name='res_enc_conv2_c'+no,kernel_initializer=TruncatedNormal(stddev=0.2))(c1)
    # Define G network with num_resblk resnet blocks
    r = c2
    for i in range(num_resblk):
        r = residual_block(r, layer_dims, 3, 1, name=f'res_encoder_resblk{i}_'+no)
    
    return r



def res_decoder(activation,
                layer_dims=64,
                out_channels=3,
                num_resblk=4):
    r = activation
    for i in range(num_resblk):
        r = residual_block(r, layer_dims, 3, 1, name=f'res_decoder_resblk{i}')
    
    up = tf.image.resize(r, tf.shape(r)[1:3] * 2)
    up = tf.pad(up, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")
    d2 = Conv2D(layer_dims // 2, kernel_size=3, strides=1, activation='relu',name='res_dec_conv2_c',kernel_initializer=TruncatedNormal(stddev=0.2))(up)
    d2 = tf.pad(d2, [[0, 0], [3, 3], [3, 3], [0, 0]], "REFLECT")
    out = Conv2D(out_channels, kernel_size=7, strides=1, name='res_pred_conv',kernel_initializer=TruncatedNormal(stddev=0.2))(d2)
    #i put this
    return tf.pad(out, [[0, 0], [1, 1], [1, 1], [0, 0]], "REFLECT")

# Define the residual_block function



def residual_block(x, output_dim, ks=3, s=1, name='residual_block'):
    p = (ks - 1) // 2
    y = tf.pad(x, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
    y = Conv2D(output_dim, ks, s, activation='relu',name=name+'_c1',kernel_initializer=TruncatedNormal(stddev=0.2))(y)
    y = tf.pad(y, [[0, 0], [p, p], [p, p], [0, 0]], "REFLECT")
    y = Conv2D(output_dim, ks, s, name=name+'_c2',kernel_initializer=TruncatedNormal(stddev=0.2))(y)
    return y + x

def _encoder(image,no):
    enc = res_encoder(image,no)

    texture_enc = enc
    shape_enc = enc
    # first convolution on common encoding
    if use_texture_conv:
        stride = 2 if texture_downsample else 1###making sure it is downsampled or not
        texture_enc = Conv2D(32, 3, stride,activation='relu',name='enc_texture_conv_'+no,kernel_initializer=TruncatedNormal(stddev=0.2))(texture_enc)
    
    if use_shape_conv:
        shape_enc = Conv2D(32, 3, 1,activation='relu', name='enc_shape_conv_'+no,kernel_initializer=TruncatedNormal(stddev=0.2))(shape_enc)

    for i in range(1):
        name = f'texture_enc_{i}_'+no
        if i == 0:
            # for backward compatibility
            name = 'texture_enc__'+no
        texture_enc = residual_block(texture_enc, 32, 3, 1,'resblk_'+name)

    for i in range(1):
        name = f'shape_enc_{i}_'+no
        if i == 0:
            # for backward compatibility
            name = 'shape_enc__'+no
        shape_enc = residual_block(shape_enc, 32,
                                   3, 1, 'resblk_'+name)
    return texture_enc, shape_enc

def _decoder(texture_enc, shape_enc):
    if texture_downsample:
        texture_enc = tf.image.resize(
                        texture_enc,
                        tf.shape(texture_enc)[1:3]* 2)
        texture_enc = tf.pad(texture_enc, [[0, 0], [1, 1], [1, 1], [0, 0]],
                             "REFLECT")
        texture_enc = Conv2D(32,3, 1,activation='relu',name='texture_upsample',kernel_initializer=TruncatedNormal(stddev=0.2))(texture_enc)

    enc = tf.concat([texture_enc, shape_enc], axis=3)
    # Needs double the channel because we concat the two encodings.
    return res_decoder(enc)

In [None]:
img_a = Input((450,450,3))
img_b = Input((450,450,3))
_,s_a = _encoder(img_a,'1')
t_b,s_b = _encoder(img_b,'2')
ou = res_manipulator(s_b,s_a,10)
out = _decoder(t_b,ou)
model = Model(inputs=[img_a,img_b],outputs=out)

In [None]:
model.summary()
model.compile(optimizer='adam', loss=(lambda x,y:tf.reduce_mean(tf.abs(x - y))), metrics=['accuracy'])



In [None]:
# lo u were asking me about how weights are stored right? this is how! ''sim'' list contains all trainable wieghts 
# sim = [(x,model.layers[x].weights) for x in range(len(model.layers)) if model.layers[x].weights!=[]]
sim[2]

In [None]:
# plot_model(model,'img.jpeg',show_shapes=True,
#     show_dtype=True,
#     show_layer_names=True,
#     rankdir='TB',
#     expand_nested=True,
#     show_layer_activations=True,
#     show_trainable=True)
# # Image('img.jpeg')

def parse_tfrecord_fn(example):
   feature_description = {
        'frameA': tf.io.FixedLenFeature([], tf.string),
        'frameB': tf.io.FixedLenFeature([], tf.string),
        'frameC': tf.io.FixedLenFeature([], tf.string),
        'frameAmp': tf.io.FixedLenFeature([], tf.string),
        'amplification_factor': tf.io.FixedLenFeature([], tf.float32)
    }

    # Parse the example
    example = tf.io.parse_single_example(example, feature_description)
    
    # Decode the images from bytes
    frameA = tf.image.decode_image(example['frameA'])
    frameB = tf.image.decode_image(example['frameB'])
    frameC = tf.image.decode_image(example['frameC'])
    frameAmp = tf.image.decode_image(example['frameAmp'])

    return frameA, frameB, frameC, frameAmp, example['amplification_factor']


filenames = 'tf record path'
dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(parse_tfrecord_fn)
dataset = dataset.shuffle(buffer_size=1000).batch(20)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

