In [None]:
from sentence_transformers import SentenceTransformer

In [None]:
import json
import sys
import torch
import time

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
batch_size = 1000

In [None]:
mpnet = SentenceTransformer('stsb-mpnet-base-v2', device=device)
distilroberta = SentenceTransformer('stsb-distilroberta-base-v2', device=device)
mpnet.max_seq_length = 512
distilroberta.max_seq_length = 512

usenc = SentenceTransformer('distiluse-base-multilingual-cased-v1', device=device)
usenc.max_seq_length = 512

In [None]:
documents = []
filename = 'top50kaaa_test_index'
with open(filename) as f:
    for line in f:
        documents.append(json.loads(line))

In [None]:
len(documents)

In [None]:
# -*- coding: utf-8 -*-
from __future__ import absolute_import
from __future__ import division, print_function, unicode_literals

from sumy.parsers.html import HtmlParser
from sumy.parsers.plaintext import PlaintextParser
from sumy.nlp.tokenizers import Tokenizer
from sumy.summarizers.lex_rank import LexRankSummarizer as Summarizer
from sumy.nlp.stemmers import Stemmer
from sumy.utils import get_stop_words

LANGUAGE = "english"
SENTENCES_COUNT = 7

import nltk; nltk.download('punkt')

In [None]:
start = time.time()
for doc in documents:
    parser = PlaintextParser.from_string(doc['_source']['content_t'], Tokenizer(LANGUAGE))
    stemmer = Stemmer(LANGUAGE)
    summarizer = Summarizer(stemmer)
    summarizer.stop_words = get_stop_words(LANGUAGE)
    summarization = ""
    for sentence in summarizer(parser.document, SENTENCES_COUNT):
        summarization += str(sentence)
        summarization += " "
    
    doc['_source']['extract'] = summarization
print(time.time() - start)

In [None]:
start = time.time()
with open(filename + '_with_embeddings', 'w') as fout:
    docs = []
    datas1 = []
    datas2 = []
    count = 0

    for doc in documents:
        docs.append(doc)
        datas1.append(doc['_source']['extract'])
        datas2.append(doc['_source']['content_t'])
        
        count += 1
        if count == batch_size:
            mpnet_embeddings1 = mpnet.encode(datas1, device=device)
            mpnet_embeddings2 = mpnet.encode(datas2, device=device)
            use_embeddings = usenc.encode(datas2, device=device)
            distilroberta_embeddings1 = distilroberta.encode(datas1, device=device)
            distilroberta_embeddings2 = distilroberta.encode(datas2, device=device)
            
            for idx, doc in enumerate(docs):
                doc['_source']['mpnet_embedding_extract'] = mpnet_embeddings1[idx].tolist()
                doc['_source']['mpnet_embedding'] = mpnet_embeddings2[idx].tolist()
                doc['_source']['use_embedding'] = use_embeddings[idx].tolist()
                doc['_source']['distilroberta_embedding_extract'] = distilroberta_embeddings1[idx].tolist()
                doc['_source']['distilroberta_embedding'] = distilroberta_embeddings2[idx].tolist()
                
                fout.write(json.dumps(doc) + '\n')
            datas1 = []
            datas2 = []
            docs = []
            count = 0
            
print(time.time() - start)