### CF HybridBERT4Rec

In [2]:
!pip install transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [62]:
import tensorflow as tf
from transformers import TFBertModel

class CFHybridBERT4Rec(tf.keras.Model):
    def __init__(self, num_users, num_items, num_numeric_features, bert_model_name, hidden_size, dropout_prob):
        super(CFHybridBERT4Rec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_numeric_features = num_numeric_features
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob

        # Embedding layers for users and items
        self.user_embedding = tf.keras.layers.Embedding(num_users, hidden_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, hidden_size)

        # Numeric features layer
        self.numeric_layer = tf.keras.layers.Dense(hidden_size)

        # BERT model for item tokens
        self.bert_model = TFBertModel.from_pretrained(bert_model_name)

        # Dropout layer
        self.dropout = tf.keras.layers.Dropout(dropout_prob)

        # Fully connected layers
        self.fc1 = tf.keras.layers.Dense(hidden_size, activation='relu')
        self.fc2 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, user_ids, item_ids, item_tokens, attention_mask, numeric_features):
        # Embedding lookup for users and items
        user_embedded = self.user_embedding(user_ids)
        item_embedded = self.item_embedding(item_ids)

        # BERT encoding for item tokens
        _, item_token_embeddings = self.bert_model(item_tokens, attention_mask=attention_mask)

        # Combine item embeddings with BERT embeddings
        #item_embedded = item_embedded + item_token_embeddings[:, 0, :]
        #item_embedded = self.dropout(item_embedded)
        # Convert item_embedded to int64 data type
        item_embedded = tf.dtypes.cast(item_embedded, tf.int64)

        # Combine item embeddings with BERT embeddings
        # item_embedded = item_embedded + item_token_embeddings[:, 0]
        # item_embedded = self.dropout(item_embedded)
        # print('item token embeddings')
        # print(item_token_embeddings)
        # print('item embedded')
        # print(item_embedded)

        # Combine item embeddings with BERT embeddings
        item_embedded = item_embedded + item_token_embeddings[:, 0]
        item_embedded = self.dropout(item_embedded)


        # Numeric features layer
        numeric_embedded = self.numeric_layer(numeric_features)

        # Concatenate user and item embeddings with numeric features
        user_item_concat = tf.concat((user_embedded, item_embedded, numeric_embedded), axis=1)

        # Hidden layer and dropout
        user_item_hidden = self.fc1(user_item_concat)
        user_item_hidden = self.dropout(user_item_hidden)

        # Output prediction
        prediction = self.fc2(user_item_hidden)

        return prediction


In [94]:
import tensorflow as tf
from transformers import TFBertModel

class CFHybridBERT4Rec(tf.keras.Model):
    def __init__(self, num_users, num_items, num_numeric_features, bert_model_name, hidden_size, dropout_prob):
        super(CFHybridBERT4Rec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_numeric_features = num_numeric_features
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob

        # Embedding layers for users and items
        self.user_embedding = tf.keras.layers.Embedding(num_users, hidden_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, hidden_size)

        # Numeric features layer
        self.numeric_layer = tf.keras.layers.Dense(hidden_size)

        # BERT model for item tokens
        self.bert_model = TFBertModel.from_pretrained(bert_model_name)

        # Dropout layer
        self.dropout = tf.keras.layers.Dropout(dropout_prob)

        # Fully connected layers
        self.fc1 = tf.keras.layers.Dense(hidden_size, activation='relu')
        self.fc2 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, user_ids, item_ids, item_tokens, attention_mask, numeric_features):
        # Embedding lookup for users and items
        user_embedded = self.user_embedding(user_ids)
        item_embedded = self.item_embedding(item_ids)

        # BERT encoding for item tokens
        outputs = self.bert_model(item_tokens, attention_mask=attention_mask)
        item_token_embeddings = outputs['last_hidden_state']

        # Combine item embeddings with BERT embeddings
        print(item_token_embeddings[:, 0, :].numpy().shape)
        print(item_embedded.numpy().shape)
        item_embedded = item_embedded + item_token_embeddings[:, 0, :]
        item_embedded = self.dropout(item_embedded)

        # Numeric features layer
        numeric_embedded = self.numeric_layer(numeric_features)

        # Concatenate user and item embeddings with numeric features
        user_item_concat = tf.concat((user_embedded, item_embedded, numeric_embedded), axis=1)

        # Hidden layer and dropout
        user_item_hidden = self.fc1(user_item_concat)
        user_item_hidden = self.dropout(user_item_hidden)

        # Output prediction
        prediction = self.fc2(user_item_hidden)

        return prediction


In [98]:
import tensorflow as tf
from transformers import TFBertModel

class CFHybridBERT4Rec(tf.keras.Model):
    def __init__(self, num_users, num_items, num_numeric_features, bert_model_name, hidden_size, dropout_prob):
        super(CFHybridBERT4Rec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_numeric_features = num_numeric_features
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob

        # Embedding layers for users and items
        self.user_embedding = tf.keras.layers.Embedding(num_users, hidden_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, hidden_size)

        # Numeric features layer
        self.numeric_layer = tf.keras.layers.Dense(hidden_size)

        # BERT model for item tokens
        self.bert_model = TFBertModel.from_pretrained(bert_model_name)

        # Dropout layer
        self.dropout = tf.keras.layers.Dropout(dropout_prob)

        # Linear transformation for item embeddings
        self.linear_transform = tf.keras.layers.Dense(hidden_size)

        # Fully connected layers
        self.fc1 = tf.keras.layers.Dense(hidden_size, activation='relu')
        self.fc2 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, user_ids, item_ids, item_tokens, attention_mask, numeric_features):
        # Embedding lookup for users and items
        user_embedded = self.user_embedding(user_ids)
        item_embedded = self.item_embedding(item_ids)

        # BERT encoding for item tokens
        outputs = self.bert_model(item_tokens, attention_mask=attention_mask)
        item_token_embeddings = outputs['last_hidden_state']

        # Linear transformation for item embeddings
        item_embedded_transformed = self.linear_transform(item_embedded)
        print(item_embedded_transformed.numpy().shape)

        # Combine item embeddings with BERT embeddings
        item_combined = item_embedded_transformed + item_token_embeddings[:, 0, :]
        item_combined = self.dropout(item_combined)

        # Numeric features layer
        numeric_embedded = self.numeric_layer(numeric_features)

        # Concatenate user and item embeddings with numeric features
        user_item_concat = tf.concat((user_embedded, item_combined, numeric_embedded), axis=1)

        # Hidden layer and dropout
        user_item_hidden = self.fc1(user_item_concat)
        user_item_hidden = self.dropout(user_item_hidden)

        # Output prediction
        prediction = self.fc2(user_item_hidden)

        return prediction


In [180]:
class CFHybridBERT4Rec(tf.keras.Model):
    def __init__(self, num_users, num_items, num_numeric_features, bert_model_name, hidden_size, dropout_prob):
        super(CFHybridBERT4Rec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_numeric_features = num_numeric_features
        self.hidden_size = hidden_size
        self.dropout_prob = dropout_prob

        # Embedding layers for users and items
        self.user_embedding = tf.keras.layers.Embedding(num_users, hidden_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, hidden_size)

        # Numeric features layer
        self.numeric_layer = tf.keras.layers.Dense(hidden_size)

        # BERT model for item tokens
        self.bert_model = TFBertModel.from_pretrained(bert_model_name)

        # Dropout layer
        self.dropout = tf.keras.layers.Dropout(dropout_prob)

        # Linear transformation for item embeddings
        self.linear_transform = tf.keras.layers.Dense(hidden_size)

        # Fully connected layers
        self.fc1 = tf.keras.layers.Dense(hidden_size, activation='relu')
        self.fc2 = tf.keras.layers.Dense(1, activation='sigmoid')

    def call(self, user_ids, item_ids, item_tokens, attention_mask, numeric_features):
        # Embedding lookup for users and items
        user_embedded = self.user_embedding(user_ids)
        item_embedded = self.item_embedding(item_ids)

        # BERT encoding for item tokens
        #attention_mask = tf.expand_dims(attention_mask,axis=1)
        outputs = self.bert_model(item_tokens, attention_mask=attention_mask)#tf.squeeze(attention_mask, axis=1))
        item_token_embeddings = outputs['last_hidden_state']

        # Linear transformation for item embeddings
        #item_embedded_transformed = self.linear_transform(item_embedded)
        item_token_embeddings_transformed = self.linear_transform(item_token_embeddings)


        # Reshape item_embedded to match the shape of item_token_embeddings_transformed
        item_embedded = tf.reshape(item_embedded, (-1, 1, self.hidden_size))

        print(item_token_embeddings_transformed.numpy().shape)
        print(item_embedded.numpy().shape)

        # Combine item embeddings with BERT embeddings
        #item_combined = item_embedded_transformed + item_token_embeddings
        item_combined = item_embedded + item_token_embeddings_transformed
        item_combined = self.dropout(item_combined)


        # Numeric features layer
        numeric_embedded = self.numeric_layer(numeric_features)

        # Reshape item_combined to match the shape of numeric_embedded

        # Reshape item_combined to match the batch size of user_embedded and numeric_embedded
        batch_size = tf.shape(user_embedded)[0]
        item_combined = tf.reshape(item_combined, (batch_size, -1, self.hidden_size))

        # Reshape item_combined to match the shape of numeric_embedded
        item_combined_reshaped = tf.reshape(item_combined, (batch_size, -1))


        # Concatenate user and item embeddings with numeric features
        print(user_embedded.numpy().shape, item_combined_reshaped.numpy().shape, numeric_embedded.numpy().shape)
        user_item_concat = tf.concat((user_embedded, item_combined_reshaped, numeric_embedded), axis=1)

        # Hidden layer and dropout
        user_item_hidden = self.fc1(user_item_concat)
        user_item_hidden = self.dropout(user_item_hidden)

        # Output prediction
        prediction = self.fc2(user_item_hidden)

        return prediction

In [147]:
#

In [67]:
bert_model_test = TFBertModel.from_pretrained(bert_model_name)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [139]:
attention_mask.shape

TensorShape([1000, 10])

In [140]:
tf.expand_dims(attention_mask,axis=1)

<tf.Tensor: shape=(1000, 1, 10), dtype=int32, numpy=
array([[[1, 1, 1, ..., 1, 1, 1]],

       [[1, 1, 1, ..., 1, 1, 1]],

       [[1, 1, 1, ..., 1, 1, 1]],

       ...,

       [[1, 1, 1, ..., 1, 1, 1]],

       [[1, 1, 1, ..., 1, 1, 1]],

       [[1, 1, 1, ..., 1, 1, 1]]], dtype=int32)>

In [138]:
tf.squeeze(attention_mask, axis=1)

InvalidArgumentError: ignored

In [86]:
a, b = ber_model_out
a, b

('last_hidden_state', 'pooler_output')

In [152]:
ber_model_out = bert_model_test(item_tokens, attention_mask=attention_mask)

In [153]:
ber_model_out['last_hidden_state'].numpy().shape

(1000, 10, 768)

In [72]:
_, item_token_embeddings_some = ber_model_out

In [80]:
ber_model_out[0].numpy().shape

(1000, 10, 768)

In [89]:
ber_model_out[0][:, 0, :].numpy().shape

(1000, 768)

In [69]:
attention_mask

<tf.Tensor: shape=(1000, 10), dtype=int32, numpy=
array([[1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       ...,
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1],
       [1, 1, 1, ..., 1, 1, 1]], dtype=int32)>

In [164]:
import numpy as np
# import tensorflow as tf
from transformers import BertTokenizer

In [181]:


# Sample data
num_users = 1000
num_items = 2000
num_numeric_features = 5
max_sequence_length = 10
bert_model_name = 'bert-base-uncased'

tokenizer = BertTokenizer.from_pretrained(bert_model_name)

# Generate random training data
num_samples = 1000
user_ids = np.random.randint(0, num_users, size=num_samples)
item_ids = np.random.randint(0, num_items, size=num_samples)
item_tokens = np.random.randint(0, 100, size=(num_samples, max_sequence_length))
attention_mask = np.random.randint(0, 2, size=(num_samples, max_sequence_length))
numeric_features = np.random.rand(num_samples, num_numeric_features)
ratings = np.random.randint(0, 2, size=num_samples)



# Sample data
num_samples = 1000
max_sequence_length = 10

# Generate random training data
item_tokens = []
attention_mask = []
for _ in range(num_samples):
    item_text = "This is a sample item description."
    tokens = tokenizer.encode(item_text, max_length=max_sequence_length, padding='max_length')
    item_tokens.append(tokens)
    mask = [1] * len(tokens) + [0] * (max_sequence_length - len(tokens))
    attention_mask.append(mask)

item_tokens = np.array(item_tokens)
attention_mask = np.array(attention_mask)


# Convert data to TensorFlow tensors with the correct data type
user_ids = tf.constant(user_ids, dtype=tf.int32)
item_ids = tf.constant(item_ids, dtype=tf.int32)
item_tokens = tf.constant(item_tokens, dtype=tf.int32)
attention_mask = tf.constant(attention_mask, dtype=tf.int32)
numeric_features = tf.constant(numeric_features, dtype=tf.float32)
ratings = tf.constant(ratings, dtype=tf.float32)

# Convert item_tokens and attention_mask to int64 data type
#item_tokens = tf.cast(item_tokens, tf.int64)
#attention_mask = tf.cast(attention_mask, tf.int64)

# # Convert data to TensorFlow tensors
# user_ids = tf.constant(user_ids, dtype=tf.int32)
# item_ids = tf.constant(item_ids, dtype=tf.int32)
# item_tokens = tf.constant(item_tokens, dtype=tf.int32)
# attention_mask = tf.constant(attention_mask, dtype=tf.int32)
# numeric_features = tf.constant(numeric_features, dtype=tf.float32)
# ratings = tf.constant(ratings, dtype=tf.float32)

# Instantiate CFHybridBERT4Rec model
hidden_size = 128
dropout_prob = 0.2
model = CFHybridBERT4Rec(num_users, num_items, num_numeric_features, bert_model_name, hidden_size, dropout_prob)


Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [182]:

# Define loss function and optimizer
loss_function = tf.keras.losses.BinaryCrossentropy()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)

# Training loop
num_epochs = 10
batch_size = 32
num_steps = num_samples // batch_size


In [183]:

for epoch in range(num_epochs):
    for step in range(num_steps):
        # Extract mini-batch from the data
        batch_start = step * batch_size
        batch_end = batch_start + batch_size

        batch_user_ids = user_ids[batch_start:batch_end]
        batch_item_ids = item_ids[batch_start:batch_end]
        batch_item_tokens = item_tokens[batch_start:batch_end]
        batch_attention_mask = attention_mask[batch_start:batch_end]
        batch_numeric_features = numeric_features[batch_start:batch_end]
        batch_ratings = ratings[batch_start:batch_end]

        # Forward pass
        with tf.GradientTape() as tape:
            predictions = model(batch_user_ids, batch_item_ids, batch_item_tokens, batch_attention_mask, batch_numeric_features)
            loss = loss_function(batch_ratings, predictions)

        # Backward pass
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

        # Print training progress
        if (step + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{num_epochs} | Step {step+1}/{num_steps} | Loss: {loss.numpy():.4f}")



(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 1/10 | Step 10/31 | Loss: 0.7700
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 1/10 | Step 20/31 | Loss: 0.7105
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 1/10 | Step 30/31 | Loss: 0.7679
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 2/10 | Step 10/31 | Loss: 0.6757
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 2/10 | Step 20/31 | Loss: 0.6745
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 2/10 | Step 30/31 | Loss: 0.7692
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 3/10 | Step 10/31 | Loss: 0.6413
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 3/10 | Step 20/31 | Loss: 0.6687
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 3/10 | Step 30/31 | Loss: 0.5839
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 4/10 | Step 10/31 | Loss: 0.5580
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 4/10 | Step 20/31 | Loss: 0.5450
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 4/10 | Step 30/31 | Loss: 0.4596
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 5/10 | Step 10/31 | Loss: 0.4377
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 5/10 | Step 20/31 | Loss: 0.4042
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 5/10 | Step 30/31 | Loss: 0.3316
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 6/10 | Step 10/31 | Loss: 0.3311
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 6/10 | Step 20/31 | Loss: 0.3044
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 6/10 | Step 30/31 | Loss: 0.2377
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 7/10 | Step 10/31 | Loss: 0.3218
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 7/10 | Step 20/31 | Loss: 0.3182
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 7/10 | Step 30/31 | Loss: 0.2156
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 8/10 | Step 10/31 | Loss: 0.3110
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 8/10 | Step 20/31 | Loss: 0.2611
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 8/10 | Step 30/31 | Loss: 0.1969
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 9/10 | Step 10/31 | Loss: 0.3010
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 9/10 | Step 20/31 | Loss: 0.2997
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 9/10 | Step 30/31 | Loss: 0.1818
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 10/10 | Step 10/31 | Loss: 0.2421
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 10/10 | Step 20/31 | Loss: 0.2356
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




Epoch 10/10 | Step 30/31 | Loss: 0.1897
(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)




In [22]:
model.summary()

Model: "cf_hybrid_bert4_rec_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 embedding_8 (Embedding)     multiple                  128000    
                                                                 
 embedding_9 (Embedding)     multiple                  256000    
                                                                 
 dense_10 (Dense)            multiple                  0 (unused)
                                                                 
 tf_bert_model_4 (TFBertMode  multiple                 109482240 
 l)                                                              
                                                                 
 dropout_187 (Dropout)       multiple                  0 (unused)
                                                                 
 dense_11 (Dense)            multiple                  0 (unused)
                                             

In [186]:
predictions = model(batch_user_ids, batch_item_ids, batch_item_tokens, batch_attention_mask, batch_numeric_features)

(32, 10, 128)
(32, 1, 128)
(32, 128) (32, 1280) (32, 128)


In [189]:
batch_ratings

<tf.Tensor: shape=(32,), dtype=float32, numpy=
array([0., 0., 0., 1., 0., 1., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0.,
       1., 0., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 0.],
      dtype=float32)>

In [190]:
predictions

<tf.Tensor: shape=(32, 1), dtype=float32, numpy=
array([[0.00563574],
       [0.04057363],
       [0.05061217],
       [0.9901481 ],
       [0.65675175],
       [0.9943834 ],
       [0.99491274],
       [0.04208165],
       [0.03709753],
       [0.0261826 ],
       [0.03206322],
       [0.8445351 ],
       [0.02718589],
       [0.99544483],
       [0.9976927 ],
       [0.88735306],
       [0.72908926],
       [0.9955383 ],
       [0.00664295],
       [0.7462278 ],
       [0.02873709],
       [0.8986698 ],
       [0.3220855 ],
       [0.02681949],
       [0.77466637],
       [0.9941871 ],
       [0.9942861 ],
       [0.01569121],
       [0.7390073 ],
       [0.9373528 ],
       [0.01710252],
       [0.00503302]], dtype=float32)>

In [None]:

# Example prediction
test_user_ids = tf.constant([0, 1, 2], dtype=tf.int32)
test_item_ids = tf.constant([100, 101, 102], dtype=tf.int32)
test_item_tokens = tf.constant(tokenizer.encode(["item1", "item2", "item3"], max_length=max_sequence_length, padding='max_length'), dtype=tf.int32)
test_attention_mask = tf.constant(np.ones((3, max_sequence_length)), dtype=tf.int32)
test_numeric_features = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5], [0.6, 0.7, 0.8, 0.9, 1.0], [0.2, 0.4, 0.6, 0.8, 1.0]], dtype=tf.float32)

predictions = model(test_user_ids, test_item_ids, test_item_tokens, test_attention_mask, test_numeric_features)
print("Predictions:", predictions.numpy())

### Code

In [4]:
import tensorflow as tf
from transformers import TFBertModel

class CFHybridBERT4Rec(tf.keras.Model):
    def __init__(self, num_users, num_items, num_numerical_features, bert_model_name, hidden_size=128):
        super(CFHybridBERT4Rec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_numerical_features = num_numerical_features
        self.hidden_size = hidden_size

        # Collaborative Filtering Embeddings
        self.user_embedding = tf.keras.layers.Embedding(num_users, hidden_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, hidden_size)

        # BERT model
        self.bert = TFBertModel.from_pretrained(bert_model_name)

        # Linear layers for numerical features
        self.numerical_features_linear = tf.keras.layers.Dense(hidden_size)

        # Final prediction layer
        self.predict_layer = tf.keras.layers.Dense(1)

    def call(self, user_ids, item_ids, item_tokens, numerical_features, attention_mask):
        # Collaborative Filtering Embeddings
        user_embedded = self.user_embedding(user_ids)
        item_embedded = self.item_embedding(item_ids)

        # BERT-based embeddings
        bert_outputs = self.bert(input_ids=item_tokens, attention_mask=attention_mask)
        bert_pooled_output = bert_outputs.pooler_output

        # Numerical features
        numerical_features_embedded = self.numerical_features_linear(numerical_features)

        # Concatenate embeddings
        concatenated = tf.concat((user_embedded, item_embedded, bert_pooled_output, numerical_features_embedded), axis=1)

        # Prediction
        predicted_rating = self.predict_layer(concatenated)

        return tf.squeeze(predicted_rating)



In [104]:
# Example usage
num_users = 1000
num_items = 5000
num_numerical_features = 10
bert_model_name = 'bert-base-uncased'

model = CFHybridBERT4Rec(num_users, num_items, num_numerical_features, bert_model_name)


TypeError: ignored

In [None]:
user_ids = tf.constant([1, 2, 3, 4])
item_ids = tf.constant([100, 200, 300, 400])
item_tokens = tf.constant([[101, 1045, 2061, 1012, 102,  1037, 2036],  # Example tokenized item 1
                           [101, 2054, 2023, 2003, 1037, 2047, 102],  # Example tokenized item 2
                           [101, 2049, 1005, 1055, 2342, 102, 2031],  # Example tokenized item 3
                           [101, 2031, 2028, 2145, 2000, 2228, 102]])  # Example tokenized item 4
numerical_features = tf.constant([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0],
                                  [1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1],
                                  [0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5],
                                  [0.9, 0.1, 0.8, 0.2, 0.7, 0.3, 0.6, 0.4, 0.5, 0.5]])
attention_mask = tf.constant([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
                              [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])

predictions = model(user_ids, item_ids, item_tokens, numerical_features, attention_mask)
print(predictions)


### Version - 3

In [44]:
import tensorflow as tf
from transformers import TFBertModel

class CFHybridBERT4Rec(tf.keras.Model):
    def __init__(self, num_users, num_items, num_numerical_features, bert_model_name, hidden_size=128):
        super(CFHybridBERT4Rec, self).__init__()
        self.num_users = num_users
        self.num_items = num_items
        self.num_numerical_features = num_numerical_features
        self.hidden_size = hidden_size

        # Collaborative Filtering Embeddings
        self.user_embedding = tf.keras.layers.Embedding(num_users, hidden_size)
        self.item_embedding = tf.keras.layers.Embedding(num_items, hidden_size)

        # BERT model
        self.bert = TFBertModel.from_pretrained(bert_model_name)

        # Linear layers for numerical features
        self.numerical_features_linear = tf.keras.layers.Dense(hidden_size)

        # Final prediction layer
        self.predict_layer = tf.keras.layers.Dense(1)

    def call(self, user_ids, item_ids, item_tokens, numerical_features, attention_mask):
        # Collaborative Filtering Embeddings
        user_embedded = self.user_embedding(user_ids)
        item_embedded = self.item_embedding(item_ids)

        # BERT-based embeddings
        bert_outputs = self.bert(input_ids=item_tokens, attention_mask=attention_mask)
        bert_pooled_output = bert_outputs.pooler_output

        # Numerical features
        numerical_features_embedded = self.numerical_features_linear(numerical_features)

        # Concatenate embeddings
        concatenated = tf.concat((user_embedded, item_embedded, bert_pooled_output, numerical_features_embedded), axis=1)

        # Prediction
        predicted_rating = self.predict_layer(concatenated)

        return tf.squeeze(predicted_rating)


In [46]:
import tensorflow as tf
from transformers import BertTokenizer

# Set hyperparameters and data paths
num_users = 1000
num_items = 2000
num_numerical_features = 5
hidden_size = 128
bert_model_name = 'bert-base-uncased'
train_data_path = 'train_data.csv'
valid_data_path = 'valid_data.csv'
batch_size = 32
num_epochs = 10
learning_rate = 0.001

In [47]:


# Create an instance of the CFHybridBERT4Rec model
model = CFHybridBERT4Rec(num_users, num_items, num_numerical_features, bert_model_name, hidden_size)

# Define loss function and optimizer
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = tf.keras.optimizers.Adam(learning_rate)

# Create tokenizer for BERT
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions w

In [48]:
# Function to preprocess a batch of data
def preprocess_batch(data_batch):
    user_ids = data_batch['user_ids']
    item_ids = data_batch['item_ids']
    item_tokens = tokenizer.batch_encode_plus(data_batch['item_texts'], padding='longest', truncation=True, max_length=128, return_tensors='tf')['input_ids']
    numerical_features = data_batch['numerical_features']
    attention_mask = tf.cast(item_tokens != 0, tf.int32)
    ratings = data_batch['ratings']

    return user_ids, item_ids, item_tokens, numerical_features, attention_mask, ratings

In [49]:
# Function to compute the forward pass and loss
@tf.function
def compute_loss(user_ids, item_ids, item_tokens, numerical_features, attention_mask, ratings):
    with tf.GradientTape() as tape:
        predicted_ratings = model(user_ids, item_ids, item_tokens, numerical_features, attention_mask)
        loss = loss_fn(ratings, predicted_ratings)

    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    return loss


### Generate Data

In [50]:
import numpy as np
import pandas as pd

# Generate random training data
num_train_samples = 1000

train_user_ids = np.random.randint(1, num_users + 1, size=num_train_samples)
train_item_ids = np.random.randint(1, num_items + 1, size=num_train_samples)
train_item_texts = [f"Item {item_id}" for item_id in train_item_ids]
train_numerical_features = np.random.randn(num_train_samples, num_numerical_features)
train_ratings = np.random.randint(1, 6, size=num_train_samples)

train_data = pd.DataFrame({
    'user_ids': train_user_ids,
    'item_ids': train_item_ids,
    'item_texts': train_item_texts,
    'numerical_features': train_numerical_features.tolist(),
    'ratings': train_ratings
})

# Generate random validation data
num_valid_samples = 200

valid_user_ids = np.random.randint(1, num_users + 1, size=num_valid_samples)
valid_item_ids = np.random.randint(1, num_items + 1, size=num_valid_samples)
valid_item_texts = [f"Item {item_id}" for item_id in valid_item_ids]
valid_numerical_features = np.random.randn(num_valid_samples, num_numerical_features)
valid_ratings = np.random.randint(1, 6, size=num_valid_samples)

valid_data = pd.DataFrame({
    'user_ids': valid_user_ids,
    'item_ids': valid_item_ids,
    'item_texts': valid_item_texts,
    'numerical_features': valid_numerical_features.tolist(),
    'ratings': valid_ratings
})

# # Save the generated data to CSV files
# train_data.to_csv('train_data.csv', index=False)
# valid_data.to_csv('valid_data.csv', index=False)


In [51]:
train_data.head()

Unnamed: 0,user_ids,item_ids,item_texts,numerical_features,ratings
0,36,714,Item 714,"[-1.5870669548122613, 0.3729184445593504, 0.53...",1
1,1000,1068,Item 1068,"[0.002471331942127893, -0.05856419713042212, -...",3
2,953,1751,Item 1751,"[-1.0458181107350684, 0.8712534899759317, 0.34...",1
3,23,523,Item 523,"[-1.112180638235938, 1.3135616534328622, 0.192...",4
4,370,1500,Item 1500,"[0.39134713293964674, -1.8222047071104766, 1.7...",4


In [52]:
user_ids, item_ids, item_tokens, numerical_features, attention_mask, ratings = preprocess_batch(train_data)

In [56]:
predicted_ratings = model(user_ids, item_ids, item_tokens, numerical_features, attention_mask)

AttributeError: ignored

In [None]:





# # Load and preprocess the training and validation data
# train_data = load_data(train_data_path)  # Replace with your data loading function
# valid_data = load_data(valid_data_path)  # Replace with your data loading function

# Training loop
for epoch in range(num_epochs):
    train_loss = tf.keras.metrics.Mean()
    valid_loss = tf.keras.metrics.Mean()

    # Training
    for batch in get_batches(train_data, batch_size):
        user_ids, item_ids, item_tokens, numerical_features, attention_mask, ratings = preprocess_batch(batch)
        loss = compute_loss(user_ids, item_ids, item_tokens, numerical_features, attention_mask, ratings)
        train_loss(loss)

    # Validation
    for batch in get_batches(valid_data, batch_size):
        user_ids, item_ids, item_tokens, numerical_features, attention_mask, ratings = preprocess_batch(batch)
        predicted_ratings = model(user_ids, item_ids, item_tokens, numerical_features, attention_mask)
        loss = loss_fn(ratings, predicted_ratings)
        valid_loss(loss)

    # Print training progress
    print(f"Epoch {epoch+1}: Train Loss = {train_loss.result()}, Valid Loss = {valid_loss.result()}")

# Save the trained model
model.save_weights('cf_hybridbert4rec_weights.h5')
