In [None]:
# cbow_keras.py
import tensorflow as tf
from tensorflow.keras import layers, models
import numpy as np
import re

# -----------------------------
# a. DATA PREPARATION
# -----------------------------
text = """
The speed of transmission is an important point of difference between the two viruses.
Influenza has a shorter median incubation period (the time from infection to appearance of symptoms)
and a shorter serial interval (the time between successive cases) than COVID-19 virus.
The serial interval for COVID-19 virus is estimated to be 5-6 days, while for influenza virus,
the serial interval is 3 days. This means that influenza can spread faster than COVID-19.

Further, transmission in the first 3-5 days of illness, or potentially pre-symptomatic transmission
–transmission of the virus before the appearance of symptoms – is a major driver of transmission
for influenza. In contrast, while we are learning that there are people who can shed COVID-19 virus
24-48 hours prior to symptom onset, at present, this does not appear to be a major driver of transmission.

The reproductive number – the number of secondary infections generated from one infected individual –
is understood to be between 2 and 2.5 for COVID-19 virus, higher than for influenza.
However, estimates for both COVID-19 and influenza viruses are very context and time-specific,
making direct comparisons more difficult.
"""

# Clean and tokenize
text = text.lower()
text = re.sub(r'[^a-z0-9\s]', '', text)
words = text.split()

# Build vocabulary
vocab = sorted(set(words))
word2idx = {w: i for i, w in enumerate(vocab)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(vocab)
print("Vocab size:", vocab_size)

# -----------------------------
# b. GENERATE TRAINING DATA (CBOW)
# -----------------------------
window_size = 2
contexts, targets = [], []

for i in range(window_size, len(words) - window_size):
    context = (
        [word2idx[words[i - j]] for j in range(window_size, 0, -1)]
        + [word2idx[words[i + j]] for j in range(1, window_size + 1)]
    )
    target = word2idx[words[i]]
    contexts.append(context)
    targets.append(target)

contexts = np.array(contexts)
targets = np.array(targets)
print("Training samples:", len(contexts))

# -----------------------------
# c. BUILD + TRAIN MODEL
# -----------------------------
embedding_dim = 50

model = models.Sequential([
    layers.Embedding(input_dim=vocab_size, output_dim=embedding_dim, input_length=window_size * 2),
    layers.Lambda(lambda x: tf.reduce_mean(x, axis=1)),  # average context embeddings
    layers.Dense(vocab_size, activation='softmax')
])

model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')
model.fit(contexts, targets, epochs=50000, verbose=0)
print("✅ Training complete!")

# -----------------------------
# d. OUTPUT – example predictions
# -----------------------------
def predict_next_word(context_words):
    tokens = [word2idx[w] for w in context_words if w in word2idx]
    if len(tokens) < 2 * window_size:
        tokens = [0] * (2 * window_size - len(tokens)) + tokens
    tokens = np.array(tokens[-2 * window_size:]).reshape(1, -1)
    pred = model.predict(tokens, verbose=0)
    return idx2word[np.argmax(pred)]



# Print a few embeddings
embeddings = model.layers[0].get_weights()[0]
print("\nSample word embeddings:")
for w in ["virus", "transmission", "influenza", "covid19" if "covid19" in word2idx else "covid"]:
    print(w, embeddings[word2idx[w]][:5])


Vocab size: 98
Training samples: 179


In [4]:
test_contexts = [
    ["the", "speed", "of", "transmission"],   # should predict “is” or “an”
    ["shorter", "serial", "for", "virus"],    # should predict “interval”
    ["influenza", "can", "faster", "than"],   # should predict “spread”
    ["number", "of", "secondary", "infections"],  # should predict “generated”
    ["the", "time", "from", "infection"],     # should predict “to”
]

for ctx in test_contexts:
    pred = predict_next_word(ctx)
    print(f"Context: {ctx} → Predicted word: '{pred}'")


Context: ['the', 'speed', 'of', 'transmission'] → Predicted word: 'of'
Context: ['shorter', 'serial', 'for', 'virus'] → Predicted word: 'interval'
Context: ['influenza', 'can', 'faster', 'than'] → Predicted word: 'spread'
Context: ['number', 'of', 'secondary', 'infections'] → Predicted word: 'of'
Context: ['the', 'time', 'from', 'infection'] → Predicted word: 'time'
