In [9]:
import requests
import os

books = [
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%201%20-%20The%20Philosopher's%20Stone.txt",
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%202%20-%20The%20Chamber%20of%20Secrets.txt",
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%203%20-%20The%20Prisoner%20of%20Azkaban.txt",
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%204%20-%20The%20Goblet%20of%20Fire.txt",
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%205%20-%20The%20Order%20of%20the%20Phoenix.txt",
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%206%20-%20The%20Half%20Blood%20Prince.txt",
    "https://raw.githubusercontent.com/formcept/whiteboard/master/nbviewer/notebooks/data/harrypotter/Book%207%20-%20The%20Deathly%20Hallows.txt"
]

if not os.path.exists('books'):
    os.makedirs('books')

for i, book in enumerate(books):
    with open(f'books/book_{i}.txt', 'w+') as book_file:
        book_file.write(requests.get(book).text)



In [10]:
import re
import json
import glob
from collections import Counter


words_counter = Counter()
paragraphs = []

for i, book in enumerate(sorted(glob.glob('books/*.txt'))):
    with open(book, 'r') as book_file:
        text = book_file.read().lower()
        text = re.sub(r'^[Pp]age.*?$', '', text, flags=re.MULTILINE)
        text = re.sub(r'\s+', ' ', text)
        text = text.replace('’', "'")
        text = text.replace("'s", '')
        
        # remove all non-alphanumeric characters
        text = re.sub(r"[^a-zA-Z0-9\s]", '', text)
        
        words_counter.update(text.split(' '))
        

words = {word: [{}, word, count] for word, count in words_counter.items() if len(word) > 1 and count > 10}
words_copy = words.copy()
for word in words.keys():
    if word.endswith('s'):
        del words_copy[word]

words = words_copy

with open('hp-words.json', 'w+') as words_file:
    words_file.write(json.dumps(words))


paragraphs = []
for i, book in enumerate(glob.glob('books/*.txt')):
    with open(book, 'r') as book_file:
        text = book_file.read().lower()
        text = re.sub(r'^[Pp]age.*?$', '', text, flags=re.MULTILINE)
        paragraphs.extend(re.split(r'\n\s*\n', text))
paragraphs = [p for p in paragraphs if len(p) > 50]
paragraphs = [p.replace('\n', '') for p in paragraphs]


In [11]:
from sentence_transformers import SentenceTransformer, util

model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
model = model.to('cuda')

In [12]:
corpus_embeddings = model.encode(paragraphs, show_progress_bar=True)

Batches:   0%|          | 0/1051 [00:00<?, ?it/s]

In [13]:
model = model.to('cpu')


In [99]:
from transformers import BartTokenizerFast, BartForConditionalGeneration
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.models.bart.modeling_bart import shift_tokens_right
import torch
import datetime
import transformers
transformers.logging.set_verbosity_error()

bart_tokenizer = BartTokenizerFast.from_pretrained('~/models/bart-fine-tuned-msmarco-with-context-1.65')
bart_model = BartForConditionalGeneration.from_pretrained('~/models/bart-fine-tuned-msmarco-with-context-1.65').cuda()
# bart_tokenizer = T5Tokenizer.from_pretrained('t5-small')
# bart_model = T5ForConditionalGeneration.from_pretrained('~/models/t5-autocomplete').cpu()
# from optimum.bettertransformer import BetterTransformer
# bart_model = BetterTransformer.transform(bart_model, keep_original_model=False)

all_tokens = list(bart_tokenizer.get_vocab().values())

def do_autocomplete(query):
    t = datetime.datetime.now()
    if query == '':
        return []
    encoded_query = model.encode([query])
    k=20
    knn = util.semantic_search(encoded_query, corpus_embeddings, top_k=k)[0]
    knn = [paragraphs[hit['corpus_id']] for hit in knn]
    # prompt = f'{query}<mask>#{";".join(knn)}'
    prompt = f'{query}<mask>#{"; ".join(knn)}'

    tokenized = \
        bart_tokenizer(prompt, return_tensors='pt', max_length=1024, truncation=True)
    output = bart_model.generate(
        input_ids=tokenized['input_ids'].cuda(),
        attention_mask=tokenized['attention_mask'].cuda(),
        early_stopping=True,
        num_beams=1,
        num_return_sequences=1,
        length_penalty=4.0
        )
    results = [bart_tokenizer.decode(o, skip_special_tokens=True) for o in output]
    return {'results': results, 'knn': knn, 'time': (datetime.datetime.now() - t).total_seconds()}


In [103]:
query = 'why did Dudley'


suggestions = do_autocomplete(query)
print('query: ', query)
print('suggestions: ', suggestions)


query:  why did Dudley
suggestions:  {'results': ['why did Dudley get a new job'], 'knn': ['he looked so dangerous with half his mustache missing that no one dared argue. ten minutes later they had wrenched their way through the boarded-up doors and were in the car, speeding toward the highway. dudley was sniffling in the back seat; his father had hit him round the head for holding them up while he tried to pack his television, vcr, and computer in his sports bag. ', 'on the last day of august he thought he’d better speak to his aunt and uncle about getting to king’s cross station the next day, so he went down to the living room where they were watching a quiz show on television. he cleared his throat to let them know he was there, and dudley screamed and ran from the room. ', 'this boy was another good reason for keeping the potters away; they didn’t want dudley mixing with a child like that. ', 'but dudley either could not or would not move. he was still on the ground, trembling and 

In [66]:
index = '''
<!DOCTYPE html>
<html>
<head>|
    <meta charset='utf-8'>
    <meta http-equiv='X-UA-Compatible' content='IE=edge'>
    <title>Page Title</title>
    <meta name='viewport' content='width=device-width, initial-scale=1'>
    <script>
        let timeout = null;
        async function complete() {
            document.getElementById('completion').innerHTML = '';
            if (timeout !== null) {
                clearTimeout(timeout);
            }
            timeout = setTimeout(async () => {
                let query = document.getElementById('query').value;
                if (query.length < 5) {
                    return;
                }
                let result = await fetch('/autocomplete?query=' + encodeURIComponent(query));
                let json = await result.json();
                document.getElementById('completion').innerHTML = json.results.join('<br>') + '<br><br>' + json.knn.join('<br>') + '<br><br> time: ' + json.time;
            }, 300);
        }
    </script>
</head>
<body>
    <input type="text" id="query" onkeyup="complete()" placeholder="Search..." title="Type in a name">
    <p id="completion"></p>
</body>
</html>
'''

In [67]:
from fastapi import FastAPI
from fastapi.responses import HTMLResponse

app = FastAPI()

@app.get("/")
async def root():
    return HTMLResponse(index)

@app.get('/autocomplete')
async def autocomp(query: str):
    return do_autocomplete(query)

In [68]:
import asyncio
import uvicorn

if __name__ == "__main__":
    config = uvicorn.Config(app, host='0.0.0.0', port=3000)
    server = uvicorn.Server(config)
    await server.serve()

INFO:     Started server process [1281]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:3000 (Press CTRL+C to quit)


INFO:     147.235.213.96:3668 - "GET /autocomplete?query=can%20hermione%20tim HTTP/1.1" 200 OK
INFO:     147.235.213.96:3680 - "GET /autocomplete?query=can%20hermione%20time HTTP/1.1" 200 OK
INFO:     147.235.213.96:3680 - "GET /autocomplete?query=can%20hermione%20 HTTP/1.1" 200 OK
INFO:     147.235.213.96:3680 - "GET /autocomplete?query=how%20do%20 HTTP/1.1" 200 OK
INFO:     147.235.213.96:3680 - "GET /autocomplete?query=how%20do%20you%20destroy HTTP/1.1" 200 OK
INFO:     147.235.213.96:3680 - "GET /autocomplete?query=how%20do%20you%20destroy HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [1281]
