In [61]:
import numpy as np
import itertools

from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

In [62]:
doc = """
         Supervised learning is the machine learning task of 
         learning a function that maps an input to an output based 
         on example input-output pairs.[1] It infers a function 
         from labeled training data consisting of a set of 
         training examples.[2] In supervised learning, each 
         example is a pair consisting of an input object 
         (typically a vector) and a desired output value (also 
         called the supervisory signal). A supervised learning 
         algorithm analyzes the training data and produces an 
         inferred function, which can be used for mapping new 
         examples. An optimal scenario will allow for the algorithm 
         to correctly determine the class labels for unseen 
         instances. This requires the learning algorithm to  
         generalize from the training data to unseen situations 
         in a 'reasonable' way (see inductive bias).
      """


In [74]:
n_gram_range = (3,3) # trigrams
stop_words = "english"

# Extract candidate words/phrases
cv = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words).fit([doc])
candidates = cv.get_feature_names_out()

print(candidates)
print(len(candidates))

['algorithm analyzes training' 'algorithm correctly determine'
 'algorithm generalize training' 'allow algorithm correctly'
 'analyzes training data' 'based example input'
 'called supervisory signal' 'class labels unseen'
 'consisting input object' 'consisting set training'
 'correctly determine class' 'data consisting set'
 'data produces inferred' 'data unseen situations' 'desired output value'
 'determine class labels' 'example input output' 'example pair consisting'
 'examples optimal scenario' 'examples supervised learning'
 'function labeled training' 'function maps input' 'function used mapping'
 'generalize training data' 'inferred function used'
 'infers function labeled' 'input object typically' 'input output based'
 'input output pairs' 'instances requires learning'
 'labeled training data' 'labels unseen instances'
 'learning algorithm analyzes' 'learning algorithm generalize'
 'learning example pair' 'learning function maps'
 'learning machine learning' 'learning task lea

In [75]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
doc_embedding = model.encode([doc])
candidate_embeddings = model.encode(candidates)
distances = cosine_similarity(doc_embedding, candidate_embeddings)

print('document_embedding:', doc_embedding.shape)
print('candidate_embeddings:', candidate_embeddings.shape)
print('distances:', distances.shape)

document_embedding: (1, 768)
candidate_embeddings: (72, 768)
distances: (1, 72)


# 1. Basic Key BERT

In [76]:
# pick top_n words closest to the doc.
top_n = 5

keywords = [candidates[index] for index in distances.argsort()[0][-top_n:]]  # return top n keywords
print(keywords)

['algorithm analyzes training', 'learning algorithm generalize', 'learning machine learning', 'learning algorithm analyzes', 'algorithm generalize training']


# 2. Max Sum Similarity

In [66]:
candidates_num = 10

dist_doc_word = cosine_similarity(doc_embedding, candidate_embeddings) # (1, 72)
dist_word_word = cosine_similarity(candidate_embeddings, candidate_embeddings) # (72, 72)

words_idx = list(np.argsort(dist_doc_word)[0][-candidates_num:])
words = [candidates[idx] for idx in words_idx]

dist_word_word = dist_word_word[np.ix_(words_idx, words_idx)]

min_sim = np.inf
candidate = None

# find the least similar words
for combination in itertools.combinations(range(len(words_idx)), top_n):
    sim = sum([dist_word_word[i][j] for i in combination for j in combination if i != j])
    if sim < min_sim:
        candidate = combination
        min_sim = sim
        
result = [words[idx] for idx in candidate]

print(result)

['requires learning algorithm', 'signal supervised learning', 'learning function maps', 'algorithm analyzes training', 'learning machine learning']


### np.ix_
 - construct an open mesh from multiple sequences.
```python
a = np.array([[1,2,3,4,5,6,7,8,9,10],[11,12,13,14,15,16,17,18,19,20],[21,22,23,24,25,26,27,28,29,30]])
_a = a[np.ix_([0,2,1],[0,1,2])]
print(_a) 
"""
[[ 1  2  3]
 [21 22 23]
 [11 12 13]]
"""
```

### itertools.combinations
 - return r length subsequences of elements from the input iterable.
 - combinations('ABCD', 2) --> AB AC AD BC BD CD
 - combinations(range(4), 3) --> 012 013 023 123
```python
for combination in itertools.combinations(range(5), 3):
  print(combination)
"""
(0, 1, 2)
(0, 1, 3)
(0, 1, 4)
(0, 2, 3)
(0, 2, 4)
(0, 3, 4)
(1, 2, 3)
(1, 2, 4)
(1, 3, 4)
(2, 3, 4)
"""
```

# 3. Maximal Marginal Relevance

In [93]:
diversity = 0.2

dist_word_doc = cosine_similarity(candidate_embeddings, doc_embedding) # (72, 1)
dist_word_word_2 = cosine_similarity(candidate_embeddings) # (72, 72)

most_similar_keyword = [np.argmax(dist_word_doc)] # [2]

# cand_indexes of words except most_similar_keyword
cand_indexes = [i for i in range(len(candidates)) if i not in most_similar_keyword] 

for _ in range(top_n-1):
    candidate_similarities = dist_word_doc[cand_indexes]
    target_similarities = np.max(dist_word_word_2[cand_indexes][:, most_similar_keyword], axis=1)
    # print(candidate_similarities.shape)
    # print(target_similarities.reshape(-1,1).shape)
    
    mmr = (1-diversity) * candidate_similarities - diversity * target_similarities.reshape(-1,1)
    mmr_idx = cand_indexes[np.argmax(mmr)]
    
    most_similar_keyword.append(mmr_idx)
    cand_indexes.remove(mmr_idx)
    
result = [candidates[idx] for idx in most_similar_keyword]
print(result)

['algorithm generalize training', 'supervised learning algorithm', 'learning machine learning', 'learning algorithm analyzes', 'learning algorithm generalize']


### np.array
```python
a = np.array([1,2,3,4,5,6,7,8,9])
a = a.reshape(3,3)
print(a)
"""
[[1 2 3]
 [4 5 6]
 [7 8 9]]
"""
print(a[[0,1,2]][:,[2]])
"""
[[3]
 [6]
 [9]]
"""
```