In [None]:
import numpy as np
import tensorflow_addons as tfa
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Dropout, BatchNormalization
from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization, GlobalAveragePooling1D
from tensorflow.keras.layers import Reshape
from tensorflow.keras.optimizers.legacy import Adam
import tensorflow as tf

In [None]:

dataset = load_dataset("paws", "labeled_final")

embedding_model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

def get_sentence_embedding(sentence):
    embedding = embedding_model.encode(sentence, convert_to_tensor=True)
    return embedding.numpy()

def create_features(data):
    embeddings = []
    labels = []
    for item in data:
        emb1 = get_sentence_embedding(item['sentence1'])
        emb2 = get_sentence_embedding(item['sentence2'])
        combined_embedding = np.concatenate([emb1, emb2])
        embeddings.append(combined_embedding)
        labels.append(item['label'])
        print(item['id'])
    return np.array(embeddings), np.array(labels)



In [None]:
def create_combined_model(bert_dim):
    # Input for concatenated embeddings
    input_embedding = Input(shape=(bert_dim * 2,), dtype='float32')

    # Reshape to add a dummy sequence length dimension
    reshaped_embedding = Reshape((1, bert_dim * 2))(input_embedding)

    # Transformer block with 2 layers
    for _ in range(2):
        attention_output = MultiHeadAttention(
            num_heads=8, 
            key_dim=bert_dim // 8
        )(reshaped_embedding, reshaped_embedding)

        attention_output = LayerNormalization(epsilon=1e-6)(reshaped_embedding + attention_output)

        ff_output = Dense(bert_dim * 4, activation='relu')(attention_output)
        ff_output = Dropout(0.2)(ff_output) 
        ff_output = Dense(bert_dim * 2)(ff_output)

        reshaped_embedding = LayerNormalization(epsilon=1e-6)(attention_output + ff_output)

    # Global average pooling
    pooled_output = GlobalAveragePooling1D()(reshaped_embedding)

    x = Dense(512, activation='relu')(pooled_output)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(256, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    x = Dense(128, activation='relu')(x)
    x = BatchNormalization()(x)
    x = Dropout(0.4)(x)
    output_layer = Dense(1, activation='sigmoid')(x)

    model = Model(inputs=input_embedding, outputs=output_layer)
    model.compile(optimizer=Adam(learning_rate=0.0005), loss='binary_crossentropy', metrics=['accuracy'])

    return model


In [None]:
# Prepare the data
train_data = dataset['train'].shard(num_shards=20, index=0)
val_data = dataset['validation'].shard(num_shards=20, index=0)
test_data = dataset['test'].shard(num_shards=20, index=0)

# Create features
train_embeddings, train_labels = create_features(train_data)
val_embeddings, val_labels = create_features(val_data)
test_embeddings, test_labels = create_features(test_data)

In [None]:
# Define and train the model
bert_dim = embedding_model.get_sentence_embedding_dimension()
combined_model = create_combined_model(bert_dim)
combined_model.fit(train_embeddings, train_labels, epochs=100, batch_size=32, validation_data=(val_embeddings, val_labels))

# Evaluate the model
test_loss, test_accuracy = combined_model.evaluate(test_embeddings, test_labels)
print(f'Test Accuracy: {test_accuracy}')