In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import tensorflow_hub as hub

# Output vocab (for targets and predictions)
out_vocab = {"<pad>":0,"cat":1,"dog":2,"meow":3,"bark":4}
inv_out_vocab = {v:k for k,v in out_vocab.items()}

def encode_output_seq(seq, max_len=1):
    tokens = seq.split()
    ids = [out_vocab[w] for w in tokens]
    ids += [0]*(max_len-len(ids))
    return np.array(ids)

def decode(ids):
    return " ".join(inv_out_vocab[i] for i in ids if i>0)

class HubEmbeddingLayer(tf.keras.layers.Layer):
    def __init__(self, hub_url="https://tfhub.dev/google/universal-sentence-encoder/4"):
        super().__init__()
        self.hub_layer = hub.KerasLayer(hub_url, trainable=True)

    def call(self, inputs):
        return self.hub_layer(inputs)


def build_text_expert():
    inp = layers.Input(shape=(), dtype=tf.string, name="text_in")
    embedding_layer = HubEmbeddingLayer()  # default is USE 512-d
    x = embedding_layer(inp)
    x = layers.Dense(512, activation='relu')(x)        # increased capacity
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(len(out_vocab), activation='softmax')(x)
    return Model(inp, out, name="TextExpert")

def build_image_expert():
    inp = layers.Input(shape=(4,4,1), name="image_in")
    x = layers.Conv2D(16, (2,2), activation='relu')(inp)
    x = layers.Flatten()(x)
    x = layers.Dense(64, activation='relu')(x)
    x = layers.Dropout(0.2)(x)
    out = layers.Dense(len(out_vocab), activation='softmax')(x)
    return Model(inp, out, name="ImageExpert")

text_expert = build_text_expert()
image_expert = build_image_expert()

initial_lr = 1e-3
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate=initial_lr,
    decay_steps=100,
    decay_rate=0.96,
    staircase=True)

opt = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

print("Compiling models...")
text_expert.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
image_expert.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

text_inputs = ["cat", "dog"]
text_targets = ["meow", "bark"]

image_inputs = np.array([np.ones((4,4)), np.zeros((4,4))])
image_targets = ["cat", "dog"]

X_text_in = tf.constant(text_inputs, dtype=tf.string)
Y_text_out = np.array([out_vocab[t] for t in text_targets])
X_img_in = image_inputs[..., None]
Y_img_out = np.array([out_vocab[t] for t in image_targets])

print("Train text expert...")
text_expert.fit(X_text_in, Y_text_out, epochs=100, verbose=1)

print("Train image expert...")
image_expert.fit(X_img_in, Y_img_out, epochs=100, verbose=1)

Epoch 1/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 816ms/step - accuracy: 0.5000 - loss: 1.5911
Epoch 2/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 18ms/step - accuracy: 1.0000 - loss: 1.4788
Epoch 3/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 1.0000 - loss: 1.3847
Epoch 4/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 1.0000 - loss: 1.2850
Epoch 5/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 1.0000 - loss: 1.2736
Epoch 6/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 1.0000 - loss: 1.0831
Epoch 7/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 17ms/step - accuracy: 1.0000 - loss: 0.9123
Epoch 8/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 1.0000 - loss: 0.9072
Epoch 9/100
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[

<keras.src.callbacks.history.History at 0x31ef68f40>

In [57]:
# -------- Inference functions --------
def predict_from_text(word):
    preds = text_expert.predict(tf.constant([word]), verbose=0)
    print(f"Input: {word}, Probabilities: {preds}")
    pred_id = np.argmax(preds[0])
    return inv_out_vocab[pred_id]

def predict_from_image(img):
    # Step 1: predict animal from image
    animal_probs = image_expert.predict(img[None,...,None], verbose=0)
    animal_id = np.argmax(animal_probs[0])
    animal = inv_out_vocab[animal_id]
    # Step 2: predict sound from text expert using animal label
    sound = predict_from_text(animal)
    return animal, sound

# -------- Testing --------
print("Text input 'dog'     =>", predict_from_text("dog"))
print("Text input 'canine'  =>", predict_from_text("canine"))
print("Text input 'wolf'    =>", predict_from_text("wolf"))
print("Text input 'labrador'    =>", predict_from_text("labrador"))
print("Text input 'coyote'  =>", predict_from_text("coyote"))
print("Text input 'puppy'  =>", predict_from_text("puppy"))
print("Text input 'cat'     =>", predict_from_text("cat"))
print("Text input 'feline'  =>", predict_from_text("feline"))
print("Text input 'ragdoll'     =>", predict_from_text("ragdoll"))
print("Text input 'lion'  =>", predict_from_text("lion"))
print("Text input 'tiger'  =>", predict_from_text("tiger"))


print("Cat-like image       =>", predict_from_image(np.ones((4,4))))
print("Dog-like image       =>", predict_from_image(np.zeros((4,4))))


Input: dog, Probabilities: [[6.7759820e-06 1.6639984e-05 3.5206945e-06 3.2574616e-04 9.9964726e-01]]
Text input 'dog'     => bark
Input: canine, Probabilities: [[1.64616809e-04 3.40726372e-04 1.06246895e-04 3.48554272e-03
  9.95902836e-01]]
Text input 'canine'  => bark
Input: wolf, Probabilities: [[0.00292729 0.00583772 0.0015113  0.18139388 0.80832976]]
Text input 'wolf'    => bark
Input: labrador, Probabilities: [[1.1411209e-03 1.9138358e-03 7.9245231e-04 1.0606636e-02 9.8554593e-01]]
Text input 'labrador'    => bark
Input: coyote, Probabilities: [[0.0075675  0.01140293 0.00480268 0.08793327 0.8882936 ]]
Text input 'coyote'  => bark
Input: puppy, Probabilities: [[1.00294594e-04 2.29717669e-04 5.90392046e-05 4.15754551e-03
  9.95453358e-01]]
Text input 'puppy'  => bark
Input: cat, Probabilities: [[1.0049924e-05 4.3694115e-05 1.9005014e-06 9.9954093e-01 4.0346046e-04]]
Text input 'cat'     => meow
Input: feline, Probabilities: [[3.45496868e-04 1.03776553e-03 1.12378846e-04 9.89282012e-

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Dense, GlobalAveragePooling1D, Softmax, LSTM, Conv2D, Flatten, MaxPooling2D, LayerNormalization, Dropout, Input, Conv2DTranspose, Reshape, Add, Lambda, BatchNormalization, add, Activation
from tensorflow.keras import Model
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

In [None]:
# ---------------------------------------------
# MoE Text Generation Model Components
# ---------------------------------------------

class Expert(Model):
    def __init__(self, d_model):
        super().__init__()
        self.norm = LayerNormalization()
        self.dense1 = Dense(64, activation='relu')
        self.dropout = Dropout(0.2)
        self.dense2 = Dense(d_model)
    def call(self, x):
        x = self.norm(x)
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

class GatingNetwork(Model):
    def __init__(self, d_model, num_experts):
        super().__init__()
        self.dense = Dense(num_experts)
        self.softmax = Softmax(axis=-1)
    def call(self, x):
        logits = self.dense(x)
        return self.softmax(logits)

class MoEResponseGenerator(Model):
    def __init__(self, vocab_size, d_model, num_experts, max_resp_len, lstm_units=128):
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model, mask_zero=True)
        self.global_pool = GlobalAveragePooling1D()
        self.num_experts = num_experts
        self.experts = [Expert(d_model) for _ in range(num_experts)]
        self.gating_net = GatingNetwork(d_model, num_experts)
        self.lstm_units = lstm_units
        self.lstm = LSTM(lstm_units, return_sequences=True, return_state=True)
        self.to_h = Dense(lstm_units)
        self.to_c = Dense(lstm_units)
        self.output_layer = Dense(vocab_size)
        self.max_resp_len = max_resp_len

    def call(self, inputs, training=False):
        input_seq, resp_in_seq = inputs
        enc_emb = self.embedding(input_seq)
        pooled = self.global_pool(enc_emb)

        gating_probs = self.gating_net(pooled)
        expert_outs = tf.stack([expert(pooled) for expert in self.experts], axis=1)
        gated_rep = tf.reduce_sum(tf.expand_dims(gating_probs, 2) * expert_outs, axis=1)

        h_state = self.to_h(gated_rep)
        c_state = self.to_c(gated_rep)

        resp_emb = self.embedding(resp_in_seq)
        lstm_out, _, _ = self.lstm(resp_emb, initial_state=[h_state, c_state])

        logits = self.output_layer(lstm_out)
        return logits, gating_probs

In [None]:
# ---------------------------------------------
# Utility functions for text encoding and generation
# ---------------------------------------------

# (Your train_data list is assumed defined here - omitted for brevity, but same as your list)

# Build vocabulary based on train_data (input + responses)
train_data = [
    ("hello there", "hi, how can I help"),
    ("hi", "hello, what can I do"),
    ("goodbye", "goodbye, have a nice day"),
    ("see you later", "see you soon, goodbye"),
    ("I want to order pizza", "sure, what toppings do you want"),
    ("can I get a burger", "what size burger would you like"),
    ("what is the weather", "the weather today is sunny"),
    ("is it raining", "no rain expected today"),
    ("hey, I want some pasta", "what kind of pasta would you prefer"),
    ("do you have vegetarian options?", "yes, we have several vegetarian dishes"),
    ("good morning", "good morning, how may I assist you?"),
    ("bye", "take care, see you later"),
    ("will it be hot today?", "expect warm temperatures all day"),
    ("can I order a salad?", "what dressing would you like on your salad?"),
    ("thanks, goodbye", "you're welcome, goodbye!"),
    ("tell me the forecast", "the forecast shows clear skies"),
    ("what's your name?", "i am your assistant, here to help"),
    ("can I have a coffee?", "sure, would you like it black or with milk?"),
    ("thank you for the help", "happy to assist you anytime"),
    ("are you open today?", "yes, we are open from 9 am to 9 pm"),
    ("could you help me with my order", "of course, what would you like to order"),
    ("are there any gluten free options", "yes, we have several gluten free dishes available"),
    ("what are today's specials", "today's special is grilled salmon with vegetables"),
    ("how late are you open", "we are open until 10 pm tonight"),
    ("can you recommend a dessert", "our chocolate lava cake is very popular"),
    ("I need to change my order", "sure, what changes would you like to make"),
    ("do you deliver", "yes, we deliver within a 5 mile radius"),
    ("what payment methods do you accept", "we accept cash, credit cards, and mobile payments"),
    ("is there a parking facility", "yes, free parking is available behind the restaurant"),
    ("thank you very much", "you're welcome, happy to help"),
    ("I have a food allergy", "please let us know your allergy, and we will accommodate"),
    ("can I book a table", "yes, for how many people and what time"),
    ("what's your restaurant address", "we are located at 123 Main Street"),
    ("do you have vegan meals", "yes, we offer delicious vegan options"),
    ("can I get nutritional information", "nutritional info is available on our website"),
    ("how long is the wait time", "usually about 15 minutes during peak hours"),
    ("do you have a kids menu", "yes, we have a special menu for children"),
    ("can I cancel my order", "please provide your order number to cancel"),
    ("what are your opening hours", "we are open from 9 am to 10 pm daily"),
    ("is takeout available", "yes, you can order takeout anytime during opening hours"),        
]

all_texts = [t[0] + " " + t[1] for t in train_data]
all_words = set(word for sentence in all_texts for word in sentence.lower().split())
word2idx = {w: i + 1 for i, w in enumerate(sorted(all_words))}
idx2word = np.array(['<pad>'] + sorted(all_words))
vocab_size = len(idx2word)
max_input_len = 6
max_resp_len = 8

def encode_sentence(sent, max_len):
    words = sent.lower().split()
    seq = [word2idx.get(w, 0) for w in words]
    seq = seq[:max_len] + [0] * (max_len - len(seq))
    return seq

def sample_from_logits(logits, temperature=1.0, top_k=5):
    logits = logits / temperature
    if top_k > 0:
        values, _ = tf.math.top_k(logits, k=top_k)
        min_values = values[:, -1, None]
        logits = tf.where(
            logits < min_values,
            tf.fill(tf.shape(logits), float('-inf')),
            logits,
        )
    probabilities = tf.nn.softmax(logits)
    next_token = tf.random.categorical(tf.math.log(probabilities), num_samples=1)
    return tf.squeeze(next_token, axis=-1).numpy()

def generate_response(model, input_text, max_len=30, temperature=1.0, top_k=5):
    input_seq = np.array([encode_sentence(input_text, max_input_len)])
    response_seq = np.zeros((1, max_resp_len), dtype=np.int32)
    generated_tokens = []
    gating_probs = None

    for i in range(max_len):
        logits, gating_probs = model((input_seq, response_seq), training=False)
        logits_step = logits[:, i % max_resp_len, :]
        next_token = sample_from_logits(logits_step, temperature=temperature, top_k=top_k)[0]
        if next_token == 0:
            break
        generated_tokens.append(idx2word[next_token])
        if i + 1 < max_resp_len:
            response_seq[0, i] = next_token
        else:
            response_seq = np.roll(response_seq, -1, axis=1)
            response_seq[0, -1] = next_token

    top_expert = np.argmax(gating_probs[0].numpy()) if gating_probs is not None else -1
    return " ".join(generated_tokens), top_expert, gating_probs[0].numpy()

# Build training dataset
X_input = np.array([encode_sentence(t[0], max_input_len) for t in train_data])
X_resp_in = np.array([encode_sentence(t[1], max_resp_len) for t in train_data])
X_resp_out = np.array([encode_sentence(t[1], max_resp_len)[1:] + [0] for t in train_data])

In [None]:
# ---------------------------------------------
# MNIST Digit Classifier Components
# ---------------------------------------------

class DigitClassifier(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(32, (3,3), activation='relu')
        self.pool1 = MaxPooling2D((2,2))
        self.conv2 = Conv2D(64, (3,3), activation='relu')
        self.pool2 = MaxPooling2D((2,2))
        self.flatten = Flatten()
        self.dense = Dense(64, activation='relu')
        self.out = Dense(10)  # digits 0-9 logits

    def call(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dense(x)
        return self.out(x)

def prepare_mnist_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    return (x_train, y_train), (x_test, y_test)

In [None]:
# ---------------------------------------------
# Training setup and training loops
# ---------------------------------------------

digit_classifier = DigitClassifier()
(x_train_mnist, y_train_mnist), (x_test_mnist, y_test_mnist) = prepare_mnist_data()
digit_classifier.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
print("Training MNIST digit classifier...")
digit_classifier.fit(x_train_mnist, y_train_mnist, epochs=3, batch_size=128, validation_split=0.1)

d_model = 32
num_experts = 5
lstm_units = 128
max_epochs = 300
batch_size = 2

moe_model = MoEResponseGenerator(vocab_size, d_model, num_experts, max_resp_len, lstm_units)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

train_dataset = tf.data.Dataset.from_tensor_slices(((X_input, X_resp_in), X_resp_out)).shuffle(50).batch(batch_size)

@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits, _ = moe_model(inputs, training=True)
        loss = loss_fn(labels, logits)
    grads = tape.gradient(loss, moe_model.trainable_variables)
    optimizer.apply_gradients(zip(grads, moe_model.trainable_variables))
    return loss

print("Training MoE text generation model...")
for epoch in range(max_epochs):
    total_loss = 0.0
    for batch in train_dataset:
        loss = train_step(batch[0], batch[1])
        total_loss += loss.numpy()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch + 1}/{max_epochs}: Loss = {total_loss / len(train_dataset):.4f}")

In [None]:
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input

# Load and normalize MNIST data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype("float32") / 255.0
x_test = x_test.astype("float32") / 255.0
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

# Prepare target images for increment digits
train_target_images = np.zeros_like(x_train)
for idx in range(len(x_train)):
    inc_digit = (y_train[idx] + 1) % 10
    candidates = np.where(y_train == inc_digit)[0]
    chosen_idx = np.random.choice(candidates)
    train_target_images[idx] = x_train[chosen_idx]

test_target_images = np.zeros_like(x_test)
for idx in range(len(x_test)):
    inc_digit = (y_test[idx] + 1) % 10
    candidates = np.where(y_test == inc_digit)[0]
    chosen_idx = np.random.choice(candidates)
    test_target_images[idx] = x_test[chosen_idx]

# Residual block
def residual_block(x, filters, kernel_size=3):
    shortcut = x
    x = Conv2D(filters, kernel_size, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(filters, kernel_size, padding="same")(x)
    x = BatchNormalization()(x)
    x = add([shortcut, x])
    x = Activation("relu")(x)
    return x

# Build model with residual blocks
def build_improved_increment_model():
    inputs = Input(shape=(28,28,1))
    x = Conv2D(64, 3, strides=2, padding="same", activation="relu")(inputs)
    x = BatchNormalization()(x)
    x = residual_block(x, 64)
    x = Conv2D(128, 3, strides=2, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)
    x = residual_block(x, 128)

    x = Flatten()(x)
    x = Dense(256, activation="relu")(x)
    x = Dropout(0.3)(x)

    x = Dense(7*7*128, activation="relu")(x)
    x = Reshape((7,7,128))(x)
    x = Conv2DTranspose(128, 3, strides=2, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)
    x = Conv2DTranspose(64, 3, strides=2, padding="same", activation="relu")(x)
    x = BatchNormalization()(x)

    outputs = Conv2D(1, 3, padding="same", activation="sigmoid")(x)
    model = Model(inputs, outputs)
    return model

increment_model = build_improved_increment_model()

# Initialize VGG16 model for perceptual loss - with input shape 32x32x3 (minimum required)
vgg = VGG16(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
vgg.trainable = False
feature_extractor = Model(vgg.input, vgg.get_layer("block3_conv3").output)

# Preprocessing function to resize and replicate grayscale channel for VGG16
def preprocess_for_vgg(x):
    x_resized = tf.image.resize(x, [32, 32])
    x_rgb = tf.image.grayscale_to_rgb(x_resized)
    x_preprocessed = preprocess_input(x_rgb * 255.0)
    return x_preprocessed

# Perceptual loss combining MSE and VGG feature loss
def perceptual_loss(y_true, y_pred):
    y_true_pp = preprocess_for_vgg(y_true)
    y_pred_pp = preprocess_for_vgg(y_pred)
    f_true = feature_extractor(y_true_pp)
    f_pred = feature_extractor(y_pred_pp)
    return tf.reduce_mean(tf.square(f_true - f_pred))

def combined_loss(y_true, y_pred):
    mse = tf.reduce_mean(tf.square(y_true - y_pred))
    pl = perceptual_loss(y_true, y_pred)
    return mse + 0.1 * pl

increment_model.compile(optimizer='adam', loss=combined_loss)

# Train with learning rate scheduler callback
lr_scheduler = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=3, verbose=1)

increment_model.fit(
    x_train,
    train_target_images,
    epochs=30,
    batch_size=128,
    validation_data=(x_test, test_target_images),
    callbacks=[lr_scheduler]
)

# Inference helper
def get_incremented_digit_image(input_image):
    img = input_image.astype("float32") / 255.0
    if img.ndim == 2:
        img = np.expand_dims(img, axis=-1)
    img = np.expand_dims(img, axis=0)
    generated_img = increment_model.predict(img)
    return generated_img[0]

In [None]:
# Visualization of test results for a random sample
# Visualization test
idx = np.random.randint(len(x_test))
input_img = x_test[idx]
original_digit = y_test[idx]
inc_digit = (original_digit + 1) % 10
generated_img = get_incremented_digit_image(input_img)

plt.figure(figsize=(12, 4))
plt.subplot(1,3,1)
plt.title(f"Input: {original_digit}")
plt.imshow(input_img[:,:,0], cmap='gray')
plt.axis('off')

plt.subplot(1,3,2)
plt.title(f"Generated: {inc_digit}")
plt.imshow(generated_img[:,:,0], cmap='gray')
plt.axis('off')

plt.subplot(1,3,3)
plt.title(f"Expected: {inc_digit}")
expected_img = x_test[np.where(y_test == inc_digit)[0][0]]
plt.imshow(expected_img[:,:,0], cmap='gray')
plt.axis('off')

plt.show()


In [None]:
def multimodal_handler(input_text=None, input_image=None):
    if input_text:
        response, expert, gating_probs = generate_response(moe_model, input_text)
        print(f"Text Input: {input_text}")
        print(f"Response: {response}")
        print(f"Top expert used: {expert}")
        print(f"Gating probs: {np.round(gating_probs, 3)}")
    if input_image is not None:
        pred_digit, inc_digit, inc_image = get_incremented_digit_image_generative(input_image)
        print(f"Image input digit: {pred_digit}, incremented output digit: {inc_digit}")
        plt.subplot(1, 2, 1)
        plt.imshow(input_image, cmap='gray')
        plt.title(f"Input: {pred_digit}")
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(inc_image, cmap='gray')
        plt.title(f"Output: {inc_digit}")
        plt.axis('off')
        plt.show()


In [None]:
# ---------------------------------------------
# Demo code to test multimodal system
# ---------------------------------------------

# Text only
multimodal_handler(input_text="hello, how are you?")

# Image only (from MNIST test)
multimodal_handler(input_image=x_test_mnist[15])

# Both
multimodal_handler(input_text="What number is this?", input_image=x_test_mnist[42])

In [None]:
greeting_examples = [
    "hello",
    "hi there",
    "good morning",
    "hey, how are you?",
    "what's up?"
]

goodbye_examples = [
    "goodbye",
    "see you later",
    "bye for now",
    "talk to you soon",
    "have a great day"
]

order_food_examples = [
    "I want to order a pizza",
    "can I get a burger please?",
    "what sides do you have?",
    "I'd like a vegetarian pasta",
    "do you have gluten free options?"
]

weather_examples = [
    "what's the weather today?",
    "will it rain tomorrow?",
    "is it sunny outside?",
    "what's the forecast for this week?",
    "do I need an umbrella today?"
]

miscellaneous_examples = [
    "what's your name?",
    "can you tell me a joke?",
    "how do I say goodbye politely?",
    "are you open on weekends?",
    "what time do you close?"
]

inference_prompts = greeting_examples + goodbye_examples + order_food_examples + weather_examples + miscellaneous_examples
for input_text in inference_prompts:
    multimodal_handler(input_text)

    # print("Input:", input_text)
    # print("Generated response:", response)
    # print("Top expert used:", expert_used)
    # print("Gating probabilities:", np.round(gating_distribution, 3))
    print("---------------------------------")

# Old Code

In [None]:
# Sample data (input, response)
train_data = [
    ("hello there", "hi, how can I help"),
    ("hi", "hello, what can I do"),
    ("goodbye", "goodbye, have a nice day"),
    ("see you later", "see you soon, goodbye"),
    ("I want to order pizza", "sure, what toppings do you want"),
    ("can I get a burger", "what size burger would you like"),
    ("what is the weather", "the weather today is sunny"),
    ("is it raining", "no rain expected today"),
    ("hey, I want some pasta", "what kind of pasta would you prefer"),
    ("do you have vegetarian options?", "yes, we have several vegetarian dishes"),
    ("good morning", "good morning, how may I assist you?"),
    ("bye", "take care, see you later"),
    ("will it be hot today?", "expect warm temperatures all day"),
    ("can I order a salad?", "what dressing would you like on your salad?"),
    ("thanks, goodbye", "you're welcome, goodbye!"),
    ("tell me the forecast", "the forecast shows clear skies"),
    ("what's your name?", "i am your assistant, here to help"),
    ("can I have a coffee?", "sure, would you like it black or with milk?"),
    ("thank you for the help", "happy to assist you anytime"),
    ("are you open today?", "yes, we are open from 9 am to 9 pm"),
    ("could you help me with my order", "of course, what would you like to order"),
    ("are there any gluten free options", "yes, we have several gluten free dishes available"),
    ("what are today's specials", "today's special is grilled salmon with vegetables"),
    ("how late are you open", "we are open until 10 pm tonight"),
    ("can you recommend a dessert", "our chocolate lava cake is very popular"),
    ("I need to change my order", "sure, what changes would you like to make"),
    ("do you deliver", "yes, we deliver within a 5 mile radius"),
    ("what payment methods do you accept", "we accept cash, credit cards, and mobile payments"),
    ("is there a parking facility", "yes, free parking is available behind the restaurant"),
    ("thank you very much", "you're welcome, happy to help"),
    ("I have a food allergy", "please let us know your allergy, and we will accommodate"),
    ("can I book a table", "yes, for how many people and what time"),
    ("what's your restaurant address", "we are located at 123 Main Street"),
    ("do you have vegan meals", "yes, we offer delicious vegan options"),
    ("can I get nutritional information", "nutritional info is available on our website"),
    ("how long is the wait time", "usually about 15 minutes during peak hours"),
    ("do you have a kids menu", "yes, we have a special menu for children"),
    ("can I cancel my order", "please provide your order number to cancel"),
    ("what are your opening hours", "we are open from 9 am to 10 pm daily"),
    ("is takeout available", "yes, you can order takeout anytime during opening hours"),        
]

# Build vocabulary
all_texts = [t[0] + " " + t[1] for t in train_data]
all_words = set(word for sentence in all_texts for word in sentence.lower().split())
word2idx = {w: i + 1 for i, w in enumerate(sorted(all_words))}
idx2word = np.array(['<pad>'] + sorted(all_words))
vocab_size = len(idx2word)

max_input_len = 6
max_resp_len = 8

def encode_sentence(sent, max_len):
    words = sent.lower().split()
    seq = [word2idx.get(w, 0) for w in words]
    seq = seq[:max_len] + [0] * (max_len - len(seq))
    return seq

X_input = np.array([encode_sentence(t[0], max_input_len) for t in train_data])
X_resp_in = np.array([encode_sentence(t[1], max_resp_len) for t in train_data])
X_resp_out = np.array([encode_sentence(t[1], max_resp_len)[1:] + [0] for t in train_data]) # shifted

class Expert(Model):
    def __init__(self, d_model):
        super().__init__()
        self.norm = tf.keras.layers.LayerNormalization()
        self.dense1 = Dense(64, activation='relu')
        self.dropout = tf.keras.layers.Dropout(0.2)
        self.dense2 = Dense(d_model)
    def call(self, x):
        x = self.norm(x)
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

class GatingNetwork(Model):
    def __init__(self, d_model, num_experts):
        super().__init__()
        self.dense = Dense(num_experts)
        self.softmax = Softmax(axis=-1)
    def call(self, x):
        logits = self.dense(x)
        return self.softmax(logits)

class MoEResponseGenerator(Model):
    def __init__(self, vocab_size, d_model, num_experts, max_resp_len, lstm_units=128):
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model, mask_zero=True)
        self.global_pool = GlobalAveragePooling1D()
        self.num_experts = num_experts
        self.experts = [Expert(d_model) for _ in range(num_experts)]
        self.gating_net = GatingNetwork(d_model, num_experts)
        self.lstm_units = lstm_units
        self.lstm = LSTM(lstm_units, return_sequences=True, return_state=True)
        self.to_h = Dense(lstm_units)
        self.to_c = Dense(lstm_units)
        self.output_layer = Dense(vocab_size)
        self.max_resp_len = max_resp_len

    def call(self, inputs, training=False):
        input_seq, resp_in_seq = inputs
        enc_emb = self.embedding(input_seq)
        pooled = self.global_pool(enc_emb)

        gating_probs = self.gating_net(pooled)
        expert_outs = tf.stack([expert(pooled) for expert in self.experts], axis=1)
        gated_rep = tf.reduce_sum(tf.expand_dims(gating_probs, 2) * expert_outs, axis=1)

        h_state = self.to_h(gated_rep)
        c_state = self.to_c(gated_rep)

        resp_emb = self.embedding(resp_in_seq)
        lstm_out, _, _ = self.lstm(resp_emb, initial_state=[h_state, c_state])

        logits = self.output_layer(lstm_out)
        return logits, gating_probs

# Hyperparams and dataset
d_model = 32
num_experts = 4
batch_size = 2
epochs = 250

dataset = tf.data.Dataset.from_tensor_slices(((X_input, X_resp_in), X_resp_out)).shuffle(20).batch(batch_size)

model = MoEResponseGenerator(vocab_size, d_model, num_experts, max_resp_len)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits, _ = model(inputs, training=True)
        loss = loss_fn(labels, logits)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

for epoch in range(epochs):
    total_loss = 0
    for in_batch, out_batch in dataset:
        loss = train_step(in_batch, out_batch)
        total_loss += loss.numpy()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{epochs} Loss: {total_loss/len(dataset):.4f}")

In [None]:
def sample_from_logits(logits, temperature=1.0, top_k=5):
    logits = logits / temperature
    if top_k > 0:
        values, _ = tf.math.top_k(logits, k=top_k)
        min_values = values[:, -1, None]
        logits = tf.where(
            logits < min_values,
            tf.fill(tf.shape(logits), float('-inf')),
            logits,
        )
    probabilities = tf.nn.softmax(logits)
    next_token = tf.random.categorical(tf.math.log(probabilities), num_samples=1)
    return tf.squeeze(next_token, axis=-1).numpy()

def generate_response(model, input_text, max_len=30, temperature=1.0, top_k=5):
    input_seq = np.array([encode_sentence(input_text, max_input_len)])
    response_seq = np.zeros((1, max_resp_len), dtype=np.int32)
    generated_tokens = []
    gating_probs = None

    for i in range(max_len):
        logits, gating_probs = model((input_seq, response_seq), training=False)
        logits_step = logits[:, i % max_resp_len, :]
        
        next_token = sample_from_logits(logits_step, temperature=temperature, top_k=top_k)[0]

        if next_token == 0:  # end on padding token
            break
        
        generated_tokens.append(idx2word[next_token])
        if i + 1 < max_resp_len:
            response_seq[0, i] = next_token
        else:
            response_seq = np.roll(response_seq, -1, axis=1)
            response_seq[0, -1] = next_token

    top_expert = np.argmax(gating_probs[0].numpy()) if gating_probs is not None else -1
    return " ".join(generated_tokens), top_expert, gating_probs[0].numpy()


# Demo
inference_prompts = [
    "hello there",
    "I want to order pizza",
    "goodbye",
    "can you help me order food",
    "hi, what's going on?",
    "will it rain tomorrow?",
    "how do I say goodbye politely?",
    "what toppings do you have?",
    "is today sunny or cloudy?",
    "see you soon"
]
for input_text in inference_prompts:
    response, expert_used, gating_distribution = generate_response(model, input_text, max_len=500)

    print("Input:", input_text)
    print("Generated response:", response)
    print("Top expert used:", expert_used)
    print("Gating probabilities:", np.round(gating_distribution, 3))
    print("---------------------------------")

In [None]:
from PIL import Image
import numpy as np

def classify_image(image_path):
    image = Image.open(image_path).convert("RGB")
    image_array = np.array(image)
    mean_colors = image_array.mean(axis=(0, 1))
    color_names = ['red', 'green', 'blue']
    dominant_color = color_names[np.argmax(mean_colors)]
    return f"Dominant color is {dominant_color}."

def multimodal_response(model, input_text=None, image_path=None):
    result = ""
    if input_text:
        # Use your generate_response logic for text
        response, exp_used, gating_distribution = generate_response(model, input_text, max_len=30)
        result += f"Text response: {response}\nTop expert used: {exp_used}\n"
    if image_path:
        # Classify the image
        image_result = classify_image(image_path)
        result += f"Image analysis: {image_result}\n"
    return result

# Example usage:
print(multimodal_response(model, input_text="what toppings do you have?", image_path="sample_image.jpg"))

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Dense, GlobalAveragePooling1D, Softmax, LSTM, Conv2D, Flatten, MaxPooling2D
from tensorflow.keras import Model
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

# ----------- MNIST DIGIT CLASSIFIER -----------

class DigitClassifier(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(32, (3,3), activation='relu')
        self.pool1 = MaxPooling2D((2,2))
        self.conv2 = Conv2D(64, (3,3), activation='relu')
        self.pool2 = MaxPooling2D((2,2))
        self.flatten = Flatten()
        self.dense = Dense(64, activation='relu')
        self.out = Dense(10)  # digits 0-9 logits

    def call(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.flatten(x)
        x = self.dense(x)
        return self.out(x)

def prepare_mnist_data():
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train = x_train.astype('float32') / 255.0
    x_test = x_test.astype('float32') / 255.0
    x_train = np.expand_dims(x_train, -1)
    x_test = np.expand_dims(x_test, -1)
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    return (x_train, y_train), (x_test, y_test)

# Train or load a pretrained digit classifier
digit_classifier = DigitClassifier()
(x_train, y_train), (x_test, y_test) = prepare_mnist_data()
digit_classifier.compile(optimizer='adam', loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])
print("Training digit classifier (this may take a bit)...")
digit_classifier.fit(x_train, y_train, epochs=3, batch_size=128, validation_split=0.1)

# ----------- YOUR MOE TEXT GENERATION MODEL (copy your existing) -----------

# (Paste your current MoEResponseGenerator, Expert, GatingNetwork, encode_sentence, sample_from_logits, etc. here)
# For brevity, let's assume it's loaded as `moe_model`

# For this example, just a placeholder generate_response:
def generate_response(model, input_text, max_len=30, temperature=1.0, top_k=5):
    # Use your real model's generation logic here
    return f"Simulated response to: {input_text}", 1, np.array([0.1, 0.7, 0.1, 0.1])

# ----------- MNIST INCREMENT LOGIC -----------

def get_incremented_digit_image(input_image, x_dataset, y_dataset):
    # Normalize pixel values and ensure float32
    img = input_image.astype('float32')
    if img.max() > 1.0:
        img /= 255.0

    # Add channel dimension if missing
    if len(img.shape) == 2:   # grayscale image shape (28,28)
        img = np.expand_dims(img, axis=-1)  # become (28,28,1)

    # Add batch dimension for model input
    img = np.expand_dims(img, axis=0)  # become (1,28,28,1)

    logits = digit_classifier(img)
    pred_digit = tf.argmax(logits, axis=1).numpy()[0]

    inc_digit = 0 if pred_digit == 9 else pred_digit + 1
    idx = np.where(y_dataset == inc_digit)[0][0]
    inc_image = x_dataset[idx][:, :, 0]  # remove channel for display

    return pred_digit, inc_digit, inc_image


# ----------- MULTIMODAL HANDLER -----------

def multimodal_handler(input_text=None, input_image=None):
    if input_text:
        response, expert, gating = generate_response(None, input_text)
        print(f"Text Input: {input_text}")
        print(f"Response: {response}")
        print(f"Top expert used: {expert}")
        print(f"Gating probabilities: {gating}")
    if input_image is not None:
        pred_digit, inc_digit, inc_image = get_incremented_digit_image(input_image, x_test, np.argmax(y_test, axis=1))
        print(f"Input Digit: {pred_digit}, Incremented Digit: {inc_digit}")
        plt.subplot(1,2,1)
        plt.imshow(input_image, cmap='gray')
        plt.title(f"Input: {pred_digit}")
        plt.axis('off')
        plt.subplot(1,2,2)
        plt.imshow(inc_image, cmap='gray')
        plt.title(f"Output: {inc_digit}")
        plt.axis('off')
        plt.show()

# ----------- DEMO -----------

# Text-only
multimodal_handler(input_text="I need help with my order.")

# Image-only
multimodal_handler(input_image=x_test[5])

# Both modalities
multimodal_handler(input_text="What digit is this?", input_image=x_test[7])
