diff --git a/.env b/.env index 2510758678..63d677513f 100644 --- a/.env +++ b/.env @@ -29,3 +29,5 @@ WIKIDATA_DIALOGUE_SERVICE_URL=http://wikidata-dial-service:8092/model NEWS_API_ANNOTATOR_URL=http://news-api-annotator:8112/respond WIKI_FACTS_URL=http://wiki-facts:8116/respond FACT_RANDOM_SERVICE_URL=http://fact-random:8119/respond +INFILLING_SERVICE_URL=http://infilling:8122/respond + diff --git a/README.md b/README.md index 9447d3592c..d4e5ca2d2a 100644 --- a/README.md +++ b/README.md @@ -187,6 +187,12 @@ Dream Architecture is presented in the following image: | User Persona Extractor | 40 MiB RAM | determines which age category the user belongs to based on some key words | | Wiki parser | 100 MiB RAM | extracts Wikidata triplets for the entities detected with Entity Linking | +## Services +| Name | Requirements | Description | +|---------------------------|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| DialoGPT | 1.3 GiB RAM, 1 GiB GPU | generative service based on Transformers generative model, the model is set in docker compose argument `PRETRAINED_MODEL_NAME_OR_PATH` (for example, `microsoft/DialoGPT-small` with 0.2-0.5 sec on GPU) | +| Infilling | 1.7 GiB RAM, 1 GiB GPU | generative service based on Infilling model, for the given utterance returns utterance where `_` from original text is replaced with generated tokens | + ## Skills | Name | Requirements | Description | |---------------------------|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| diff --git a/assistant_dists/dream/cpu.yml b/assistant_dists/dream/cpu.yml index 1d5e8e6916..eaf049edb1 100644 --- a/assistant_dists/dream/cpu.yml +++ b/assistant_dists/dream/cpu.yml @@ -49,3 +49,6 @@ services: dialogpt: environment: CUDA_VISIBLE_DEVICES: "" + infilling: + environment: + CUDA_VISIBLE_DEVICES: "" diff --git a/assistant_dists/dream/dev.yml b/assistant_dists/dream/dev.yml index 6c1b9a311d..00ab64e0fa 100755 --- a/assistant_dists/dream/dev.yml +++ b/assistant_dists/dream/dev.yml @@ -399,6 +399,11 @@ services: - "./services/dialogpt:/src" ports: - 8125:8125 + infilling: + volumes: + - "./services/infilling:/src" + ports: + - 8122:8122 dff-template-skill: volumes: - "./skills/dff_template_skill:/src" diff --git a/assistant_dists/dream/docker-compose.override.yml b/assistant_dists/dream/docker-compose.override.yml index caa542b492..8a559d7388 100644 --- a/assistant_dists/dream/docker-compose.override.yml +++ b/assistant_dists/dream/docker-compose.override.yml @@ -19,7 +19,7 @@ services: dff-funfact-skill:8104, dff-bot-persona-skill:8105, news-api-annotator:8112, dff-gossip-skill:8109, dff-wiki-skill:8111, dff-gaming-skill:8115, topic-recommendation:8113, user-persona-extractor:8114, wiki-facts:8116, dff-music-skill:8099, entity-detection:8103, dff-art-skill:8117, - midas-predictor:8121, dialogpt:8125, dff-template-skill:8120" + midas-predictor:8121, dialogpt:8125, infilling:8122, dff-template-skill:8120" WAIT_HOSTS_TIMEOUT: ${WAIT_TIMEOUT:-480} convers-evaluator-annotator: env_file: [.env] @@ -1144,7 +1144,7 @@ services: memory: 50M reservations: memory: 50M - + dialogpt: env_file: [ .env ] build: @@ -1164,6 +1164,23 @@ services: reservations: memory: 2G + infilling: + env_file: [ .env ] + build: + context: ./services/infilling/ + args: + SERVICE_PORT: 8122 + command: flask run -h 0.0.0.0 -p 8122 + environment: + - CUDA_VISIBLE_DEVICES=0 + - FLASK_APP=server + deploy: + resources: + limits: + memory: 2.5G # ? + reservations: + memory: 2.5G # ? + dff-template-skill: env_file: [.env] build: diff --git a/assistant_dists/dream/gpu1.yml b/assistant_dists/dream/gpu1.yml index d9e545e504..82c63a5a75 100644 --- a/assistant_dists/dream/gpu1.yml +++ b/assistant_dists/dream/gpu1.yml @@ -189,6 +189,10 @@ services: restart: unless-stopped environment: - CUDA_VISIBLE_DEVICES=9 + infilling: + restart: unless-stopped + environment: + - CUDA_VISIBLE_DEVICES=7 dff-template-skill: restart: unless-stopped version: '3.7' diff --git a/assistant_dists/dream/proxy.yml b/assistant_dists/dream/proxy.yml index 2f57601fa5..b654ee4839 100644 --- a/assistant_dists/dream/proxy.yml +++ b/assistant_dists/dream/proxy.yml @@ -593,7 +593,16 @@ services: environment: - PROXY_PASS=dream.deeppavlov.ai:8125 - PORT=8125 - + + infilling: + command: [ "nginx", "-g", "daemon off;" ] + build: + context: dp/proxy/ + dockerfile: Dockerfile + environment: + - PROXY_PASS=dream.deeppavlov.ai:8122 + - PORT=8122 + dff-template-skill: command: [ "nginx", "-g", "daemon off;" ] build: diff --git a/assistant_dists/dream/test.yml b/assistant_dists/dream/test.yml index f503619a6e..72fdef5c46 100644 --- a/assistant_dists/dream/test.yml +++ b/assistant_dists/dream/test.yml @@ -130,5 +130,8 @@ services: dialogpt: environment: - CUDA_VISIBLE_DEVICES=6 + infilling: + environment: + - CUDA_VISIBLE_DEVICES=8 dff-template-skill: version: '3.7' diff --git a/common/infilling.py b/common/infilling.py new file mode 100644 index 0000000000..300dce8aa8 --- /dev/null +++ b/common/infilling.py @@ -0,0 +1,10 @@ +import os +import requests + + +INFILLING_SERVICE_URL = os.getenv("INFILLING_SERVICE_URL", "http://0.0.0.0:8122/respond") + + +def infill_texts(texts, timeout=1): + result = requests.post(INFILLING_SERVICE_URL, json={"texts": texts}, timeout=timeout).json()["infilled_text"] + return result diff --git a/services/infilling/Dockerfile b/services/infilling/Dockerfile new file mode 100644 index 0000000000..8762e20c9d --- /dev/null +++ b/services/infilling/Dockerfile @@ -0,0 +1,34 @@ +# syntax=docker/dockerfile:experimental + +FROM pytorch/pytorch:1.5-cuda10.1-cudnn7-runtime + +RUN apt-get update && apt-get install -y --allow-unauthenticated wget + +WORKDIR /src + +ARG MODEL_DIR=/data/ +ENV MODEL_DIR ${MODEL_DIR} +ARG SERVICE_PORT +ENV SERVICE_PORT ${SERVICE_PORT} + +COPY ./requirements.txt /src/requirements.txt +RUN pip install -r /src/requirements.txt + +COPY . /src + +RUN mkdir /data/ +RUN ls +RUN python download_model.py model sto ilm | bash +WORKDIR /data/ + +RUN wget http://files.deeppavlov.ai/dream/infilling/additional_ids_to_tokens.pkl +RUN wget http://files.deeppavlov.ai/dream/infilling/vocab.bpe +RUN wget http://files.deeppavlov.ai/dream/infilling/encoder.json +RUN wget http://files.deeppavlov.ai/dream/infilling/config.json + +WORKDIR /src + +HEALTHCHECK --interval=5s --timeout=90s --retries=3 CMD curl --fail 127.0.0.1:${SERVICE_PORT}/healthcheck || exit 1 + + +CMD gunicorn --workers=1 server:app -b 0.0.0.0:${SERVICE_PORT} --timeout=300 diff --git a/services/infilling/README.md b/services/infilling/README.md new file mode 100644 index 0000000000..7a8e77316b --- /dev/null +++ b/services/infilling/README.md @@ -0,0 +1,5 @@ +GPU RAM = 1Gb +cpu time = 0.5-2 sec +gpu time = 0.1-0.5 sec + +Very unstable inference time \ No newline at end of file diff --git a/services/infilling/constants.py b/services/infilling/constants.py new file mode 100644 index 0000000000..f8195d4812 --- /dev/null +++ b/services/infilling/constants.py @@ -0,0 +1,3 @@ +GPT2_MODEL_NAMES = ["gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"] +GPT2_TOKENIZER_LEN = 50257 +GPT2_EOS_TOKEN_ID = 50256 diff --git a/services/infilling/download_model.py b/services/infilling/download_model.py new file mode 100644 index 0000000000..bcc9f1c925 --- /dev/null +++ b/services/infilling/download_model.py @@ -0,0 +1,124 @@ +PREMASKED_DATA = { + "train": { + "sto_mixture": "https://drive.google.com/open?id=1LxlyPqz3OvAZsYRRC8yRdSoaCKGB0Ucg", + "abs_mixture": "https://drive.google.com/open?id=1rw45GKP4iRJLzXnRtX-rnk_NeGXOqWkU", + "lyr_mixture": "https://drive.google.com/open?id=1jGCjboxlFUF0jqvB0_-L0eeylhKWfZJV", + }, + "valid": { + "sto_mixture": "https://drive.google.com/open?id=1Y4HRYrBnqwtdbziF5Q6b5WaIFxJd1v7m", + "abs_mixture": "https://drive.google.com/open?id=1hHdXX43qkkm-zpUCJz_iuv1vRRpyfbaP", + "lyr_mixture": "https://drive.google.com/open?id=1xR0LC5WHV1UQDPjWTN0HcOQ9C5jsXYef", + }, + "test": { + # Table 1/6 + "sto_sentence": "https://drive.google.com/open?id=1w02hGewoBk_Pq-thrtbOcU1JPGRdGL_U", + "abs_sentence": "https://drive.google.com/open?id=18aNMfcqC1wyC8wWJHbCfMxLY49Dbg-du", + "lyr_sentence": "https://drive.google.com/open?id=18Szj-HYwh3sjLmmfF8TNwAmB2oEddool", + # Table 3 + "sto_document": "https://drive.google.com/open?id=1ydEjL0SMbX8p-1w6XeLWNrzeVn8TleGT", + "abs_document": "https://drive.google.com/open?id=1UjPh51URE8hvK-yTw3xkVwkBwEcCo6Uz", + "lyr_document": "https://drive.google.com/open?id=1KNvdzn1xhpw0Xdh0pMWN3CtrDKEK8d4N", + # Table 4 + "sto_mixture": "https://drive.google.com/open?id=1Zsuj8Plrcs49f-5rV6dvJ5W2kIz_C30u", + "abs_mixture": "https://drive.google.com/open?id=1TA3ySrvcWxaNtoDPpN8Jk7uqjGKdPqda", + "lyr_mixture": "https://drive.google.com/open?id=1FGEL3CGzLvnWpgvUYWsHsOUgW65DVxgw", + # Table 5 + "sto_paragraph": "https://drive.google.com/open?id=1MBM96hfN2cGJidG-mi_4bE0K07xgWAxT", + "abs_paragraph": "https://drive.google.com/open?id=1xXJfjCNzRLXYZgHgUrNimP4CtUW0Ziph", + "lyr_paragraph": "https://drive.google.com/open?id=10ScpFR8sG3Ur0WpWdkPYxAsT94jNNmZh", + # Table 7 + "sto_ngram": "https://drive.google.com/open?id=1x8RBys_jbreSFO1zMdmwiT2ref2F8q_C", + "abs_ngram": "https://drive.google.com/open?id=1JJyh7clJjyPF-rm4rHFLyX7Y-l_doD0K", + "lyr_ngram": "https://drive.google.com/open?id=1dbCCc68TvY6segwTrrxYS1ukVbdC7zgJ", + # Table 8 + "sto_word": "https://drive.google.com/open?id=178joxkympgzDwZoExnalWujRq2Jv_37P", + "abs_word": "https://drive.google.com/open?id=1PdVg-TnG5VQt8GCQOQA841AGw1GR44yl", + "lyr_word": "https://drive.google.com/open?id=1Td-yr6g5cTxW4yoz_Wv4gSi-wbu1376R", + }, +} + +PRETRAINED_MODELS = { + # Trained on stories + "sto_lm": "https://drive.google.com/open?id=1-FGKu-bodqOsCGrFCYY6Yyp2rTk2rRpc", + "sto_lmrev": "https://drive.google.com/open?id=1_uCgugc57tPGfFofKbU8doJN23cf4lEY", + "sto_lmall": "https://drive.google.com/open?id=1dPOLkggPbe-Pzn8VVkcrinuGJv2yRieR", + "sto_ilm": "https://drive.google.com/open?id=1oYFLxkX6mWbmpEwQH8BmgE7iKix2W0zU", + "sto_lmscratch": "https://drive.google.com/open?id=1vGxdfZUWtOB5ajpDgSGUXuHK5_BGY9GA", + "sto_lmrevscratch": "https://drive.google.com/open?id=1xbyQ5bMJpTxlsPtL1YsH2jmUUh_49gOI", + "sto_lmallscratch": "https://drive.google.com/open?id=1Qy13Dw60Jd5HqN89q8WvCMtwvTXJw7tj", + "sto_ilmscratch": "https://drive.google.com/open?id=14BFLWSaPi2JSsKsa68lcTSnCOnYV9jPm", + # Trained on abstracts + "abs_lm": "https://drive.google.com/open?id=1BSIFfuSTznmHIKa4R-AnwIxN93b1Ap-b", + "abs_lmrev": "https://drive.google.com/open?id=1yl36oZq9R_d3IhlFWLlMGq46n8F9Lq1q", + "abs_lmall": "https://drive.google.com/open?id=1qyM0OCL8pI5dL7sfAag-y9X_bnlTS_1Z", + "abs_ilm": "https://drive.google.com/open?id=1FBY9DR60WWX05orILaFHuyZYlB4ChTpS", + "abs_lmscratch": "https://drive.google.com/open?id=103Cw2ZSb5g5PlTKslmbmhqCaxn3N65OO", + "abs_lmrevscratch": "https://drive.google.com/open?id=1HeuxA2A6iEs6SW26jlCom3x_tFQHnIGu", + "abs_lmallscratch": "https://drive.google.com/open?id=1XU61GMduqJeCzYqDk8BQ7S4M8tbzqF9g", + "abs_ilmscratch": "https://drive.google.com/open?id=1ZTZOO5fVTlnPBw7EC_4OOEzHmcs6tAFO", + # Trained on lyrics + "lyr_lm": "https://drive.google.com/open?id=1FJBgz26lZPcanZTEf0iWxZCXIEM6esu6", + "lyr_lmrev": "https://drive.google.com/open?id=1XAug1jhm7sa5lksDV6GMyF8sFQLwk1Y6", + "lyr_lmall": "https://drive.google.com/open?id=1nrNkd4cBsdZS0eajA3wD1i5b6t6R6bow", + "lyr_ilm": "https://drive.google.com/open?id=1nYuYCS5fDP2_vB7A92guk0PWh5CC2I5x", + "lyr_lmscratch": "https://drive.google.com/open?id=1JzDRUSWVeyGnNaWKVYM8t1BPAs58t6uB", + "lyr_lmrevscratch": "https://drive.google.com/open?id=1Kkli5Brmc3D6qE0b5ww5daZdZroaN1YB", + "lyr_lmallscratch": "https://drive.google.com/open?id=18JYIBOtDfnksZPl4TW9cOzjOh_qDBCJP", + "lyr_ilmscratch": "https://drive.google.com/open?id=1RObPpSttNtMw4UQ1bGiVzEM-94QqkwHT", +} + +PRETRAINED_MODEL_CONFIG_JSON = "https://drive.google.com/open?id=15JnXi7L6LeEB2fq4dFK2WRvDKyX46hVi" +PRETRAINED_SPECIAL_VOCAB_PKL = "https://drive.google.com/open?id=1nTQVe2tfkWV8dumbrLIHzMgPwpLIbYUd" + +PAPER_TASK_TO_INTERNAL = { + "lm": "lm", + "lmrev": "reverse_lm", + "lmall": "naive", + "ilm": "ilm", +} + +_DOWNLOAD_TEMPLATE = """ +wget -nc --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget \ + --quiet --save-cookies /tmp/cookies.txt \ + --keep-session-cookies \ + --no-check-certificate \ + 'https://docs.google.com/uc?export=download&id={gdrive_id}' \ + -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\\1\\n/p')&id={gdrive_id}" \ + -O {local_path} && rm -rf /tmp/cookies.txt +""".strip() + + +if __name__ == "__main__": + import os + import sys + from pathlib import Path + + MODEL_DIR = Path(os.environ.get("MODEL_DIR", "/data/")) + + if sys.argv[1] == "model": + data_tag, model_type = sys.argv[2:] + model_tag = "{}_{}".format(data_tag[:3], model_type) + gdrive_urls = [PRETRAINED_MODELS[model_tag], PRETRAINED_MODEL_CONFIG_JSON, PRETRAINED_SPECIAL_VOCAB_PKL] + local_fns = ["pytorch_model.bin", "config.json", "additional_ids_to_tokens.pkl"] + elif sys.argv[1] == "data_train": + data_tag = sys.argv[2][:3] + out_dir = MODEL_DIR.joinpath("data") + gdrive_urls = [PREMASKED_DATA[s]["{}_mixture".format(data_tag)] for s in ["train", "valid"]] + local_fns = ["{}_mixture_{}.pkl".format(data_tag, s) for s in ["train", "valid"]] + elif sys.argv[1] == "data_eval": + data_tag = sys.argv[2][:3] + out_dir = MODEL_DIR.joinpath("data") + gdrive_urls = [ + PREMASKED_DATA["test"]["{}_{}".format(data_tag, g)] + for g in ["mixture", "document", "paragraph", "sentence", "ngram", "word"] + ] + local_fns = [ + "{}_{}_test.pkl".format(data_tag, g) + for g in ["mixture", "document", "paragraph", "sentence", "ngram", "word"] + ] + + print("mkdir -p {}".format(str(MODEL_DIR))) + for gdrive_url, local_fn in zip(gdrive_urls, local_fns): + print( + _DOWNLOAD_TEMPLATE.format(gdrive_id=gdrive_url.split("=")[1], local_path=str(MODEL_DIR.joinpath(local_fn))) + ) diff --git a/services/infilling/encoder.py b/services/infilling/encoder.py new file mode 100644 index 0000000000..9c1e96a8db --- /dev/null +++ b/services/infilling/encoder.py @@ -0,0 +1,106 @@ +"""Byte pair encoding utilities""" + +import regex as re +from functools import lru_cache + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) + cs = bs[:] + n = 0 + for b in range(2 ** 8): + if b not in bs: + bs.append(b) + cs.append(2 ** 8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +class Encoder: + def __init__(self, encoder, bpe_merges, errors="replace"): + self.encoder = encoder + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except Exception: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ")) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = bytearray([self.byte_decoder[c] for c in text]).decode("utf-8", errors=self.errors) + return text diff --git a/services/infilling/infer.py b/services/infilling/infer.py new file mode 100644 index 0000000000..6fdfabdfe8 --- /dev/null +++ b/services/infilling/infer.py @@ -0,0 +1,107 @@ +import copy + +import torch +import torch.nn.functional as F + + +def sample_from_logits(logits, temp=1.0, topk=None, nucleus=1.0): + if temp == 0: + return torch.argmax(logits, dim=-1).unsqueeze(-1) + elif temp != 1: + logits /= temp + + probs = F.softmax(logits, dim=-1) + + if topk is not None: + top_probs = torch.topk(probs, topk) + mask = F.one_hot(top_probs.indices, probs.shape[-1]).float() + mask = mask.sum(dim=1) + probs *= mask + probs /= probs.sum(dim=-1) + + if nucleus != 1: + probs_sorted = torch.sort(probs, descending=True, dim=-1) + # sorted_indices = probs_sorted.indices + sorted_values = probs_sorted.values + + cumsum = torch.cumsum(sorted_values, dim=-1) + ks = (cumsum < nucleus).long().sum(dim=-1) + ks = torch.max(ks, torch.ones_like(ks)) + + # TODO: Make this more efficient using gather + ks = F.one_hot(ks, probs.shape[-1]).float() + cutoffs = (sorted_values * ks).sum(-1) + + mask = (probs > cutoffs.unsqueeze(1)).float() + probs *= mask + + probs /= probs.sum(keepdim=True, dim=-1) + + next_tokens = torch.multinomial(probs, num_samples=1) + + return next_tokens + + +def infill_with_ilm(model, special_tokens_to_ids, x, num_infills=1, max_sequence_length=256, nucleus=0.95): + + _sep_id = special_tokens_to_ids["<|startofinfill|>"] + _end_span_id = special_tokens_to_ids["<|endofinfill|>"] + _special_ids = special_tokens_to_ids.values() + + # Make sure example doesn't already ends with [sep] + if x[-1] == _sep_id: + x = x[:-1] + + # Count number of blanks + blank_idxs = [] + for i, tok_id in enumerate(x): + if tok_id in _special_ids: + blank_idxs.append(i) + k = len(blank_idxs) + if k == 0: + raise ValueError() + + # Decode until we have that many blanks + with torch.no_grad(): + device = next(model.parameters()).device + context = torch.tensor(x + [_sep_id], dtype=torch.long, device=device).unsqueeze(0).repeat(num_infills, 1) + + terminated = [] + + while context.shape[0] > 0: + logits = model(context)[0][:, -1] + next_tokens = sample_from_logits(logits, nucleus=nucleus) + context = torch.cat((context, next_tokens), dim=1) + + num_predicted_spans = (context == _end_span_id).long().sum(dim=1) + + terminate_expected = num_predicted_spans >= k + terminate_toolong = torch.ones_like(context).long().sum(dim=1) >= max_sequence_length + terminate = terminate_expected | terminate_toolong + + if torch.any(terminate): + terminated_seqs = context[terminate, len(x) + 1 :] + terminated.extend([list(s) for s in terminated_seqs.cpu().numpy()]) + context = context[~terminate, :] + + # Collect generated spans + generated_spans = [] + for gen in terminated: + spans = [] + while _end_span_id in gen: + spans.append(gen[: gen.index(_end_span_id)]) + gen = gen[gen.index(_end_span_id) + 1 :] + while len(spans) < k: + spans.append([]) + generated_spans.append(spans) + + # Insert into context + generated = [] + for spans in generated_spans: + context = copy.deepcopy(x) + for i, j in enumerate(blank_idxs[::-1]): + del context[j] + context[j:j] = spans[k - 1 - i] + generated.append(context) + + return generated diff --git a/services/infilling/paths.py b/services/infilling/paths.py new file mode 100644 index 0000000000..cfb8dac97f --- /dev/null +++ b/services/infilling/paths.py @@ -0,0 +1,9 @@ +import os +import pathlib + +_LIB_DIR = os.path.dirname(os.path.join(pathlib.Path(__file__).absolute())) +_REPO_DIR = os.path.dirname(_LIB_DIR) + +OFFICIAL_GPT2_ENCODER_DIR = os.path.join(_LIB_DIR, "official_gpt2_encoder") + +RAW_DATA_DIR = os.path.join(_REPO_DIR, "data", "raw_data") diff --git a/services/infilling/requirements.txt b/services/infilling/requirements.txt new file mode 100644 index 0000000000..280b89a829 --- /dev/null +++ b/services/infilling/requirements.txt @@ -0,0 +1,10 @@ +transformers==4.0.1 +sentencepiece==0.1.94 +flask==1.1.1 +itsdangerous==2.0.1 +gunicorn==19.9.0 +requests==2.22.0 +sentry-sdk[flask]==0.14.1 +healthcheck==1.3.3 +jinja2<=3.0.3 +Werkzeug<=2.0.3 diff --git a/services/infilling/server.py b/services/infilling/server.py new file mode 100644 index 0000000000..9d72ecb5a0 --- /dev/null +++ b/services/infilling/server.py @@ -0,0 +1,87 @@ +import logging +import time +import os +import pickle +from pathlib import Path + +import sentry_sdk +import tokenize_util +import torch +from flask import Flask, request, jsonify +from infer import infill_with_ilm +from sentry_sdk.integrations.flask import FlaskIntegration +from transformers import GPT2LMHeadModel + + +sentry_sdk.init(dsn=os.getenv("SENTRY_DSN"), integrations=[FlaskIntegration()]) + +logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO) +logger = logging.getLogger(__name__) + +MODEL_DIR = os.environ.get("MODEL_DIR", "/data/") +logging.info(f"MODEL_DIR = {MODEL_DIR}") +# MASK_ID = 103 + +try: + if torch.cuda.is_available(): + device = torch.device("cuda") + else: + device = torch.device("cpu") + + logger.info(f"infilling is set to run on {device}") + + # init model + tokenizer = tokenize_util.Tokenizer.GPT2 + with open(Path(MODEL_DIR).joinpath("additional_ids_to_tokens.pkl"), "rb") as f: + additional_ids_to_tokens = pickle.load(f) + additional_tokens_to_ids = {v: k for k, v in additional_ids_to_tokens.items()} + try: + tokenize_util.update_tokenizer(additional_ids_to_tokens, tokenizer) + except ValueError: + logger.info("Tokenizer already updated") + logger.info(additional_tokens_to_ids) + model = GPT2LMHeadModel.from_pretrained(MODEL_DIR) + model.eval() + if torch.cuda.is_available(): + additional_tokens_to_ids = { + k: torch.tensor(v, dtype=torch.int).cuda() for k, v in additional_tokens_to_ids.items() + } + model.to(device) + + logger.info("infilling model is ready") +except Exception as e: + sentry_sdk.capture_exception(e) + logger.exception(e) + raise e + +app = Flask(__name__) + + +@app.route("/respond", methods=["POST"]) +def respond(): + st_time = time.time() + + texts = request.json.get("texts", []) + logger.info(f"Input: {texts}") + try: + output = [] + for txt in texts: + inputs = tokenize_util.encode(txt, tokenizer) + _blank_id = tokenize_util.encode(" _", tokenizer)[0] + flag = 0 + while not flag: # надо исправить костыль + try: + inputs[inputs.index(_blank_id)] = additional_tokens_to_ids["<|infill_ngram|>"] + except Exception: + flag = 1 + generated = infill_with_ilm(model, additional_tokens_to_ids, inputs, num_infills=1) + output.append(tokenize_util.decode(generated[0], tokenizer)) + except Exception as exc: + logger.exception(exc) + sentry_sdk.capture_exception(exc) + output = [""] * len(texts) + + logger.info(f"Output: {output}") + total_time = time.time() - st_time + logger.info(f"infilling exec time: {total_time:.3f}s") + return jsonify({"infilled_text": output}) diff --git a/services/infilling/test.py b/services/infilling/test.py new file mode 100644 index 0000000000..bbf609f8df --- /dev/null +++ b/services/infilling/test.py @@ -0,0 +1,23 @@ +import requests + + +def test_respond(): + url = "http://0.0.0.0:8122/respond" + + texts = ["Chris was bad at _.", "Chris was _ so he could not come."] + + request_data = {"texts": texts} + + result = requests.post(url, json=request_data).json() + + assert result["infilled_text"][0].startswith("Chris was bad at"), print( + f"Got\n{result}\n, but had to be starting with 'Chris was bad at math'" + ) + assert result["infilled_text"][1].startswith("Chris was") and result["infilled_text"][1].endswith( + "so he could not come." + ), print(f"Got\n{result}\n, but had to be ending with 'so he could not come.'") + print("Success") + + +if __name__ == "__main__": + test_respond() diff --git a/services/infilling/test.sh b/services/infilling/test.sh new file mode 100755 index 0000000000..27759385a2 --- /dev/null +++ b/services/infilling/test.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +python3 test.py diff --git a/services/infilling/tokenize_util.py b/services/infilling/tokenize_util.py new file mode 100644 index 0000000000..ccf2a41b8e --- /dev/null +++ b/services/infilling/tokenize_util.py @@ -0,0 +1,266 @@ +import json +import os +import regex as re +import warnings +from enum import Enum +from functools import lru_cache +from pathlib import Path + +from encoder import Encoder as OfficialEncoder + + +MODEL_DIR = os.environ.get("MODEL_DIR", "/data/") + + +class Tokenizer(Enum): + CUSTOM = 0 + GPT2 = 1 + + +DEFAULT_TOKENIZER = Tokenizer.GPT2 +_CUSTOM_ID_TO_TOKEN = None + + +def set_custom_vocab_fp(vocab_fp): + global _CUSTOM_ID_TO_TOKEN + with open(vocab_fp, "r") as f: + _CUSTOM_ID_TO_TOKEN = f.read().strip().splitlines() + + +_TOKENIZER_TO_STATE = {} + + +def _get_tokenizer_state(tokenizer): + if type(tokenizer) == str: + try: + tokenizer = Tokenizer[tokenizer.upper()] + except Exception: + raise ValueError("Unknown tokenizer specified") + + if type(tokenizer) != Tokenizer: + raise ValueError("Tokenizer must be from Tokenizer enum") + + if tokenizer not in _TOKENIZER_TO_STATE: + if tokenizer == Tokenizer.GPT2: + with open(Path(MODEL_DIR).joinpath("encoder.json"), "r") as f: + encoder_json = json.load(f) + with open(Path(MODEL_DIR).joinpath("vocab.bpe"), "r", encoding="utf-8") as f: + bpe_data = f.read() + bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split("\n")[1:-1]] + official_encoder = OfficialEncoder(encoder=encoder_json, bpe_merges=bpe_merges) + _TOKENIZER_TO_STATE[tokenizer] = official_encoder + elif tokenizer == Tokenizer.CUSTOM: + if _CUSTOM_ID_TO_TOKEN is None: + raise Exception("Must call set_custom_vocab_fp first") + CUSTOM_TOKEN_TO_ID = {v: k for k, v in enumerate(_CUSTOM_ID_TO_TOKEN)} + if len(_CUSTOM_ID_TO_TOKEN) != len(CUSTOM_TOKEN_TO_ID): + raise ValueError("Duplicate tokens") + _TOKENIZER_TO_STATE[tokenizer] = (_CUSTOM_ID_TO_TOKEN, CUSTOM_TOKEN_TO_ID) + else: + assert False + + return _TOKENIZER_TO_STATE[tokenizer] + + +def update_tokenizer(additional_ids_to_tokens, tokenizer=DEFAULT_TOKENIZER): + state = _get_tokenizer_state(tokenizer) + + additional_tokens_to_ids = {v: k for k, v in additional_ids_to_tokens.items()} + if len(additional_ids_to_tokens) != len(additional_tokens_to_ids): + raise ValueError() + + if tokenizer == Tokenizer.GPT2: + vocab_size_before = len(state.encoder) + state.encoder.update(additional_tokens_to_ids) + state.decoder.update(additional_ids_to_tokens) + vocab_size_after = len(state.encoder) + elif tokenizer == Tokenizer.CUSTOM: + raise NotImplementedError() + else: + assert False + + if vocab_size_after != (vocab_size_before + len(additional_ids_to_tokens)): + raise ValueError() + + return vocab_size_after + + +def tokenize(s, tokenizer=DEFAULT_TOKENIZER): + state = _get_tokenizer_state(tokenizer) + + if tokenizer == Tokenizer.GPT2: + tokens_regex = re.findall(state.pat, s) + tokens_ids = [] + for token in tokens_regex: + token = "".join(state.byte_encoder[b] for b in token.encode("utf-8")) + token_ids = [state.encoder[bpe_token] for bpe_token in state.bpe(token).split(" ")] + tokens_ids.extend(token_ids) + raw_tokens = [state.decoder[token_id] for token_id in tokens_ids] + tokens = [ + bytearray([state.byte_decoder[c] for c in token]).decode("utf-8", errors=state.errors) + for token in raw_tokens + ] + elif tokenizer == Tokenizer.CUSTOM: + tokens = s.strip().split() + else: + assert False + + return tokens + + +def tokens_to_ids(tokens, tokenizer=DEFAULT_TOKENIZER): + state = _get_tokenizer_state(tokenizer) + + if tokenizer == Tokenizer.GPT2: + tokens_ids = [] + for token in tokens: + token = "".join(state.byte_encoder[b] for b in token.encode("utf-8")) + tokens_ids.extend(state.encoder[bpe_token] for bpe_token in state.bpe(token).split(" ")) + elif tokenizer == Tokenizer.CUSTOM: + tokens_ids = [state[1][t] for t in tokens] + else: + assert False + + if len(tokens_ids) != len(tokens): + raise Exception("Token ids not equal in length to tokens") + + return tokens_ids + + +def ids_to_tokens(tokens_ids, tokenizer=DEFAULT_TOKENIZER): + state = _get_tokenizer_state(tokenizer) + + if tokenizer == Tokenizer.GPT2: + tokens = [state.decoder[token_id] for token_id in tokens_ids] + tokens = [ + bytearray([state.byte_decoder[c] for c in token]).decode("utf-8", errors=state.errors) for token in tokens + ] + elif tokenizer == Tokenizer.CUSTOM: + tokens = [state[0][t] for t in tokens_ids] + else: + assert False + + if len(tokens) != len(tokens_ids): + raise Exception("Tokens not equal in length to token ids") + + return tokens + + +def detokenize(tokens, tokenizer=DEFAULT_TOKENIZER): + if tokenizer == Tokenizer.GPT2: + s = "".join(tokens) + elif tokenizer == Tokenizer.CUSTOM: + s = " ".join(tokens) + else: + assert False + + return s + + +def encode(s, tokenizer=DEFAULT_TOKENIZER): + return tokens_to_ids(tokenize(s, tokenizer=tokenizer), tokenizer=tokenizer) + + +def decode(tokens_ids, tokenizer=DEFAULT_TOKENIZER): + return detokenize(ids_to_tokens(tokens_ids, tokenizer=tokenizer), tokenizer=tokenizer) + + +def vocab_size(tokenizer=DEFAULT_TOKENIZER): + state = _get_tokenizer_state(tokenizer) + + if tokenizer == Tokenizer.GPT2: + vocab_size = len(state.encoder) + elif tokenizer == Tokenizer.CUSTOM: + vocab_size = len(state[0]) + else: + assert False + + return vocab_size + + +@lru_cache(maxsize=128) +def _tokens_offsets_and_residuals_memoized(x, x_tok): + x_remaining_off = 0 + x_remaining = x[:] + + offsets = [] + residuals = [] + + for i, t in enumerate(x_tok): + if len(t) == 0: + warnings.warn("Encountered empty token") + + try: + t_off_in_x_remaining = x_remaining.index(t) + t_res = x_remaining[:t_off_in_x_remaining] + t_off = x_remaining_off + t_off_in_x_remaining + except Exception: + t_off = None + t_res = "" + + offsets.append(t_off) + residuals.append(t_res) + + if t_off is not None: + trim = t_off_in_x_remaining + len(t) + x_remaining_off += trim + x_remaining = x_remaining[trim:] + + rres = x_remaining + + return offsets, residuals, rres + + +def tokens_offsets(x, x_tok): + if type(x_tok) != tuple: + x_tok = tuple(x_tok) + return _tokens_offsets_and_residuals_memoized(x, x_tok)[0] + + +def tokens_residuals(x, x_tok): + if type(x_tok) != tuple: + x_tok = tuple(x_tok) + return _tokens_offsets_and_residuals_memoized(x, x_tok)[1:] + + +def align_charspan_to_tokenspan(x, x_tok, char_offset, char_len): + if len(x_tok) == 0: + raise ValueError() + if char_offset < 0 or char_len < 0 or (char_offset + char_len) > len(x): + raise ValueError() + + if type(x_tok) != tuple: + x_tok = tuple(x_tok) + x_tok_offsets, x_tok_residuals, x_tok_rres = _tokens_offsets_and_residuals_memoized(x, x_tok) + if None in x_tok_offsets: + raise ValueError() + x_tok_residuals.append(x_tok_rres) + x_tok_lens = [len(t) for t in x_tok] + + # Build char_idx_to_token of appropriate token for each cursor index + # NOTE: This is one greater than len(x) because cursor can be at beginning or end. + char_idx_to_token = [0] * len(x_tok_residuals[0]) + for i in range(len(x_tok)): + char_idx_to_token += [i] * (x_tok_lens[i] + len(x_tok_residuals[i + 1])) + char_idx_to_token += [len(x_tok) - 1] + + if char_len == 0: + token_offset = char_idx_to_token[char_offset] + token_len = 0 + char_offset = x_tok_offsets[token_offset] + char_len = 0 + else: + selected_x_tok = set(char_idx_to_token[char_offset : char_offset + char_len]) + token_offset = min(selected_x_tok) + token_len = max(selected_x_tok) - token_offset + 1 + + char_offset = x_tok_offsets[token_offset] + token_end = token_offset + token_len - 1 + char_end = x_tok_offsets[token_end] + x_tok_lens[token_end] + char_len = char_end - char_offset + + return ( + char_offset, + char_len, + token_offset, + ) diff --git a/tests/runtests.sh b/tests/runtests.sh index d9ef03300c..f8424a7ef4 100755 --- a/tests/runtests.sh +++ b/tests/runtests.sh @@ -149,7 +149,7 @@ if [[ "$MODE" == "test_skills" || "$MODE" == "all" ]]; then dff-gossip-skill dff-wiki-skill topic-recommendation dff-science-skill personal-info-skill \ user-persona-extractor small-talk-skill wiki-facts dff-art-skill dff-funfact-skill \ meta-script-skill spelling-preprocessing dff-gaming-skill dialogpt \ - dff-music-skill dff-bot-persona-skill entity-detection midas-predictor; do + dff-music-skill dff-bot-persona-skill entity-detection midas-predictor infilling; do echo "Run tests for $container" dockercompose_cmd exec -T -u $(id -u) $container ./test.sh