In [1]:
import copy
from pandas import Series, DataFrame
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer

In [18]:
class TopicExtractor:
    
    def __init__(self, trainingDocuments, topicDictionary):
        
        # generate a tfidf matrix for the training documents
        self.tfidfVectorizer = (TfidfVectorizer(token_pattern="[a-z']{2,}", 
                                                stop_words='english', 
                                                ngram_range=(1, 2), 
                                                min_df=0.01, 
                                                max_df=0.99
                                               )
                               )
    
        fit = self.tfidfVectorizer.fit_transform(trainingDocuments)
        
        # store n-grams with associated weights and topics in dataframe
        ngrams = Series(self.tfidfVectorizer.get_feature_names())
        self.ngramTopics = DataFrame({'ngram': ngrams,
                                      'idf': self.tfidfVectorizer.idf_,
                                      'topic': ngrams.apply(lambda ngram: # adapted from Sven Marnach's post at 
                                                                          # https://stackoverflow.com/questions/8122079/python-how-to-check-a-string-for-substrings-from-a-list
                                                            [key for key in topicDictionary.keys() 
                                                             if any(sub_ngram in topicDictionary[key]
                                                                    for sub_ngram in ngram.split(' ')
                                                                   )
                                                            ]
                                                           )
                                })
    
        # organize n-grams by topic (split ngrams with multiple topics into seperate rows)
        # adapted from Alexander's post at 
        # https://stackoverflow.com/questions/32468402/how-to-explode-a-list-inside-a-dataframe-cell-into-separate-rows
        newRows = []
        splitRows = self.ngramTopics.apply(lambda row: 
                                           [newRows.append([row['ngram'], row['idf'], topic]) 
                                            for topic in row['topic']
                                           ], 
                                           axis=1
                                          )
    
        self.ngramsByTopic = (DataFrame(newRows, columns=['ngram', 'idf', 'topic'])
                              .sort_values(['topic', 'idf'], ascending=[True, False])
                              .set_index(['topic', 'ngram'])
                             )
    
    def getTfIdfScore(self, documents, topic):
        
        # provide a total, topic-speicific tfidf score of a set of documents
        idfs = self.ngramsByTopic.loc[topic]
        countVectorizer = CountVectorizer(vocabulary=idfs.index.values.tolist())
        fit = countVectorizer.fit_transform(documents)
        
        return sum(idfs['idf'] * fit.toarray().sum(axis=0))