In [1]:
from transformers import AutoTokenizer, AutoModel

In [2]:
from pyspark.sql.functions import *
from pyspark.sql.types import *
import requests
import json

In [31]:
opensearch_url = "https://search-vector-db-lyyqoujvse7m6t32vwxjb2ui3i.us-east-1.es.amazonaws.com"
opensearch_user_name = "OSMasterUser"
opensearch_password = "AwS#OpenSearch1"

In [4]:
datasource = "s3://amazon-pqa/amazon_pqa_computer_cases.json"
df_raw = spark.read.format("json").load(datasource).limit(100)

In [5]:
df_raw.printSchema()

root
 |-- answer_aggregated: string (nullable = true)
 |-- answers: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- answer_text: string (nullable = true)
 |-- asin: string (nullable = true)
 |-- brand_name: string (nullable = true)
 |-- bullet_point1: string (nullable = true)
 |-- bullet_point2: string (nullable = true)
 |-- bullet_point3: string (nullable = true)
 |-- bullet_point4: string (nullable = true)
 |-- bullet_point5: string (nullable = true)
 |-- item_name: string (nullable = true)
 |-- product_description: string (nullable = true)
 |-- question_id: string (nullable = true)
 |-- question_text: string (nullable = true)
 |-- question_type: string (nullable = true)



In [10]:
df_raw.select(df_raw.answers[0].answer_text).limit(2).show(truncate=False)

+------------------------------------------------------------------------+
|answers[0].answer_text                                                  |
+------------------------------------------------------------------------+
|yes it is.                                                              |
|It doesn't come with any. But it can hold up to four 80mm fans i believe|
+------------------------------------------------------------------------+



In [11]:
df_raw.select(col("question_id"),col("question_text"),df_raw.answers[0].answer_text.alias("answer_text")).limit(2).show(truncate=False)

+---------------+--------------------------------------------------------------------------------------------+------------------------------------------------------------------------+
|question_id    |question_text                                                                               |answer_text                                                             |
+---------------+--------------------------------------------------------------------------------------------+------------------------------------------------------------------------+
|Tx350WMH6IV09J8|Clarification: reviews make me uncertain about this, the ml03b is micro atx correct? thanks.|yes it is.                                                              |
|TxPFERD1HWC26S |How many fans does this case comes with                                                     |It doesn't come with any. But it can hold up to four 80mm fans i believe|
+---------------+---------------------------------------------------------------

In [12]:
df = df_raw.select(col("question_id"),col("question_text"),df_raw.answers[0].answer_text.alias("answer_text"))

In [13]:
df.limit(2).show(truncate=False)

+---------------+--------------------------------------------------------------------------------------------+------------------------------------------------------------------------+
|question_id    |question_text                                                                               |answer_text                                                             |
+---------------+--------------------------------------------------------------------------------------------+------------------------------------------------------------------------+
|Tx350WMH6IV09J8|Clarification: reviews make me uncertain about this, the ml03b is micro atx correct? thanks.|yes it is.                                                              |
|TxPFERD1HWC26S |How many fans does this case comes with                                                     |It doesn't come with any. But it can hold up to four 80mm fans i believe|
+---------------+---------------------------------------------------------------

In [14]:
df.printSchema()

root
 |-- question_id: string (nullable = true)
 |-- question_text: string (nullable = true)
 |-- answer_text: string (nullable = true)



In [15]:
df.count()

100

In [16]:
import numpy as np

def embed_func(df):
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = AutoModel.from_pretrained("bert-base-uncased")

    #title_abs = [d.question_text + tokenizer.sep_token + d.answer_text  for idx, d in df.iterrows()]
    title_abs = [d.question_text  for idx, d in df.iterrows()]

    def chunks(lst, n):
        """Yield successive n-sized chunks from lst."""
        for i in range(0, len(lst), n):
          yield lst[i:i + n]

    batch_size = 20
    embeddings_chunks = []
    for c in chunks(title_abs, batch_size):
        # preprocess the input
        inputs = tokenizer(c, padding=True, truncation=True, return_tensors="pt", max_length=512)
        result = model(**inputs)
        # take the first token in the batch as the embedding
        embeddings = result.last_hidden_state[:, 0, :].cpu().detach().numpy()
        embeddings_chunks.append(embeddings)

    embeddings = np.concatenate(embeddings_chunks)

    return_df = (
        df[["question_id","question_text","answer_text"]]
        .assign(embedding=list(embeddings))
    )
    return return_df

In [17]:
from pyspark.sql.types import StructType, StructField, StringType, ArrayType, FloatType
from pyspark.sql.functions import spark_partition_id

embed_schema = StructType(
    [
        StructField("question_id", StringType(), True),
        StructField("question_text", StringType(), True),
        StructField("answer_text", StringType(), True),
        StructField("embedding", ArrayType(FloatType()), True)
    ]
)

embed_df = (
    df
    .groupBy(spark_partition_id().alias("_pid"))
    .applyInPandas(embed_func, embed_schema)
)

embed_df.show()

+---------------+--------------------+--------------------+--------------------+
|    question_id|       question_text|         answer_text|           embedding|
+---------------+--------------------+--------------------+--------------------+
|Tx2RE2SOAZVEVY8|How do I resolve ...|Did you load the ...|[-0.26896772, -0....|
|Tx30YAFDRDBOHN9|Burning smell whe...|Jip. Just check t...|[0.061282303, 0.1...|
| TxLX89NHS1LJMS|Has anyone else h...|Nope. Works as it...|[-0.10376623, 0.0...|
|Tx2SEERIC2V38C6|Dose it have inst...|The instructions ...|[-0.24544035, 0.0...|
|Tx3UQWJPNPV9313|What size is the ...|Hm, I seem to rem...|[-0.33658928, 0.2...|
|Tx1M2ZC4N8J5BPN|In the product pi...|I don’t recall ho...|[-0.2180458, 0.21...|
| TxLN2W3PTWJLG4|I see the fan att...|It is one speed o...|[-0.15246768, 0.1...|
| TxGK7X8RU8A7GH|What style SD car...|small slot that a...|[-0.30430722, -0....|
|Tx3SI9GG93GT7SC|What style SD car...|             MicroSD|[-0.30430722, -0....|
|Tx1CVUCIT72Y1HN|Is this an 

In [18]:
embed_df.printSchema()

root
 |-- question_id: string (nullable = true)
 |-- question_text: string (nullable = true)
 |-- answer_text: string (nullable = true)
 |-- embedding: array (nullable = true)
 |    |-- element: float (containsNull = true)



In [19]:
embed_df.count()

100

In [20]:
embed_df = embed_df.select(embed_df.question_id.alias("id"),embed_df.question_text,embed_df.answer_text,embed_df.embedding)

In [21]:
embed_df.show(3)

+---------------+--------------------+--------------------+--------------------+
|             id|       question_text|         answer_text|           embedding|
+---------------+--------------------+--------------------+--------------------+
|Tx2RE2SOAZVEVY8|How do I resolve ...|Did you load the ...|[-0.26896772, -0....|
|Tx30YAFDRDBOHN9|Burning smell whe...|Jip. Just check t...|[0.061282303, 0.1...|
| TxLX89NHS1LJMS|Has anyone else h...|Nope. Works as it...|[-0.10376623, 0.0...|
+---------------+--------------------+--------------------+--------------------+
only showing top 3 rows



In [28]:
def loados(x):
    
    answer_text = x.answer_text
    
    upload_document_request_body = {
        "answer": x.answer_text,
        "embedded_vector": x.embedding,
        "question": x.question_text
    }
    
    upload_document_r = requests.post(opensearch_url + '/vector-index/_doc', auth=(opensearch_user_name, opensearch_password), headers= {'Content-type': 'application/json'}, data=json.dumps(upload_document_request_body))
    return upload_document_r
    #return json.dumps(upload_document_request_body)

In [32]:
rdd_test = embed_df.rdd.map(lambda os: loados(os))

In [33]:
rdd_test.take(3)

[<Response [201]>, <Response [201]>, <Response [201]>]