In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '4'
from tensorflow.keras import Model
from tensorflow.keras import backend as K 
from tensorflow.keras.callbacks import LambdaCallback
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization, Activation, Flatten, Conv1D, Add, Multiply, Lambda, Conv2DTranspose, Concatenate, UpSampling2D, Reshape, Dot, Permute, RepeatVector, Embedding
import tensorflow as tf
import typing
from keras.utils import Sequence
import numpy as np
from generators.text_generator import C_T_I, I_T_C, CHR_TO_IDX, IDX_TO_CHR
from tqdm.notebook import tqdm
from generators import text_generator

from curricula import selection

from models import masked_language

#### Run Params

Probably can move these to a cfg file, but ehhh

In [3]:
def return_checkpoints(target_path, log_path):
    checkpoint_filepath = target_path
    model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
        filepath=checkpoint_filepath,
        monitor='accuracy',
        mode='max',
        save_best_only=True)

    import datetime
    tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=f"{log_path}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}", histogram_freq=1)
    os.makedirs(f"{log_path}", exist_ok=True)
    
    csv_callback = tf.keras.callbacks.CSVLogger(f"{log_path}_{datetime.datetime.now().strftime('%Y%m%d-%H%M%S')}.csv", append=True)

    return [model_checkpoint_callback, csv_callback]

In [4]:
run_name = "charpred_baseline_gru_rho_sel"

irred_chk, irred_log = f"results/{run_name}_irred.keras", f"results/logs/{run_name}_irred"
irred_checkpoints = return_checkpoints(irred_chk, irred_log)

control_chk, control_log = f"results/{run_name}_control.keras", f"results/logs/{run_name}_control"
control_checkpoints = return_checkpoints(control_chk, control_log)

target_chk, target_log = f"results/{run_name}_target.keras", f"results/logs/{run_name}_target"
target_checkpoints = return_checkpoints(target_chk, target_log)

length = 31
embed_dim = 5

#! Model params
irred_dense = [128, 64, 32]

target_dense = control_dense = [256, 128, 128, 64] #! Large model for more capacity
target_dims = control_dims = [128, 64, 64, 32]

run_params = {
    "run_name": run_name,
    "irred_chk": irred_chk,
    "irred_log": irred_log,
    "control_chk": control_chk,
    "control_log": control_log,
    "target_chk": target_chk,
    "target_log": target_log,
    "length": length,
    "embed_dim": embed_dim,
    "irred_dense": irred_dense,
    "target_dense": target_dense,
    "control_dense": control_dense,
    "target_dims": target_dims,
    "control_dims": control_dims

}


#### Loading Dataloaders

In [5]:
train_gen = text_generator.pretraining_generator(
    "data/text/cleaned_train.txt", 4, 4096, True
)

val_gen = text_generator.pretraining_generator(
    "data/text/cleaned_val.txt", 4,  4096, True
)

test_gen = text_generator.pretraining_generator(
    "data/text/cleaned_test.txt", 4,  4096, True
)

#### Calculating Baseline

In [19]:
control_model = masked_language.GRU_ENC(
    31, embed_dim, control_dims, control_dense
)

control_model.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001),
    loss = tf.keras.losses.CategoricalCrossentropy(),
    metrics = ["accuracy"]params, base_dir, batch_size, gentype='train', return_spec=return_spec, return_fft=return_fft, ext=ext,
)


control_model.summary()



Epoch 1/10


  self._warn_if_super_not_called()


[1m  6/188[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m23:42[0m 8s/step - accuracy: 0.1282 - loss: 3.1901

KeyboardInterrupt: 

In [None]:
control_model.fit(train_gen, validation_data=test_gen, epochs=10, callbacks=control_checkpoints)

#### Irreducibe Model Training

In [6]:
irred_model = masked_language.baseline_dense_concat(length, embed_dim, irred_dense) #! change this according to your base model. also can probably look at changing base_cnn params

irred_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

irred_model.summary()






In [7]:
irred_model.fit(val_gen, validation_data=test_gen, epochs=10, callbacks=irred_checkpoints)

Epoch 1/10


  self._warn_if_super_not_called()


[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 232ms/step - accuracy: 0.1340 - loss: 3.0216 - val_accuracy: 0.2241 - val_loss: 2.8805
Epoch 2/10
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 213ms/step - accuracy: 0.2343 - loss: 2.8617 - val_accuracy: 0.2389 - val_loss: 2.8168
Epoch 3/10
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 243ms/step - accuracy: 0.2429 - loss: 2.7982 - val_accuracy: 0.2569 - val_loss: 2.7672
Epoch 4/10
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 228ms/step - accuracy: 0.2565 - loss: 2.7556 - val_accuracy: 0.2606 - val_loss: 2.7376
Epoch 5/10
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 223ms/step - accuracy: 0.2622 - loss: 2.7247 - val_accuracy: 0.2658 - val_loss: 2.7148
Epoch 6/10
[1m33/33[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 220ms/step - accuracy: 0.2714 - loss: 2.6981 - val_accuracy: 0.2755 - val_loss: 2.6920
Epoch 7/10
[1m33/33[0m [32m━━━━━━━━━

<keras.src.callbacks.history.History at 0x7f1b89756910>

#### Reducible Model Training

In [8]:
target_model = masked_language.GRU_ENC(
    31, embed_dim, control_dims, control_dense
)

target_model.compile(
    optimizer = tf.keras.optimizers.Adam(learning_rate=0.001),
    loss = tf.keras.losses.CategoricalCrossentropy(),
    metrics = ["accuracy"]
)


target_model.summary()



In [12]:
irred_model = tf.keras.models.load_model(irred_chk)

train_rho_gen = text_generator.rho_generator_audio( "data/text/cleaned_train.txt", samples_per_word=4, batch_size=4096, pad=True, selector=selection.irreducible_loss_selector, irred_model=irred_model, target_model=target_model, epoch_cutoff=0)

In [13]:
target_model.fit(train_rho_gen, validation_data=test_gen, epochs=10, callbacks=target_checkpoints)

Epoch 1/10


  self._warn_if_super_not_called()


[1m  2/188[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m30:52[0m 10s/step - accuracy: 0.0823 - loss: 3.2526  

KeyboardInterrupt: 