In [2]:
from sentence_transformers import SentenceTransformer
import torch
# from Pooling import Pooling
from transformers import AutoTokenizer
import onnxruntime as ort
import pandas as pd
from tqdm import tqdm
import time
import psutil
import memory_profiler as mem_profile
from pprint import pprint
import tracemalloc

TEXT_COL = 5
sentence = "@switchfoot http://twitpic.com/2y1zl - Awww, that's a bummer.  You shoulda got David Carr of Third Day to do it. ;D"

In [3]:
def print_mem_usage():
    mem_usage = mem_profile.memory_usage()[0]
    # print(f"Final memory usage: {mem_usage:.2f} MB")
    return f"{mem_usage:.2f} MB"

In [4]:
def mean_pooling(tokens, inputs):
    output_vectors = []
    attention_mask = torch.tensor(inputs['attention_mask'])
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(tokens.size()).to(tokens.dtype)
    )
    sum_embeddings = torch.sum(tokens * input_mask_expanded, 1)
    sum_mask = input_mask_expanded.sum(1)
    sum_mask = torch.clamp(sum_mask, min=1e-9)

    output_vectors.append(sum_embeddings / sum_mask)
    
    return output_vectors

In [20]:
def main():
    # initialize model and load data
    initialize_start = time.time()

    tokenizer = AutoTokenizer.from_pretrained("KnightsAnalytics/all-MiniLM-L6-v2")
    onnx_model_path = "model.onnx"
    session = ort.InferenceSession(onnx_model_path)
    data = pd.read_csv("/home/testuser/repositories/hugot/text_data.csv", header=None)
    data = data[:5000]
    initialize_time = time.time() - initialize_start

    metrics = {}
    NUM_ITERS = 1
    time_per_iter = []
    print('Starting loop')

    text_embeddings = []

    # data loop
    for i in range(NUM_ITERS):
        start_time = time.time()
        
        for index in tqdm(range(len(data))):
            row = data.iloc[index]
            text = row[TEXT_COL]
            inputs = tokenizer(text, return_tensors="np")
            onnx_inputs = {
                session.get_inputs()[0].name: inputs['input_ids'],
                session.get_inputs()[1].name: inputs['attention_mask'],
                session.get_inputs()[2].name: inputs['token_type_ids']
            }
            try:
                outputs = session.run(None, onnx_inputs)
                tokens = torch.tensor(outputs[0])
                sentence_embedding = mean_pooling(tokens, onnx_inputs)
                text_embeddings.append(sentence_embedding[0][0].tolist())
            except Exception as e:
                print(f"Error: {e}")
                continue
        
        iter_duration = time.time() - start_time
        time_per_iter.append(iter_duration)
        print(f"Iteration {i+1} took {iter_duration:.2f} seconds")
        


    avg_time = sum(time_per_iter) / NUM_ITERS
    metrics['startup time'] = initialize_time
    metrics['time per iteration'] = time_per_iter
    metrics['average runtime'] = avg_time

    print("Metrics:")
    print(metrics)
    torch.set_printoptions(precision=6, sci_mode=False)

    return text_embeddings







In [19]:
###
# Run baseline code
###
tracemalloc.start()

main()
print(f"iteration used {tracemalloc.get_traced_memory()[1] / 1024 / 1024} MB of memory")

tracemalloc.stop()

Starting loop


100%|██████████| 1000/1000 [00:07<00:00, 137.62it/s]


Iteration 1 took 7.27 seconds


100%|██████████| 1000/1000 [00:08<00:00, 121.60it/s]


Iteration 2 took 8.23 seconds


100%|██████████| 1000/1000 [00:07<00:00, 130.83it/s]


Iteration 3 took 7.65 seconds
Metrics:
{'startup time': 8.713757514953613, 'time per iteration': [7.273473501205444, 8.230931282043457, 7.652797222137451], 'average runtime': 7.719067335128784}
iteration used 464.18289852142334 MB of memory


In [140]:
import csv

filename = "output.csv"

text_embeddings = main()

# Write the data to the CSV file
with open(filename, 'w', newline='') as csvfile:
    csvwriter = csv.writer(csvfile)
    
    # Write each row in the array to the CSV file
    for row in text_embeddings:
        csvwriter.writerow(row)

    

Starting loop


100%|██████████| 100/100 [00:01<00:00, 73.13it/s]


Iteration 1 took 1.37 seconds
Metrics:
{'startup time': 3.891197919845581, 'time per iteration': [1.37239670753479], 'average runtime': 1.37239670753479}
