In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model

In [None]:
class CrossAttentionBlock(layers.Layer):
    def __init__(self, channels):
        super().__init__()
        self.norm = layers.LayerNormalization()
        self.q_proj = layers.Dense(channels)
        self.k_proj = layers.Dense(channels)
        self.v_proj = layers.Dense(channels)
        self.proj  = layers.Dense(channels)
        
    def call(self, x, context):
        # Image feature map
        B, H, W, C = x.shape
        residual = x
        
        x = self.norm(x)
        context = self.norm(context)
        
        x_flat = tf.reshape(x, (B, H*W, C))
        q_concat = tf.concat([x_flat, context], axis=1)
        q = self.q_proj(q_concat) # Query (B, H*W+N, C)
        
        k = self.k_proj(context) # Key (B, N, C)
        v = self.v_proj(context) # Value (B, N, C)
        
        attention = tf.matmul(q, k, transpose_b=True), tf.math.sqrt(tf.cast(C, tf.float32))
        attention = tf.nn.softmax(attention, axis=-1)
        
        out = tf.matmul(attention, v)
        out = tf.reshape(out, (B, H, W, C))
        
        out = self.proj(out)
        return out + residual # Skip connection

In [None]:
class TimestepEmbedding(tf.keras.layers.Layer):
    def __init__(self, time_dim):
        super().__init__()
        self.time_dim = time_dim
        self.dense1 = layers.Dense(time_dim, activation='swish')
        self.dense2 = layers.Dense(time_dim, activation='swish')

    def call(self, t):
        half_dim = self.time_dim // 2
        emb = tf.range(half_dim, dtype=tf.float32)
        emb = tf.exp(-tf.math.log(10000.0) * emb / half_dim)
        emb = tf.cast(t, dtype=tf.float32)[:, None] * emb[None, :]
        emb = tf.concat([tf.sin(emb), tf.cos(emb)], axis=-1)
        emb = tf.pad(emb, [[0, 0], [0, self.time_dim % 2]])  # For odd time_dim

        emb = self.dense1(emb)
        emb = self.dense2(emb)
        return emb

In [None]:
class ImprovedUNet(Model):
    def __init__(self, time_dim=256):
        super().__init__()
        self.time_dim = time_dim
        self.time_embed = TimestepEmbedding(time_dim)
        
        self.enc1 = self.make_encoder_block(64)
        self.enc2 = self.make_encoder_block(128)
        self.enc3 = self.make_encoder_block(256)
        
        self.bottleneck = tf.keras.Sequential([
            CrossAttentionBlock(256),
            layers.Conv2D(256, 3, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('swish')
        ])
        
        self.dec3 = self.make_decoder_block(128)
        self.dec2 = self.make_decoder_block(64)
        self.dec1 = self.make_decoder_block(64)
        self.final = layers.Conv2D(3, 1, activation='tanh')
        
    def make_encoder_block(self, channels):
        return tf.keras.Sequential([
            layers.Conv2D(channels, 3, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('swish'),
            layers.Conv2D(channels, 3, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('swish'),
            layers.MaxPooling2D(2)
        ])
        
    def make_decoder_block(self, channels):
        return tf.keras.Sequential([
            layers.Conv2DTranspose(channels, 3, strides=2, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('swish'),
            layers.Conv2D(channels, 3, padding='same'),
            layers.BatchNormalization(),
            layers.Activation('swish')
        ])
        
    def call(self, x, t, text_embeddings):
        t = self.time_embed(t)
        t = tf.reshape(t, (-1, 1, 1, self.time_dim))
        
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        
        # Bottleneck with Cross-Attention
        middle = self.bottleneck(e3+t, context=text_embeddings)
        
        # Decoder with skip connection
        x = self.dec3(tf.concat([middle, e3], axis=-1))
        x = self.dec2(tf.concat([x, e2], axis=-1))
        x = self.dec1(tf.concat([x, e1], axis=-1))
        
        return self.final(x)

In [None]:
class ImprovedStableDiffusion(Model):
    def __init__(self, img_size=64, time_steps=1000):
        super().__init__()
        self.img_size = img_size
        self.time_steps = time_steps
        
        # Noise scheduling
        self.beta = np.linspace(0.0001, 0.02, time_steps)
        self.alpha = 1. - self.beta
        self.alpha_bar = tf.constant(np.cumprod(self.alpha), dtype=tf.float32)
        
        self.unet = ImprovedUNet()
        
    def diffusion_schedule(self, n):
        return tf.gather(self.alpha_bar, n)
    
    def forward_diffusion(self, x, t):
        alpha_bar = self.diffusion_schedule(t)
        alpha_bar = tf.reshape(alpha_bar, (-1, 1, 1, 1))
        
        noise = tf.random.normal(x.shape)
        noisy_x = tf.sqrt(alpha_bar) * x + tf.sqrt(1 - alpha_bar) * noise
        return noisy_x, noise
    
    def call(self, x, training=True, text_embeddings=None):
        if text_embeddings is None:
            raise ValueError("text_embeddings must be provided for cross-attention.")
        
        b = tf.shape(x)[0]
        t = tf.random.uniform((b,), 0, self.time_steps, dtype=tf.float32)
        
        x_noisy, noise = self.forward_diffusion(x, t)
        
        pred_noise = self.unet(x_noisy, t, text_embeddings)
        
        if training:
            return pred_noise, noise
        else:
            return self.sample(b, text_embeddings)
        
    def sample(self, batch_size, text_embeddings):
        x = tf.random.normal((batch_size, self.img_size, self.img_size, 3))
        
        for t in range(self.time_steps-1, -1, -1):
            t_batch = tf.fill((batch_size), t)
            pred_noise = self.unet(x, t_batch, text_embeddings)
            
            alpha = self.alpha[t]
            alpha_bar = self.alpha_bar[t]
            beta = self.beta[t]
            
            if t > 0:
                noise = tf.random.normal(x.shape)
            else:
                noise = 0
                
            x = (1 / tf.sqrt(alpha) * (x - ((1-alpha) / tf.sqrt(1-alpha_bar))*pred_noise) + tf.sqrt(beta)*noise)    
        return x