In [1]:
import sys
sys.path.append("/Users/sarathrnair/Projects/tf-transformers/src/")

In [2]:
import tensorflow as tf
from unet import UnetModel
from tf_transformers.models import SentenceTransformer
from base_diffusion import BaseDiffusion



In [3]:
model_name = 'sentence-transformers/sentence-t5-base'
text_encoder = SentenceTransformer.from_pretrained(model_name, return_layer=True)
text_encoder.trainable = False

Metal device set to: Apple M1


INFO:absl:Successful ✅✅: Model checkpoints matched and loaded from /Users/sarathrnair/.cache/huggingface/hub/tftransformers__sentence-t5-base-sentence-transformers.main.d64dbdc4c8c15637da4215b81f38af99d48a586c/ckpt-1
INFO:absl:Successful ✅: Loaded model from tftransformers/sentence-t5-base-sentence-transformers


In [6]:
out_channels = 512
channel_mult = [1, 2, 3, 4]
num_res_blocks = 3
time_emb = 128
text_emb = 768
input_channels = 3

unet = UnetModel(
                text_embedding_dimension=text_emb, # Make sure output of text encoder matches this
                time_embedding_dimension=time_emb, # This should be same in BaseDiffusion model
                out_channels=out_channels, 
                channel_mult = channel_mult,
                input_channels=input_channels,
                num_res_blocks = num_res_blocks,
                attention_resolutions=[32, 16, 8],
                cross_attention_resolutions=[32, 16, 8],
                use_scale_shift_norm=True,

                )


Ds res 32
Ds res 16
Ds res 8
Ds res 8
Current out channel 2048
Up res start 8
Up res 16
Up res 32
Up res 64
Up res 64
H start (None, 64, 64, 512)
H (None, 32, 32, 512)
H (None, 16, 16, 1024)
H (None, 8, 8, 1536)
H (None, 8, 8, 2048)
Hs [<tf.Tensor 'unet/down_block/downsample/BiasAdd:0' shape=(None, 32, 32, 512) dtype=float32>, <tf.Tensor 'unet/down_block/attention/self_attention_layer_norm/batchnorm_1/add_1:0' shape=(None, 16, 16, 1024) dtype=float32>, <tf.Tensor 'unet/down_block/attention/self_attention_layer_norm/batchnorm_3/add_1:0' shape=(None, 8, 8, 1536) dtype=float32>, <tf.Tensor 'unet/down_block/attention/self_attention_layer_norm/batchnorm_5/add_1:0' shape=(None, 8, 8, 2048) dtype=float32>]
Hs len 4
Hs len 3
Ublocks 4
H middle (None, 8, 8, 2048)
H up (None, 16, 16, 512)
H up (None, 32, 32, 1024)
H up (None, 64, 64, 1536)
H up (None, 64, 64, 2048)


In [5]:
unet.count_params()

1278110723

In [7]:
unet.count_params()

1305709571

In [5]:
config = {}
config['beta_schedule'] = 'cosine'
config['diffusion_steps'] = 128
config['time_emb_dimension'] = time_emb
config['image_height'] = 32
config['image_width'] = 32
config['input_channels'] = input_channels

model = BaseDiffusion(config,
                     text_encoder_model=text_encoder, 
                     unet_model=unet)

model = model.get_model()

In [6]:
model.input

{'input_pixels': <KerasTensor: shape=(None, 32, 32, 3) dtype=float32 (created by layer 'input_pixels')>,
 'input_ids': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_ids')>,
 'input_mask': <KerasTensor: shape=(None, None) dtype=int32 (created by layer 'input_mask')>,
 'time_steps': <KerasTensor: shape=(1, None) dtype=int32 (created by layer 'time_steps')>,
 'noise': <KerasTensor: shape=(None, 32, 32, 3) dtype=float32 (created by layer 'input_noise')>}

In [7]:
model.output

{'h': <KerasTensor: shape=(None, 32, 32, 3) dtype=float32 (created by layer 'diffusion')>,
 'noise': <KerasTensor: shape=(None, 32, 32, 3) dtype=float32 (created by layer 'diffusion')>}

In [8]:
# model.save_serialized("/tmp/diffusion_temp")

In [12]:
# loaded = tf.saved_model.load("/tmp/diffusion_temp")
# model = loaded.signatures['serving_default']

In [8]:
batch_size = 4
text_sequence_length = 96
height = config['image_height']
width  = config['image_width']
in_channels = config['input_channels']
diffusion_steps = config['diffusion_steps']

In [9]:
image = tf.random.uniform((batch_size, height, width, in_channels)) # original image

input_ids = tf.random.uniform(minval=0, maxval=100, shape=(batch_size, text_sequence_length), dtype=tf.int32)
input_mask = tf.random.uniform(minval=0, maxval=2, shape=(batch_size, text_sequence_length), dtype=tf.int32)
time_steps = tf.random.uniform(minval=0, maxval=diffusion_steps, shape=(1, batch_size), dtype=tf.int32) # time steps

noise = tf.random.uniform((batch_size, height, width, in_channels)) # noise image



In [10]:
inputs = {}
inputs['input_pixels'] = image
inputs['noise'] = noise
inputs['input_ids'] = input_ids
inputs['input_mask'] = input_mask
inputs['time_steps'] = time_steps

In [11]:
model_outputs = model(inputs)

In [None]:
ds_res = image_height # Set downsample_resolution to be same as image_height

for index, ch_mult in enumerate(channel_mult):
    if index == len(channel_mult) - 1:
        use_downsample = False
    else:
        use_downsample = True
        
    current_out_channel = ch_mult * out_channels
    
    layers = []
    for resnet_counter in range(num_res_blocks):
        
        res = ResNetBlock(
                    current_out_channel, use_scale_shift_norm=use_scale_shift_norm, name='resnet_{}'.format(resnet_counter)
                        )
        
        self_attn = tf.identity
        if ds_res in attention_resolutions:
            self_attn = ImageSelfAttention()
        
        cross_attn = tf.identity
        if ds_res in cross_attention_resolutions:
            cross_attn = ImageTextCrossAttention()
        
        down_sample = tf.identity
        if use_downsample:
            down_sample = tf.keras.layers.Conv2D(
                out_channels, kernel_size=(3, 3), strides=(2, 2), use_bias=True, padding='SAME', name='downsample_{}'.format(index)
            )
            ds_res = ds_res // 2
        
        layers.append(res, attn, cross_attn, down_sample)
            
            


        
        
    self.d_blocks.append(
        DownBlock(
            current_out_channel,
            use_self_attention=use_self_attention[index],
            use_cross_attention=use_cross_attention[index],
            use_downsample=use_downsample,
            use_scale_shift_norm=use_scale_shift_norm,
        )
    )
