In [305]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Dense, SimpleRNN
from tensorflow.keras.models import Model
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam

In [306]:
# Define parameters
num_generations = 4
tf.random.set_seed(42)

In [307]:
train_data_raw = [
    "a red cat sat on the mat",
    "the clever dog played with a ball",
    "one stray cat sat on a ball",
    "a blue buffalo kicked a small tree",
    "a big ball flew through our window"
]


vocab = set()
for datapoint in train_data_raw:
  words = datapoint.split(' ')
  for word in words:
    vocab.add(word)

encoding_to_word = {}
word_to_encoding = {}

for i, word in enumerate(vocab):
    encoding_to_word[i] = word
    word_to_encoding[word] = i

vocab_size = len(vocab)
sentence_size = len(train_data_raw[0].split(' '))
data_size = len(train_data_raw)

train_data = np.zeros((len(train_data_raw), sentence_size))
for i, dp in enumerate(train_data_raw):
  words = dp.split(' ')
  for j, word in enumerate(words):
    train_data[i, j] = word_to_encoding[word]

In [308]:
def create_model():

    # Define the input shape (assuming each input data point is a sequence of vectors)
    input_shape = (data_size, sentence_size)  
    output_shape = ((data_size * sentence_size), vocab_size)

    # Define the input layer
    input_layer = Input(shape=input_shape)

    flattened_input = tf.keras.layers.Flatten()(input_layer)

    # Add a recurrent layer (SimpleRNN) for sequence processing
    # rnn_layer = SimpleRNN(units=64, activation='relu')(flattened_input)
    mid_layer = Dense(64, activation='linear')(flattened_input)

    # Add a Dense layer with softmax activation for the output
    output_layer = Dense(data_size * sentence_size * vocab_size, activation='softmax')(mid_layer)

    # Reshape the output to the desired 2D shape
    output_layer = tf.keras.layers.Reshape(output_shape)(output_layer)

    # Create the model
    model = Model(inputs=input_layer, outputs=output_layer)

    # Compile the model with cross-entropy loss and an optimizer (e.g., Adam)
    model.compile(optimizer=Adam(learning_rate=0.001), loss=categorical_crossentropy, metrics=['accuracy'])

    # Print the model summary
    # model.summary()
    
    return model

In [309]:
def create_label(dataset_encoding):
    label = np.zeros((data_size, sentence_size, vocab_size))
    for i, dp in enumerate(dataset_encoding):
        for j, word in enumerate(dp):
            label[i, j, int(word)] = 1
    return label.reshape(((data_size * sentence_size), vocab_size))

def train_model(model, input, labels, epochs):
    model.fit(np.array([input]), np.array([labels]), epochs=epochs, verbose=0)

def get_prediction(model, input):
    out = np.array(model(np.array([input])))[0]
    out = out.reshape((data_size, sentence_size, vocab_size))
    return np.argmax(out, axis=2)

def get_sentence(prediction):
    sentence = ""
    for i in range(data_size):
        for j in range(sentence_size):
            sentence += encoding_to_word[prediction[i, j]] + " "
        sentence += "\n"
    return sentence

In [310]:
def iterated_learning(epochs):
    label = create_label(train_data)
    last_sent = ""

    for _ in range(num_generations):
        model = create_model()
        train_model(model, train_data, label, epochs)

        prediction = get_prediction(model, train_data)
        label = create_label(prediction)
        
        last_sent = get_sentence(prediction)

    print("Epochs : " + str(epochs))
    print(last_sent)

In [311]:
epoch_list = [1, 5, 10, 15, 50]
for epoch in epoch_list:
    iterated_learning(epoch)

Epochs : 1
tree on red dog tree blue flew 
kicked red cat mat big through red 
window sat the with big red stray 
blue ball flew through tree cat buffalo 
blue kicked the dog stray our buffalo 

Epochs : 5
cat mat the dog clever red window 
window mat mat blue dog tree big 
with played stray dog played the ball 
through one a dog a with one 
the flew sat through blue tree red 

Epochs : 10
red the cat the small the window 
the clever tree played flew one ball 
one stray blue sat played a ball 
a our buffalo tree a small blue 
blue big ball flew through our window 

Epochs : 15
a red cat sat on the mat 
the clever dog played with a ball 
one stray flew sat on a ball 
a blue buffalo kicked a small tree 
a big ball flew through our window 

Epochs : 50
a red cat sat on the mat 
the clever dog played with a ball 
one stray cat sat on a ball 
a blue buffalo kicked a small tree 
a big ball flew through our window 

