In [82]:
import tensorflow as tf
from tensorflow.keras.layers import Embedding, Dense, GlobalAveragePooling1D, Softmax, LSTM
from tensorflow.keras import Model
import numpy as np

# Sample data (input, response)
train_data = [
    ("hello there", "hi, how can I help"),
    ("hi", "hello, what can I do"),
    ("goodbye", "goodbye, have a nice day"),
    ("see you later", "see you soon, goodbye"),
    ("I want to order pizza", "sure, what toppings do you want"),
    ("can I get a burger", "what size burger would you like"),
    ("what is the weather", "the weather today is sunny"),
    ("is it raining", "no rain expected today"),
    ("hey, I want some pasta", "what kind of pasta would you prefer"),
    ("do you have vegetarian options?", "yes, we have several vegetarian dishes"),
    ("good morning", "good morning, how may I assist you?"),
    ("bye", "take care, see you later"),
    ("will it be hot today?", "expect warm temperatures all day"),
    ("can I order a salad?", "what dressing would you like on your salad?"),
    ("thanks, goodbye", "you're welcome, goodbye!"),
    ("tell me the forecast", "the forecast shows clear skies"),
    ("what's your name?", "i am your assistant, here to help"),
    ("can I have a coffee?", "sure, would you like it black or with milk?"),
    ("thank you for the help", "happy to assist you anytime"),
    ("are you open today?", "yes, we are open from 9 am to 9 pm"),
    ("could you help me with my order", "of course, what would you like to order"),
    ("are there any gluten free options", "yes, we have several gluten free dishes available"),
    ("what are today's specials", "today's special is grilled salmon with vegetables"),
    ("how late are you open", "we are open until 10 pm tonight"),
    ("can you recommend a dessert", "our chocolate lava cake is very popular"),
    ("I need to change my order", "sure, what changes would you like to make"),
    ("do you deliver", "yes, we deliver within a 5 mile radius"),
    ("what payment methods do you accept", "we accept cash, credit cards, and mobile payments"),
    ("is there a parking facility", "yes, free parking is available behind the restaurant"),
    ("thank you very much", "you're welcome, happy to help"),
    ("I have a food allergy", "please let us know your allergy, and we will accommodate"),
    ("can I book a table", "yes, for how many people and what time"),
    ("what's your restaurant address", "we are located at 123 Main Street"),
    ("do you have vegan meals", "yes, we offer delicious vegan options"),
    ("can I get nutritional information", "nutritional info is available on our website"),
    ("how long is the wait time", "usually about 15 minutes during peak hours"),
    ("do you have a kids menu", "yes, we have a special menu for children"),
    ("can I cancel my order", "please provide your order number to cancel"),
    ("what are your opening hours", "we are open from 9 am to 10 pm daily"),
    ("is takeout available", "yes, you can order takeout anytime during opening hours"),        
]

# Build vocabulary
all_texts = [t[0] + " " + t[1] for t in train_data]
all_words = set(word for sentence in all_texts for word in sentence.lower().split())
word2idx = {w: i + 1 for i, w in enumerate(sorted(all_words))}
idx2word = np.array(['<pad>'] + sorted(all_words))
vocab_size = len(idx2word)

max_input_len = 6
max_resp_len = 8

def encode_sentence(sent, max_len):
    words = sent.lower().split()
    seq = [word2idx.get(w, 0) for w in words]
    seq = seq[:max_len] + [0] * (max_len - len(seq))
    return seq

X_input = np.array([encode_sentence(t[0], max_input_len) for t in train_data])
X_resp_in = np.array([encode_sentence(t[1], max_resp_len) for t in train_data])
X_resp_out = np.array([encode_sentence(t[1], max_resp_len)[1:] + [0] for t in train_data]) # shifted

class Expert(Model):
    def __init__(self, d_model):
        super().__init__()
        self.norm = tf.keras.layers.LayerNormalization()
        self.dense1 = Dense(64, activation='relu')
        self.dropout = tf.keras.layers.Dropout(0.2)
        self.dense2 = Dense(d_model)
    def call(self, x):
        x = self.norm(x)
        x = self.dense1(x)
        x = self.dropout(x)
        return self.dense2(x)

class GatingNetwork(Model):
    def __init__(self, d_model, num_experts):
        super().__init__()
        self.dense = Dense(num_experts)
        self.softmax = Softmax(axis=-1)
    def call(self, x):
        logits = self.dense(x)
        return self.softmax(logits)

class MoEResponseGenerator(Model):
    def __init__(self, vocab_size, d_model, num_experts, max_resp_len, lstm_units=128):
        super().__init__()
        self.embedding = Embedding(vocab_size, d_model, mask_zero=True)
        self.global_pool = GlobalAveragePooling1D()
        self.num_experts = num_experts
        self.experts = [Expert(d_model) for _ in range(num_experts)]
        self.gating_net = GatingNetwork(d_model, num_experts)
        self.lstm_units = lstm_units
        self.lstm = LSTM(lstm_units, return_sequences=True, return_state=True)
        self.to_h = Dense(lstm_units)
        self.to_c = Dense(lstm_units)
        self.output_layer = Dense(vocab_size)
        self.max_resp_len = max_resp_len

    def call(self, inputs, training=False):
        input_seq, resp_in_seq = inputs
        enc_emb = self.embedding(input_seq)
        pooled = self.global_pool(enc_emb)

        gating_probs = self.gating_net(pooled)
        expert_outs = tf.stack([expert(pooled) for expert in self.experts], axis=1)
        gated_rep = tf.reduce_sum(tf.expand_dims(gating_probs, 2) * expert_outs, axis=1)

        h_state = self.to_h(gated_rep)
        c_state = self.to_c(gated_rep)

        resp_emb = self.embedding(resp_in_seq)
        lstm_out, _, _ = self.lstm(resp_emb, initial_state=[h_state, c_state])

        logits = self.output_layer(lstm_out)
        return logits, gating_probs

# Hyperparams and dataset
d_model = 32
num_experts = 4
batch_size = 2
epochs = 250

dataset = tf.data.Dataset.from_tensor_slices(((X_input, X_resp_in), X_resp_out)).shuffle(20).batch(batch_size)

model = MoEResponseGenerator(vocab_size, d_model, num_experts, max_resp_len)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

@tf.function
def train_step(inputs, labels):
    with tf.GradientTape() as tape:
        logits, _ = model(inputs, training=True)
        loss = loss_fn(labels, logits)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    return loss

for epoch in range(epochs):
    total_loss = 0
    for in_batch, out_batch in dataset:
        loss = train_step(in_batch, out_batch)
        total_loss += loss.numpy()
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}/{epochs} Loss: {total_loss/len(dataset):.4f}")

Epoch 50/250 Loss: 0.2209
Epoch 100/250 Loss: 0.0372
Epoch 150/250 Loss: 0.0153
Epoch 200/250 Loss: 0.0062


2025-08-30 19:11:57.774868: I tensorflow/core/framework/local_rendezvous.cc:407] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence


Epoch 250/250 Loss: 0.0032


In [83]:
def sample_from_logits(logits, temperature=1.0, top_k=5):
    logits = logits / temperature
    if top_k > 0:
        values, _ = tf.math.top_k(logits, k=top_k)
        min_values = values[:, -1, None]
        logits = tf.where(
            logits < min_values,
            tf.fill(tf.shape(logits), float('-inf')),
            logits,
        )
    probabilities = tf.nn.softmax(logits)
    next_token = tf.random.categorical(tf.math.log(probabilities), num_samples=1)
    return tf.squeeze(next_token, axis=-1).numpy()

def generate_response(model, input_text, max_len=30, temperature=1.0, top_k=5):
    input_seq = np.array([encode_sentence(input_text, max_input_len)])
    response_seq = np.zeros((1, max_resp_len), dtype=np.int32)
    generated_tokens = []
    gating_probs = None

    for i in range(max_len):
        logits, gating_probs = model((input_seq, response_seq), training=False)
        logits_step = logits[:, i % max_resp_len, :]
        
        next_token = sample_from_logits(logits_step, temperature=temperature, top_k=top_k)[0]

        if next_token == 0:  # end on padding token
            break
        
        generated_tokens.append(idx2word[next_token])
        if i + 1 < max_resp_len:
            response_seq[0, i] = next_token
        else:
            response_seq = np.roll(response_seq, -1, axis=1)
            response_seq[0, -1] = next_token

    top_expert = np.argmax(gating_probs[0].numpy()) if gating_probs is not None else -1
    return " ".join(generated_tokens), top_expert, gating_probs[0].numpy()


# Demo
inference_prompts = [
    "hello there",
    "I want to order pizza",
    "goodbye",
    "can you help me order food",
    "hi, what's going on?",
    "will it rain tomorrow?",
    "how do I say goodbye politely?",
    "what toppings do you have?",
    "is today sunny or cloudy?",
    "see you soon"
]
for input_text in inference_prompts:
    response, expert_used, gating_distribution = generate_response(model, input_text, max_len=500)

    print("Input:", input_text)
    print("Generated response:", response)
    print("Top expert used:", expert_used)
    print("Gating probabilities:", np.round(gating_distribution, 3))
    print("---------------------------------")

Input: hello there
Generated response: how can i help
Top expert used: 0
Gating probabilities: [0.267 0.244 0.252 0.237]
---------------------------------
Input: I want to order pizza
Generated response: what toppings do you want
Top expert used: 0
Gating probabilities: [0.273 0.25  0.267 0.21 ]
---------------------------------
Input: goodbye
Generated response: have a nice day
Top expert used: 0
Gating probabilities: [0.321 0.238 0.2   0.241]
---------------------------------
Input: can you help me order food
Generated response: what would like like like
Top expert used: 1
Gating probabilities: [0.262 0.267 0.245 0.226]
---------------------------------
Input: hi, what's going on?
Generated response: how how i help you?
Top expert used: 2
Gating probabilities: [0.261 0.228 0.27  0.24 ]
---------------------------------
Input: will it rain tomorrow?
Generated response: warm temperatures all day
Top expert used: 2
Gating probabilities: [0.275 0.202 0.291 0.233]
------------------------