# ToS Mixedbread Embedding Creation

State-of-the-art sentence embeddings from mixedbread.ai. 

https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1.

In [1]:
# Packages
from sentence_transformers import SentenceTransformer
# NLTK for sentence tokenization
import nltk
nltk.download('punkt')
# Torch to move to GPU
import torch
import os
import pandas as pd
import time

  from .autonotebook import tqdm as notebook_tqdm




[nltk_data] Downloading package punkt to
[nltk_data]     /accounts/grad/ijyliu/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
# Flag for if this is a sample run or not
sample_run = False

## Load Model

In [3]:
model = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1")
# Move to GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
# Print model
print(model)
# Print device
print(device)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 1024, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})
)
cuda


## Function to Get Sentences and Embeddings from Doc

In [4]:
def get_doc_sentences_embeddings(filename):
    '''
    Get sentence embeddings from a document.
    '''
    # Read file
    with open('../Text Data/Terms of Service Corpus/text/' + filename, 'r', encoding='utf-8') as f:
        doc = f.read()
    # Strip all unicode characters
    doc = ''.join([char if ord(char) < 128 else ' ' for char in doc])
    # Strip lines that do not contain a period, exclamation point, or question mark, and at least one letter
    doc = '\n'.join([line for line in doc.split('\n') if ('.' in line or '!' in line or '?' in line) and any([char.isalpha() for char in line])])
    #print('number of lines in doc:', len(doc.split('\n')))
    # Parse company name by taking the first part of filename before underscore
    company_name = filename.split('_')[0]
    # Sentence tokenize the document
    sentences = nltk.sent_tokenize(doc)
    # Append company name to the beginning of each sentence
    sentences = [company_name + ' ' + sentence for sentence in sentences]
    # Encode sentences and return
    embeddings = model.encode(sentences)
    return sentences, embeddings
    

## Get List of files

In [5]:
filenames = os.listdir('../Text Data/Terms of Service Corpus/text')
print('length of filenames:', len(filenames))
print(filenames[:5])

length of filenames: 9491
['FinanceBuzz_DMCA.txt', 'SiteMinder_PrivacyPolicy.txt', 'Nintendo_TermsofUse.txt', 'CommonSenseMedia_TermsofUse.txt', 'TheOmniGroup_SyncServerTermsofService.txt']


## Clear pre-existing folder

In [6]:
# Delete files in '../Text Data Embeddings' directory
for file in os.listdir('../Text Data Embeddings'):
    os.remove('../Text Data Embeddings/' + file)

## Encode Documents and Create Parquet Files

In [7]:
def create_parquet(filename):
    '''
    Save parquet file of sentences and embeddings for a document.
    '''
    # Get sentences and embeddings
    # Start timer
    #start = time.time()
    sentences, embeddings = get_doc_sentences_embeddings(filename)
    # End timer
    #end = time.time()
    #print('Time to get embeddings:', end - start)
    # New timer
    #start = time.time()
    # Create dataframe
    df = pd.DataFrame()
    # Sentences are a column
    df['sentence'] = sentences
    # Add embeddings array on to the dataframe
    df = pd.concat([df, pd.DataFrame(embeddings)], axis=1)
    # Set column names
    column_names = ['sentence'] + [f'embed_element_{i}' for i in range(embeddings.shape[1])]
    df.columns = column_names
    # add filename column
    df['filename'] = filename # note parquet compression handles this constant value well
    # Copy dataframe 40 times to test sharding
    #df = pd.concat([df]*40, ignore_index=True)
    # If more than 1000 rows, split into 1000 row chunks
    if df.shape[0] > 1000:
        for i in range(0, df.shape[0], 1000):
            df.iloc[i:i+1000].to_parquet('../Text Data Embeddings/' + filename.split('.')[0] + f'_{i // 1000}.parquet')
    else:
        # Save to parquet
        df.to_parquet('../Text Data Embeddings/' + filename.split('.')[0] + '.parquet')
    # End timer
    #end = time.time()
    #print('Time to save parquet:', end - start)

In [8]:
# If this is a sample run, run for one filename + the largest file, otherwise, run for all
if sample_run:
    create_parquet(filenames[0])
    #create_parquet('Honeywell_CookieNotice.txt')
else:
    errored_files = []
    for filename in filenames:
        try:
            create_parquet(filename)
        except:
            errored_files.append(filename)
            continue
    # save to disk if errored files
    if len(errored_files) > 0:
        with open('errored_files.txt', 'w') as f:
            for file in errored_files:
                f.write(file + '\n')
    else:
        print('No errors in creating parquet files.')

number of lines in doc: 34


Time to get embeddings: 0.8928356170654297
Time to save parquet: 0.12570738792419434


In [9]:
# Load a file to check
test_df = pd.read_parquet('../Text Data Embeddings/' + filenames[0].split('.')[0] + '.parquet')
test_df.head()

Unnamed: 0,sentence,embed_element_0,embed_element_1,embed_element_2,embed_element_3,embed_element_4,embed_element_5,embed_element_6,embed_element_7,embed_element_8,...,embed_element_1015,embed_element_1016,embed_element_1017,embed_element_1018,embed_element_1019,embed_element_1020,embed_element_1021,embed_element_1022,embed_element_1023,filename
0,FinanceBuzz THIS NOTICE IS SUBJECT TO MODIFICA...,-0.129559,-0.358322,-0.432857,0.330654,-0.167946,-0.095894,0.059403,-0.297004,0.261421,...,-0.638149,0.160865,0.050397,0.275447,-0.004016,0.165371,0.1377,0.019452,0.134917,FinanceBuzz_DMCA.txt
1,FinanceBuzz YOU MUST CHECK BACK FREQUENTLY TO ...,0.178183,-0.409948,0.018674,0.550421,-0.163982,-0.06401,0.706781,-0.518907,0.975128,...,-0.026903,-0.362144,-0.642567,0.396593,0.658277,0.790721,0.183814,0.522085,0.370805,FinanceBuzz_DMCA.txt
2,FinanceBuzz It is our policy to respond to not...,-0.84362,-0.220504,-0.119484,0.311968,0.344784,0.439013,-0.044417,-0.821258,0.706464,...,-0.4253,-0.053282,0.580609,0.412872,0.322933,-0.00396,-0.15019,-0.197769,0.598518,FinanceBuzz_DMCA.txt
3,FinanceBuzz Responses may include removing or ...,-0.15474,-0.253713,0.231295,0.663021,0.446653,-0.131465,-0.257986,-0.186931,0.421412,...,-0.285569,-0.028328,0.438995,0.572426,0.340228,-0.412949,-0.271771,0.087636,0.514101,FinanceBuzz_DMCA.txt
4,FinanceBuzz If we remove or disable access in ...,-0.735254,-0.152386,-0.368491,0.824974,0.163802,0.147596,0.312946,-0.046971,0.347785,...,-0.592601,-0.139058,-0.04362,0.325281,0.482764,0.119737,0.14685,-0.459043,0.118706,FinanceBuzz_DMCA.txt


In [10]:
# Print all values of sentence column
for value in test_df['sentence'].values:
    print(value)

FinanceBuzz THIS NOTICE IS SUBJECT TO MODIFICATION OR TERMINATION AT ANY TIME, WHETHER FOR CHANGES IN THE LAW OR AT THE CONVENIENCE OF BUZZERY, LLC AND ITS AFFILIATES ("FINANCEBUZZ.COM") WITHOUT ADVANCE NOTICE.
FinanceBuzz YOU MUST CHECK BACK FREQUENTLY TO ENSURE THAT YOU SEE A CORRECT, CURRENT VERSION OF THE NOTICE.
FinanceBuzz It is our policy to respond to notices of alleged infringement that comply with the Digital Millennium Copyright Act and other applicable intellectual property laws.
FinanceBuzz Responses may include removing or disabling access to material claimed to be the subject of infringing activity and/or terminating subscribers.
FinanceBuzz If we remove or disable access in response to such a notice, we will make a good-faith attempt to contact the owner or administrator of the affected site or content so that they may make a counter notification pursuant to sections 512(g)(2) and (3) of that Act.
FinanceBuzz It is our policy to document all notices of alleged infringem

In [11]:
# Load Honeywell file to check
#test_df = pd.read_parquet('../Text Data Embeddings/Honeywell_CookieNotice.parquet')
#test_df.head()