diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..23adf76 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,17 @@ +FROM python:3.7.3-slim-stretch + +RUN apt-get -y update && apt-get -y install gcc + +WORKDIR / +COPY checkpoint /checkpoint + +# Make changes to the requirements/app here. +# This Dockerfile order allows Docker to cache the checkpoint layer +# and improve build times if making changes. +RUN pip3 --no-cache-dir install tensorflow gpt-2-simple starlette uvicorn ujson +COPY app.py / + +# Clean up APT when done. +RUN apt-get clean && rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/* + +ENTRYPOINT ["python3", "-X", "utf8", "app.py"] diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..4f47092 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Max Woolf + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..e9e02f0 --- /dev/null +++ b/README.md @@ -0,0 +1,15 @@ +# mtg-gpt-2-cloud-run + +## Maintainer/Creator + +Max Woolf ([@minimaxir](https://minimaxir.com)) + +*Max's open-source projects are supported by his [Patreon](https://www.patreon.com/minimaxir). If you found this project helpful, any monetary contributions to the Patreon are appreciated and will be put to good creative use.* + +## License + +MIT + +## Disclaimer + +This repo has no affiliation or relationship with OpenAI. \ No newline at end of file diff --git a/api_ui.html b/api_ui.html new file mode 100644 index 0000000..7061ed2 --- /dev/null +++ b/api_ui.html @@ -0,0 +1,287 @@ + + + + + + + + + + AI-Generated Reddit Submission titles with GPT-2 + + + + + + + + + + + +
+
+
+
+
+
+ +
+ +
+

Reddit subreddit to generate text from. Not case-sensitive.

+
+
+ +
+ +
+

Starts the generated title with the specified text. (Optional: max 100 characters)

+
+
+ +
+ + + +
+

Keywords/phrases to base the generated title upon. Use case-sensitive inputs for better results. (Optional)

+
+
+ +
+ +
+

Number of titles to generate with the given parameters. Titles are generated in parallel. (min 1, max 5)

+
+
+ + + + +
+
+
+ + + + +
+

Generated text will appear here! + Use the form to configure GPT-2 and press Generate Text + to get your Reddit submission title! +

+
+
+
+ + +
+ + + + + + + + + + + + + \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000..9a9faa2 --- /dev/null +++ b/app.py @@ -0,0 +1,102 @@ +from starlette.applications import Starlette +from starlette.responses import UJSONResponse +import gpt_2_simple as gpt2 +import tensorflow as tf +import uvicorn +import os +import re + + +MIN_LENGTH = 50 +MAX_LENGTH = 200 +STEP_LENGTH = 50 + +INVALID_SUBREDDITS = set([ + "me_irl", + "2meirl4meirl", + "anime_irl", + "furry_irl", + "cursedimages", + "meirl", + "hmmm", + "ooer" +]) + +app = Starlette(debug=False) + +sess = gpt2.start_tf_sess(threads=1) +gpt2.load_gpt2(sess) + +# Needed to avoid cross-domain issues +response_header = { + 'Access-Control-Allow-Origin': '*' +} + +generate_count = 0 + + +@app.route('/', methods=['GET', 'POST', 'HEAD']) +async def homepage(request): + global generate_count + global sess + + if request.method == 'GET': + params = request.query_params + elif request.method == 'POST': + params = await request.json() + elif request.method == 'HEAD': + return UJSONResponse({'text': ''}, + headers=response_header) + + subreddit = params.get('subreddit', '').lower().strip() + + if subreddit == '': + subreddit = 'askreddit' + + if subreddit in INVALID_SUBREDDITS: + return UJSONResponse({'text': 'ಠ_ಠ'}, + headers=response_header) + + keywords = " ".join([v.replace(' ', '-').strip() for k, v in params.items() + if 'key' in k and v != '']) + + prepend = "<|startoftext|>~`{}~^{}~@".format(subreddit, keywords) + text = prepend + params.get('prefix', '')[:100] + + length = MIN_LENGTH + + while '<|endoftext|>' not in text and length <= MAX_LENGTH: + text = gpt2.generate(sess, + length=STEP_LENGTH, + temperature=0.7, + top_k=40, + prefix=text, + include_prefix=True, + return_as_list=True + )[0] + length += STEP_LENGTH + + generate_count += 1 + if generate_count == 8: + # Reload model to prevent Graph/Session from going OOM + tf.reset_default_graph() + sess.close() + sess = gpt2.start_tf_sess(threads=1) + gpt2.load_gpt2(sess) + generate_count = 0 + + prepend_esc = re.escape(prepend) + eot_esc = re.escape('<|endoftext|>') + + if '<|endoftext|>' not in text: + pattern = '(?:{})(.*)'.format(prepend_esc) + else: + pattern = '(?:{})(.*)(?:{})'.format(prepend_esc, eot_esc) + + trunc_text = re.search(pattern, text) + + return UJSONResponse({'text': trunc_text.group(1)}, + headers=response_header) + +if __name__ == '__main__': + uvicorn.run(app, host='0.0.0.0', port=int(os.environ.get('PORT', 8080)))