In [195]:
import os

from pprint import pprint
from contextlib import redirect_stdout
import pandas as pd
import numpy as np
import tensorflow as tf

# Output library versions
print(f"os: Built-in module (no version)")
print(f"pprint: Built-in module (no version)")
print(f"contextlib: Built-in module (no version)")
print(f"numpy version: {np.__version__}")
print(f"tensorflow version: {tf.__version__}")


os: Built-in module (no version)
pprint: Built-in module (no version)
contextlib: Built-in module (no version)
numpy version: 1.26.4
tensorflow version: 2.18.0


In [196]:
# Updated encoding legend
ENCODING_LEGEND = {
    'MRI_CCS_11': 1, 'MRI_EXU_95': 2, 'MRI_FRR_18': 3, 'MRI_FRR_257': 4,
    'MRI_FRR_264': 5, 'MRI_FRR_3': 6, 'MRI_FRR_34': 7, 'MRI_MPT_1005': 8,
    'MRI_MSR_100': 9, 'MRI_MSR_104': 10, 'MRI_MSR_21': 11, 'MRI_MSR_34': 12,
    'START': 13,  # Start token
    'END': 14     # End token
}

CHAR_TO_INT = {
    '0': 0,
    '1': 1,
    '2': 2,
    '3': 3,
    '4': 4,
    '5': 5,
    '6': 6,
    '7': 7,
    '8': 8,
    '9': 9,

}


In [197]:
# --------------------------------------- DATA FUNCTIONS -----------------------------------

# Updated start and end tokens
START_TOKEN = 13
END_TOKEN = 14

def generate_data(data_size=100):
    """
    Generate synthetic sourceID data sequences.
    Each sequence starts with the START token (13) and ends with the END token (14).
    Random sourceIDs (from 1 to 12) are included in between.
    """
    
    def generate_cond_sequence(condition):
        """
        Generate a sequence of sourceIDs based on a condition.
        The condition is an integer (sourceID) between 1 and 12.
        """
        if condition < 1 or condition > 12:
            raise ValueError("Condition must be between 1 and 12.")

        random_number = np.random.choice([0, 1, 2])

        if random_number == 0:
            length = 3
        elif random_number == 1:
            length = 6
        elif random_number == 2:
            length = 10
        else:
            raise NotImplementedError()

        # Generate random sourceIDs for the sequence
        return np.random.randint(1, 13, size=length).tolist()

    conditions_list = []
    sequences_lists = []

    for i in range(data_size):
        # Randomly select a condition (sourceID between 1 and 12)
        condition = np.random.randint(1, 13)

        # Generate a sequence based on the condition
        seq = generate_cond_sequence(condition)

        # Add start and end tokens to the sequence
        seq = [START_TOKEN] + seq + [END_TOKEN]

        conditions_list.append(condition)
        sequences_lists.append(seq)

    return conditions_list, sequences_lists

In [198]:
def add_start_end_tokens(conds, seqs):
    new_seqs = []
    for seq in seqs:
        # Prepend START token and append END token
        new_seqs.append([13] + seq + [14])  # Use integers for START and END tokens
    return conds, new_seqs
    
    
def make_same_length(conds, seqs):
    new_seqs = []
    
    # Find the maximum sequence length
    max_length = max(len(seq) for seq in seqs)
    
    for seq in seqs:
        # Calculate how many tokens to add
        remaining_tokens = max_length - len(seq)
        # Pad the sequence with the END token (14 as per your legend)
        new_seqs.append(seq + [14] * remaining_tokens)  # END_TOKEN as an integer
    
    return conds, new_seqs



def make_training_data(conds, seqs):
    
    input_seqs = []
    
    for seq in seqs:
        input_seqs.append(seq[:-1])
    
    input_data = (conds, input_seqs)
    
    output_seqs = []
    
    for seq in seqs:
        output_seqs.append(seq[1:])
        
    output_data = output_seqs
    
    return input_data, output_data

In [199]:
# Define mappings
RAW_STRINGS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
CHAR_TO_INT = dict(zip(RAW_STRINGS, range(1, len(RAW_STRINGS) + 1)))
INT_TO_CHAR = {v: k for k, v in CHAR_TO_INT.items()}

class ConditionMapper:
    def __init__(self):
        # Define your mappings as before
        self.integer_map = ENCODING_LEGEND.copy()  # Your original map
        self.string_map = {v: k for k, v in self.integer_map.items()}  # Reverse mapping for integer to string
        self.dimension = len(self.integer_map) + 1  # Account for possible 'END' token

    def map_to_ints(self, input_string):
        # Check if the input is a string (e.g., 'MRI_MSR_100') or an integer (e.g., 10)
        if isinstance(input_string, int):
            # Check if the integer exists in the reverse mapping (for keys like 10 corresponding to 'MRI_MSR_100')
            input_string = self.string_map.get(input_string, None)
            if input_string is None:
                raise KeyError(f"Integer '{input_string}' not found in string map")
        
        # Now input_string should be a valid string that can be mapped to an integer
        if input_string not in self.integer_map:
            raise KeyError(f"Character '{input_string}' not found in integer map")
        
        return np.array([self.integer_map[input_string]])

    def map_list_to_ints_vectors(self, list_of_strings):
        return np.array([self.map_to_ints(input_string) for input_string in list_of_strings])

    def map_ints_to_string(self, input_ints):
        return "".join([self.string_map.get(integer, '') for integer in input_ints])

# Example usage:
cond_mapper = ConditionMapper()

# If you're working with integers:
input_string = 10  # This should map to a valid string if it corresponds to a known integer
output_int = cond_mapper.map_to_ints(input_string)
print(output_int)




[10]


In [200]:
class SeqMapper():
    
    def __init__(self):
        self.integer_map = CHAR_TO_INT.copy()
        
        n_ints = len(self.integer_map)
        
        # self.integer_map.update({START_TOKEN: n_ints, END_TOKEN: n_ints + 1})
        # self.dimension = len(self.integer_map)
        
        self.integer_map.update({START_TOKEN: n_ints + 1, END_TOKEN: n_ints + 2})
        
        self.dimension = len(self.integer_map) + 1
        
        self.string_map = dict(zip(list(self.integer_map.values()), list(self.integer_map.keys())))
    
    def map_to_ints(self, input_string):
        input_numbers = [self.integer_map[item] for item in input_string]
        
        return np.asarray(input_numbers)
    
    def map_list_to_ints_vectors(self, list_of_strings):
        input_length = len(list_of_strings)
        
        vectors = []
        
        for input_string in list_of_strings:
            vectors.append(self.map_to_ints(input_string))
            
        vectors = np.asarray(vectors)
        
        return vectors
    
    def map_ints_to_string(self, input_ints):
        string = ""
        
        for integer in input_ints:
            string += self.string_map[integer]
            
        return string


In [201]:
def convert_training_data(input_train, output_train, mappers):
    converted_input_train = mappers[0].map_list_to_ints_vectors(input_train[0]), mappers[1].map_list_to_ints_vectors(input_train[1])
    converted_output_train = mappers[1].map_list_to_ints_vectors(output_train)
    
    return converted_input_train, converted_output_train

In [202]:
# --------------------------------------- TENSORFLOW LAYERS AND MODELS -----------------------------------
    
# --------------------------------------- masked layers
def masked_loss(label, pred):
    mask = label != 0
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none')
    loss = loss_object(label, pred)

    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask

    loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)
    return loss


def masked_accuracy(label, pred):
    pred = tf.argmax(pred, axis=2)
    label = tf.cast(label, pred.dtype)
    match = label == pred

    mask = label != 0

    match = match & mask

    match = tf.cast(match, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(match) / tf.reduce_sum(mask)    

In [203]:
# --------------------------------------- positional embedding
def positional_encoding(length, depth):
    depth = depth / 2

    positions = np.arange(length)[:, np.newaxis]  # (seq, 1)
    depths = np.arange(depth)[np.newaxis, :] / depth  # (1, depth)

    angle_rates = 1 / (10000 ** depths)  # (1, depth)
    angle_rads = positions * angle_rates  # (pos, depth)

    pos_encoding = np.concatenate(
        [np.sin(angle_rads), np.cos(angle_rads)],
        axis=-1)

    return tf.cast(pos_encoding, dtype=tf.float32)


class PositionalEmbedding(tf.keras.layers.Layer):
    def __init__(self, vocab_size,
                 d_model,
                 use_embedding=True):

        super().__init__()
        self.d_model = d_model
        self.use_embedding = use_embedding

        if self.use_embedding:
            self.embedding = tf.keras.layers.Embedding(vocab_size, d_model, mask_zero=True)
        else:
            self.embedding = tf.keras.layers.Dense(d_model, activation="relu")

        self.pos_encoding = positional_encoding(length=2048, depth=d_model)

    def compute_mask(self, *args, **kwargs):
        if self.use_embedding:
            return self.embedding.compute_mask(*args, **kwargs)
        else:
            return None

    def call(self, x):
        length = tf.shape(x)[1]
        x = self.embedding(x)
        # This factor sets the relative scale of the embedding and positional_encoding.
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x = x + self.pos_encoding[tf.newaxis, :length, :]
        return x

In [204]:
# --------------------------------------- attention layers
class BaseAttention(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__()
        self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
        self.layernorm = tf.keras.layers.LayerNormalization()
        self.add = tf.keras.layers.Add()


class CrossAttention(BaseAttention):
    def call(self, x, context):
        attn_output, attn_scores = self.mha(
            query=x,
            key=context,
            value=context,
            return_attention_scores=True)

        # Cache the attention scores for plotting later.
        self.last_attn_scores = attn_scores

        x = self.add([x, attn_output])
        x = self.layernorm(x)

        return x


class GlobalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x


class CausalSelfAttention(BaseAttention):
    def call(self, x):
        attn_output = self.mha(
            query=x,
            value=x,
            key=x,
            use_causal_mask=True)
        x = self.add([x, attn_output])
        x = self.layernorm(x)
        return x
    

class FeedForward(tf.keras.layers.Layer):
    def __init__(self, d_model, dff, dropout_rate=0.1):
        super().__init__()
        self.seq = tf.keras.Sequential([
            tf.keras.layers.Dense(dff, activation='relu'),
            tf.keras.layers.Dense(d_model),
            tf.keras.layers.Dropout(dropout_rate)
        ])
        self.add = tf.keras.layers.Add()
        self.layer_norm = tf.keras.layers.LayerNormalization()

    def call(self, x):
        x = self.add([x, self.seq(x)])
        x = self.layer_norm(x)
        return x


class SelfAttentionFeedForwardLayer(tf.keras.layers.Layer):
    def __init__(self, *, d_model, num_heads, dff, dropout_rate=0.1,
                 attention="global"):

        super().__init__()

        if attention == "global":
            self.self_attention = GlobalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        elif attention == "causal":
            self.self_attention = CausalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        else:
            raise NotImplemented(f"The choice {attention} for attention is not implemented.")

        self.ffn = FeedForward(d_model, dff)

    def call(self, x):
        x = self.self_attention(x)
        x = self.ffn(x)
        return x


class SelfAttentionCrossAttentionFeedForwardLayer(tf.keras.layers.Layer):
    def __init__(self,
                 *,
                 d_model,
                 num_heads,
                 dff,
                 dropout_rate=0.1,
                 attention="causal"):

        super(SelfAttentionCrossAttentionFeedForwardLayer, self).__init__()

        if attention == "global":
            self.self_attention = GlobalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        elif attention == "causal":
            self.self_attention = CausalSelfAttention(
                num_heads=num_heads,
                key_dim=d_model,
                dropout=dropout_rate)
        else:
            raise NotImplemented(f"The choice {attention} for attention is not implemented.")

        self.cross_attention = CrossAttention(
            num_heads=num_heads,
            key_dim=d_model,
            dropout=dropout_rate)

        self.ffn = FeedForward(d_model, dff)

    def call(self, x, context):
        x = self.self_attention(x=x)
        x = self.cross_attention(x=x, context=context)

        # Cache the last attention scores for plotting later
        self.last_attn_scores = self.cross_attention.last_attn_scores

        x = self.ffn(x)  # Shape `(batch_size, seq_len, d_model)`.
        return x
    
# --------------------------------------- encoder
class Encoder(tf.keras.Model):
    def __init__(self, *, 
                 num_layers, 
                 d_model, 
                 num_heads,
                 dff, 
                 vocab_size, 
                 dropout_rate=0.1):
        super().__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(
            vocab_size=vocab_size, d_model=d_model)

        self.enc_layers = [
            SelfAttentionFeedForwardLayer(d_model=d_model,
                                          num_heads=num_heads,
                                          dff=dff,
                                          dropout_rate=dropout_rate)
            for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)

    def call(self, x):
        x = self.pos_embedding(x)  # Shape `(batch_size, seq_len, d_model)`.
        x = self.dropout(x)
        for i in range(self.num_layers):
            x = self.enc_layers[i](x)

        return x  # Shape `(batch_size, seq_len, d_model)`.
    
# --------------------------------------- decoder
class Decoder(tf.keras.Model):
    def __init__(self, *, num_layers, d_model, num_heads, dff, vocab_size,
                 dropout_rate=0.1):
        super().__init__()
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.pos_embedding = PositionalEmbedding(vocab_size=vocab_size,
                                                 d_model=d_model)
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        self.dec_layers = [
            SelfAttentionCrossAttentionFeedForwardLayer(d_model=d_model, num_heads=num_heads,
                                                        dff=dff, dropout_rate=dropout_rate)
            for _ in range(num_layers)]

        self.last_attn_scores = None

    def call(self, x, context):
        x = self.pos_embedding(x)  # (batch_size, target_seq_len, d_model)

        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.dec_layers[i](x, context)

        self.last_attn_scores = self.dec_layers[-1].last_attn_scores

        return x # (batch_size, target_seq_len, d_model)

In [205]:
# --------------------------------------- transformer
class Transformer(tf.keras.Model):
    def __init__(self, *, 
                 num_layers, 
                 d_model, 
                 num_heads, 
                 dff,
                 input_vocab_size, 
                 target_vocab_size, 
                 dropout_rate=0.1):
        
        super().__init__()
        self.encoder = Encoder(num_layers=num_layers, 
                               d_model=d_model,
                               num_heads=num_heads, 
                               dff=dff,
                               vocab_size=input_vocab_size,
                               dropout_rate=dropout_rate)

        self.decoder = Decoder(num_layers=num_layers, 
                               d_model=d_model,
                               num_heads=num_heads, 
                               dff=dff,
                               vocab_size=target_vocab_size,
                               dropout_rate=dropout_rate)

        self.final_layer = tf.keras.Sequential([tf.keras.layers.Dense(target_vocab_size)])
        
    def save_model_weights(self, save_folder):
        self.encoder.save_weights(os.path.join(save_folder, "encoder.weights.h5"))
        self.decoder.save_weights(os.path.join(save_folder, "decoder.weights.h5"))
        self.final_layer.save_weights(os.path.join(save_folder, "final_layer.weights.h5"))

    def load_model_from_weights(self, save_folder):
        self.encoder.load_weights(os.path.join(save_folder, "encoder.weights.h5"))
        self.decoder.load_weights(os.path.join(save_folder, "decoder.weights.h5"))
        self.final_layer.load_weights(os.path.join(save_folder, "final_layer.weights.h5"))

    def get_models(self):
        return self.encoder, self.decoder

    def call(self, inputs):
        context, x = inputs
        context = self.encoder(context)  # (batch_size, context_len, d_model)
        x = self.decoder(x, context)  # (batch_size, target_len, d_model)
        logits = self.final_layer(x)  # (batch_size, target_len, target_vocab_size)

        try:
            # Drop the keras mask, so it doesn't scale the losses/metrics.
            # b/250038731
            del logits._keras_mask
        except AttributeError:
            pass

        return logits

In [206]:
# --------------------------------------- translator
class Translator(tf.Module):
    def __init__(self, transformer, mappers, sample=True, min_prob=0.05):
        self.cond_mapper = mappers[0]
        self.seq_mapper = mappers[1]
        self.transformer = transformer
        
        self.sample = sample
        
        assert 0 <= min_prob < 1
        
        self.min_prob = min_prob

    def __call__(self, cond_seq, max_length=15):
        if not isinstance(cond_seq, tf.Tensor):
            cond_seq = tf.convert_to_tensor(cond_seq)

        encoder_input = tf.expand_dims(cond_seq, 0)
        
        seq_start = self.seq_mapper.integer_map[START_TOKEN]
        seq_end = self.seq_mapper.integer_map[END_TOKEN]

        output_array = tf.TensorArray(dtype=tf.int64, size=0, dynamic_size=True)
        output_array = output_array.write(0, seq_start)

        for i in tf.range(max_length):
            output = tf.transpose(output_array.stack())
            predictions = self.transformer([encoder_input, tf.expand_dims(output, 0)], training=False)

            predictions = predictions[0, -1:, :]  # Shape `(batch_size, 1, vocab_size)`.
            
            if self.sample:
                # calculate and transform probabilities
                probabilities = tf.math.softmax(predictions)
                high_probs = tf.where(probabilities > self.min_prob, probabilities, 0)[0]
                high_probs = high_probs / tf.math.reduce_sum(high_probs)
                
                # print(high_probs)
                
                # sample next id
                predicted_id = np.random.choice(len(high_probs), size=1, p=high_probs.numpy())[0]
            else:
                # take id with highest probability
                predicted_id = tf.argmax(predictions, axis=-1)[0]

            output_array = output_array.write(i + 1, predicted_id)

            if predicted_id == seq_end:
                break

        output = tf.transpose(output_array.stack())

        text = self.seq_mapper.map_ints_to_string(output.numpy())  # Shape: `()`.

        self.transformer([encoder_input, tf.expand_dims(output[:-1], 0)], training=False)
        attention_weights = self.transformer.decoder.last_attn_scores

        return text, attention_weights

In [207]:
# --------------------------------------- main use case
def main(train=True, 
         compute_results=True,
         n_conds=-1,
         n_samples=3,
         epochs=1, 
         checkpoint_name="test_transformer"):
    
    # Path to tokenization directory
    tokenization_dir = "../data/Tokenization"  # Directory containing the 300 CSV files
    csv_files = sorted(
        [os.path.join(tokenization_dir, file) for file in os.listdir(tokenization_dir) if file.endswith(".csv")]
    )
    
    # Extract `sourceID` sequences from the CSV files
    conditions = []
    sequences = []
    
    for file in csv_files:
        data = pd.read_csv(file)  # Load CSV file
        source_ids = data['sourceID'].dropna().astype(int).tolist()  # Extract `sourceID` column, drop NaNs, convert to int
        
        if source_ids:
            conditions.append(source_ids[0])  # Use the first `sourceID` as the condition
            sequences.append(source_ids)     # Use the entire column as the sequence
    
    # Save raw data for reference
    os.makedirs(checkpoint_name, exist_ok=True)
    with open(os.path.join(checkpoint_name, "raw_data.txt"), "w") as f:
        print_dict = {}
        for i, cond in enumerate(conditions):
            if cond not in print_dict:
                print_dict[cond] = []
            
            print_dict[cond].append(sequences[i])
                
        with redirect_stdout(f):
            pprint(print_dict)
    
    # Add start and end tokens to sequences
    conditions, sequences = add_start_end_tokens(conditions, sequences)
    conditions, sequences = make_same_length(conditions, sequences)
    
    # Save processed data for training
    with open(os.path.join(checkpoint_name, "train_data.txt"), "w") as f:
        print_dict = {}
        for i, cond in enumerate(conditions):
            if cond not in print_dict:
                print_dict[cond] = []
            
            print_dict[cond].append(sequences[i])
                
        with redirect_stdout(f):
            pprint(print_dict)
    
    # Instantiate mappers for conditions and sequences
    cond_mapper = ConditionMapper()
    seq_mapper = SeqMapper()

    
    # generate training data
    raw_model_input_data, raw_model_output_data = make_training_data(conditions, sequences)
    model_input_data, model_output_data = convert_training_data(raw_model_input_data, 
                                                                raw_model_output_data, 
                                                                mappers=(cond_mapper, seq_mapper))
    
    # build transformer
    
    # tf example
    # num_layers = 4
    # d_model = 128
    # dff = 512
    # num_heads = 8
    # dropout_rate = 0.1
    
    transformer = Transformer(
        num_layers=3,
        d_model=32,
        num_heads=8,
        dff=128,
        input_vocab_size=cond_mapper.dimension,
        target_vocab_size=seq_mapper.dimension,
        dropout_rate=0.1,
    )
    
    transformer(model_input_data)
    transformer.summary()
    
    if not os.path.exists(checkpoint_name):
        os.mkdir(checkpoint_name)
    
    if os.path.exists(os.path.join(checkpoint_name, "encoder.weights.h5")):
        transformer.load_model_from_weights(checkpoint_name)
        
        print("Model loaded successfully!")
        print("\n")

    # compile transformer
    transformer.compile(
        loss=masked_loss,
        optimizer='Adam',
        metrics=[masked_accuracy])
    
    if train:
        print("---------------------------- Training model ----------------------------")
        
        # fit transformer on batches
        transformer.fit(model_input_data,
                        model_output_data,
                        epochs=epochs,
                        # validation_data=val_batches
                        )
        
        transformer.save_model_weights(checkpoint_name)
    
    # build translator
    translator = Translator(transformer, mappers=(cond_mapper, seq_mapper))
    
    # translate examples
    
    if compute_results:
        print("---------------------------- Using model ----------------------------")
            
        translation_results = dict()
        
        if n_conds == -1:
            strings_to_use = RAW_STRINGS
        elif 0 < n_conds < len(RAW_STRINGS):
            strings_to_use = RAW_STRINGS[:n_conds]
        else:
            raise NotImplementedError()
            
        for char in RAW_STRINGS:
            
            print(f"Generating output for condition {char}")
            
            translation_results[char] = []
            example_cond = cond_mapper.map_to_ints(char)
            
            for i in range(n_samples):
                model_text, _ = translator(example_cond)
                
                print(model_text)
                
                translation_results[char].append(model_text)
        
        with open(os.path.join(checkpoint_name, "results.txt"), "w") as f:
            with redirect_stdout(f):
                pprint(translation_results)

In [208]:
if __name__ == '__main__':
    main(train=True, 
         compute_results=True, 
         n_conds=-1,
         n_samples=10,
         epochs=5)

KeyError: 10