# Import libraries

In [1]:
import sys

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import transformers
from transformers import RobertaTokenizerFast, RobertaModel


print('Python:'.ljust(16), sys.version.split('\n')[0])
print('Transformers:'.ljust(16), transformers.__version__)

Python:          3.11.5 (main, Aug 24 2023, 15:09:45) [Clang 14.0.3 (clang-1403.0.22.14.1)]
Transformers:    4.33.2


# Define constants

In [2]:
# Paths
DATA_PATH = './suggestion-data/'

# Pretrained model name (checkpoint)
MODEL_NAME = 'roberta-base'
MODEL_INTERNAL_DIM = 768

# Number of suggestions
SUGGESTION_NUM = 5

# Load tokenizer and model

In [3]:
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_NAME)
model = RobertaModel.from_pretrained(MODEL_NAME)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Read terms ans sample text

In [4]:
# Read terms
with open(DATA_PATH + 'terms.csv') as f:
    terms = f.read()

# Preprocess terms
terms = terms.lower().split('\n')
terms = [' ' + term for term in terms]

# Read sample text
with open(DATA_PATH + 'sample-text.txt') as f:
    sample = f.read()
    
# Print terms and sample text
print(terms)
print()
print(sample)

[' optimal performance', ' utilise resources', ' enhance productivity', ' conduct an analysis', ' maintain a high standard', ' implement best practices', ' ensure compliance', ' streamline operations', ' foster innovation', ' drive growth', ' leverage synergies', ' demonstrate leadership', ' exercise due diligence', ' maximize stakeholder value', ' prioritise tasks', ' facilitate collaboration', ' monitor performance metrics', ' execute strategies', ' gauge effectiveness', ' champion change']

In today's meeting, we discussed a variety of issues affecting our department. The weather was unusually sunny, a pleasant backdrop to our serious discussions. We came to the consensus that we need to do better in terms of performance. Sally brought doughnuts, which lightened the mood. It's important to make good use of what we have at our disposal. During the coffee break, we talked about the upcoming company picnic. We should aim to be more efficient and look for ways to be more creative in our

# Get embeddings of the terms

In [5]:
def get_hidden_state(text):
    # Get offset mapping
    model_input = tokenizer(
        text,
        return_tensors='pt',
        return_special_tokens_mask=True,
        return_offsets_mapping=True,
    )
    offset_mapping = model_input['offset_mapping'].detach().numpy().squeeze()

    # Get last hidden state
    model_output = model(
        input_ids=model_input['input_ids'],
        attention_mask=model_input['attention_mask'],
    )
    last_hidden_state = model_output['last_hidden_state'].detach().numpy().squeeze()
    
    # Mask to get no special tokens
    mask = model_input['special_tokens_mask'].detach().numpy().squeeze() == 0
    last_hidden_state = last_hidden_state[mask, :]
    offset_mapping = offset_mapping[mask, :]

    return last_hidden_state, offset_mapping

In [6]:
# Compute mean pooling embedding
term_embedding = np.empty((len(terms), MODEL_INTERNAL_DIM))
for i, term in enumerate(terms):
    last_hidden_state, _ = get_hidden_state(term)
    term_embedding[i] = np.sum(last_hidden_state, axis=0) / last_hidden_state.shape[0]

# Print shape
print(term_embedding.shape)

(20, 768)


# Get embeddings of the sample text

In [7]:
# get last_hidden_state and offset_mapping
last_hidden_state, offset_mapping = get_hidden_state(sample)

# Compute number of contexts (array length)
length = 0
for i in range(last_hidden_state.shape[0]):
    for j in range(i, i+8):
        if j >= last_hidden_state.shape[0]:
            continue
        length += 1

# Compute mean pooling embeddings for all contexts in sample text
context_embedding = np.empty((length, MODEL_INTERNAL_DIM))
context_mapping = np.empty((length, 2), int)
index = 0
for i in range(last_hidden_state.shape[0]):
    for j in range(i, i+8):
        if j < last_hidden_state.shape[0]:
            slice = last_hidden_state[i:j+1, :]
            context_embedding[index] = np.sum(slice, axis=0) / slice.shape[0]
            context_mapping[index] = np.array((offset_mapping[i, 0], offset_mapping[j, 1]))
            index += 1

# Print shapes
print(context_embedding.shape)
print(context_mapping.shape)

(1372, 768)
(1372, 2)


# Get suggestions

In [8]:
# Compute similarity and sort
if context_embedding.shape[0]:
    similarity = cosine_similarity(context_embedding, term_embedding)
else:
    similarity = np.array([[]])
flat_indices = np.flip(np.argsort(similarity, axis=None))

# Get suggestions
spans = []
original_phrases = []
replacements = []
scores = []
for index in flat_indices:
    row = index // len(terms)
    col = index % len(terms)

    # Get current suggestion in the order
    text_span = context_mapping[row].tolist()
    original_phrase = sample[text_span[0]:text_span[1]]
    replacement = terms[col][1:]
    score = similarity[row, col].item()

    # Check if it is a new span
    new_span = True
    for span in spans:
        if ((span[0] <= text_span[0] < span[1])
            or (span[0] < text_span[1] <= span[1])
            or (text_span[0] < span[0] and span[1] < text_span[1])):
            new_span = False

    # Add suggestion if it is a new span
    if new_span:
        spans.append(text_span)
        original_phrases.append(original_phrase)
        replacements.append(replacement)
        scores.append(score)
        
        if len(spans) >= SUGGESTION_NUM:
            break

# Print result
for i in range(len(spans)):
    print(spans[i])
    print(original_phrases[i])
    print(replacements[i])
    print(scores[i])
    print()

[53, 78]
affecting our department.
monitor performance metrics
0.9237443330959721

[139, 159]
serious discussions.
monitor performance metrics
0.9220517883443872

[406, 421]
company picnic.
monitor performance metrics
0.9208585375958583

[281, 291]
mood. It's
monitor performance metrics
0.9111545349126645

[220, 241]
of performance. Sally
monitor performance metrics
0.909594951782594

