In [1]:
import sys, os
from string import Template

# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
    ExLlamaV2Lora,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler,
)

import time

# Initialize model and cache

model_directory = (
    "/mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ"
)
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()

model = ExLlamaV2(config)
print("Loading model: " + model_directory)
model.load()

tokenizer = ExLlamaV2Tokenizer(config)
max_tokens=2048
cache = ExLlamaV2Cache(model, max_seq_len=max_tokens)

# Load LoRA

lora_submission_directory = "/mnt/Woo/text-generation-webui/loras/Wizard-Vicuna-13B-Uncensored-GPTQ-reddit-submissions"
lora_submission = ExLlamaV2Lora.from_directory(model, lora_submission_directory)

lora_comment_directory = "/mnt/Woo/text-generation-webui/loras/Wizard-Vicuna-13B-Uncensored-GPTQ-reddit-comments"
lora_comment = ExLlamaV2Lora.from_directory(model, lora_comment_directory)

# Initialize generators

streaming_generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
streaming_generator.warmup()

streaming_generator.set_stop_conditions(["\nUser:", tokenizer.eos_token_id])

simple_generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# Sampling settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.98
settings.top_p = 0.37
settings.token_repetition_penalty = 1.18

Loading model: /mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ


In [2]:
def generate_with_lora(prompt_, lora_, max_new_tokens):
    #print(prompt_, end="")
    #sys.stdout.flush()

    input_ids = tokenizer.encode(prompt_)
    prompt_length = len(input_ids[0])
    total_possible_tokens_left = max_tokens - prompt_length - 1
    if max_new_tokens + prompt_length > total_possible_tokens_left:
        # see if we can get away with less
        max_new_tokens = total_possible_tokens_left
        if max_new_tokens < 150:
            raise ValueError("Hit cache limit")

    streaming_generator.begin_stream(input_ids, settings, loras=lora_)
    generated_tokens = 0
    output = ""
    while True:
        chunk, eos, _ = streaming_generator.stream()
        generated_tokens += 1
        output += chunk
        #print(chunk, end="")
        #sys.stdout.flush()
        if eos or generated_tokens == max_new_tokens:
            if not eos:
                raise ValueError("Ran out of tokens on this prompt")
            break

    #print()

    #print("done")
    return output

In [3]:
from string import Template
import string

POST_TEMPLATE = Template(
    """You are a Reddit post generator.
User: 
Subreddit: $subreddit 
Author: $author 
Media: $media 
Title: $title 
Write the Reddit post.
Assistant:"""
)

def contains_letter(s):
    return any(c.isalpha() for c in s)

def gen_valid_first_character(include_digits=True):
    if include_digits and random.random() < 0.5:
        return random.choice(string.digits)
    return random.choice(string.ascii_letters)


VALID_MEDIA = ["image", "video", "text", "article"] # articles are links to external sites
TAGS = {
    "Subreddit": {
        "TAG": "[SUBREDDIT]",
        "CONSTRAINT": lambda x: x.startswith("/r/") and len(x) > len("/r/") + 3 and contains_letter(x) and len(x) <= len("/r/") + 21,
        "HELPER": lambda x: f"/r/{gen_valid_first_character(include_digits=False)}",
    },
    "Author": {
        "TAG": "[AUTHOR]",
        "CONSTRAINT": lambda x: len(x) >= 3 and contains_letter(x) and len(x) <= 23,
        "HELPER": lambda x: gen_valid_first_character(include_digits=True),
    },
    "Media": {
        "TAG": "[MEDIA]",
        "CONSTRAINT": lambda x: x in VALID_MEDIA,
        "HELPER": lambda x: x,
    },
    "Title": {
        "TAG": "[TITLE]",
        "CONSTRAINT": lambda x: len(x) > 0 and contains_letter(x) and len(x) <= 300,
        "HELPER": lambda x: x,
    },
    "EOS": {
        "TAG": "Write the Reddit post.\nAssistant:",
        "CONSTRAINT": lambda x: x,
        "HELPER": lambda x: x,
    },
}

POST_TAGS_ORDER_OP = [
    TAGS["Subreddit"]["TAG"],
    TAGS["Author"]["TAG"],
    TAGS["Media"]["TAG"],
    TAGS["Title"]["TAG"],
    TAGS["EOS"]["TAG"],
]  # order matters

In [4]:
import random

def get_up_to_tag_line(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        print("Error: tag not found")
        return -1

    # get last \n before tag
    last_newline_index = prompt.rfind("\n", 0, sub_index)

    if last_newline_index == -1:
        last_newline_index = 0
    # remove  from last_newline_index to newline_index
    prompt = prompt[:last_newline_index]
    return prompt


def get_up_to_tag(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        return -1
    prompt = prompt[:sub_index]
    return prompt



def generate_post(
    **options
):
    default_options = {
        "subreddit": TAGS["Subreddit"]["TAG"],
        "author": TAGS["Author"]["TAG"],
        "media": TAGS["Media"]["TAG"],
        "title": TAGS["Title"]["TAG"],
    }
    options = {**default_options, **options}
    prompt_full = POST_TEMPLATE.substitute(
        options
    )

    # loop over tags in order
    for tag in POST_TAGS_ORDER_OP:
        if tag == TAGS["EOS"]["TAG"]:
            # add EOS tag
            #prompt_full = f"{prompt_full}\n{tag}"
            break
        prompt = get_up_to_tag(prompt_full, tag)
        if prompt == -1:
            continue  # skip tag if not found
        #print(prompt)

        # get tag key from tag value
        for temp_tag_key, tag_obj in TAGS.items():
            if tag_obj["TAG"] == tag:
                tag_key = temp_tag_key
                next_tag_value = POST_TAGS_ORDER_OP[POST_TAGS_ORDER_OP.index(tag) + 1]
                continue
            if tag_obj["TAG"] == next_tag_value:
                next_tag_key = temp_tag_key
                break

        # print("Tag: " + tag_key)
        #print(f"""generate from here to: '\\n'""")
        #print()
        streaming_generator.set_stop_conditions(
            ["\n", "\nUser:", tokenizer.eos_token_id]
        )
        valid_output = False
        while not valid_output:
            helper_starter_char = TAGS[tag_key]["HELPER"]("")
            if tag == "[MEDIA]":
                # random valid media
                output = random.choice(VALID_MEDIA)
            else:
                output = generate_with_lora(f"{prompt}{helper_starter_char}", lora_submission, 250)
            output = output.strip()
            output = helper_starter_char + output
            options[tag_key.lower()] = output
            #print(f"generated output: {output}")
            valid_output = TAGS[tag_key]["CONSTRAINT"](output)
                
        # replace tag with output
        prompt_full = prompt_full.replace(tag, output)
        #prompt = f"{prompt}{output} \n{next_tag_key}:}"

    #print(prompt_full)
    streaming_generator.set_stop_conditions(
        ["\nUser:", tokenizer.eos_token_id]
    )
    options["postPrompt"] = prompt_full
    options["text"] = generate_with_lora(prompt_full, lora_submission, len(prompt_full) + 500)
    return options


postObj = generate_post()
print(postObj["postPrompt"])
print(postObj["text"])
# generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg", title="Cute doggo")
#generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg")

You are a Reddit post generator.
User: 
Subreddit: /r/gifs 
Author: mikemc2017 
Media: article 
Title: 33 years ago today, the Challenger exploded in front of millions of people on live TV. The disaster was captured by many cameras and has been immortalized as one of the most tragic moments in American history. Here's how it happened. 
Write the Reddit post.
Assistant:
 
On January 28th, 1986, NASA launched its space shuttle mission STS-51L with seven astronauts aboard, including Christa McAuliffe, who had won a competition to become the first teacher in space. The launch took place from Kennedy Space Center in Florida at 11:38 am EST.
The shuttle lifted off without any problems, but shortly after reaching an altitude of about 48,000 feet (about 14 kilometers), there was a loud bang heard over the radio transmission. Moments later, smoke began pouring out of the right side of the shuttle's exterior.
As the crew tried to regain control of the situation, they quickly realized that someth

In [9]:
from pprint import pprint
print(postObj)
# postObj: { subreddit, author, media, title, postPrompt, text }
def generate_comment(postObj, num_comments=random.randint(1, 5)):
    
    initialPrompt = f"""You are a Reddit user comment generator. In the conversation you change your Reddit username often to simulate different users.
User: You are on the subreddit {postObj["subreddit"]}.
Post title: {postObj["title"]}
Post media type: {postObj["media"]}
The post submission is about: {postObj["text"]}
The Original Poster(OP) username is: {postObj["author"]}
Your username is made up by you. Generate a comment in the format: USERNAME - COMMENT
Assistant: """

    folowupPrompt = f"""User: Generate a follow up comment in the format: USERNAME - COMMENT
Assistant: """ #  Undercoverotaku - Was looking for the comment pointing this out far too long.
    valid_output = False
    streaming_generator.set_stop_conditions(
        ["\nUser:", tokenizer.eos_token_id]
    )
    comments = []
    comment_first_char = gen_valid_first_character(include_digits=False)
    prompt_builder = initialPrompt + comment_first_char
    for i in range(num_comments):
        attempts = 0
        valid_output = False
        broke = False
        while not valid_output:
            print(len(prompt_builder))
            try:
                comment = f"{comment_first_char}{generate_with_lora(prompt_builder, lora_comment, 350)}"
            except ValueError as e:
                if "Hit cache limit" in str(e):
                    print("Hit cache limit")
                    broke = True
                    break
                comment = ""
            comment = comment.strip()
             # length > 0 and contains a - character
            if len(comment) > 0 and "-" in comment:
                valid_output = True
            else:
                attempts += 1
                if attempts > 5:
                    broke = True
                    break
        if broke:
            print('broke')
            break
        comments.append(comment)
        print(len(prompt_builder))
        comment_first_char = gen_valid_first_character(include_digits=False)
        prompt_builder = f"""{prompt_builder}{comment}\n{folowupPrompt}{comment_first_char}"""
    print(prompt_builder)
    return comments


comments = generate_comment(postObj, 222)
pprint(comments)

{'subreddit': '/r/gifs', 'author': 'mikemc2017', 'media': 'article', 'title': "33 years ago today, the Challenger exploded in front of millions of people on live TV. The disaster was captured by many cameras and has been immortalized as one of the most tragic moments in American history. Here's how it happened.", 'postPrompt': "You are a Reddit post generator.\nUser: \nSubreddit: /r/gifs \nAuthor: mikemc2017 \nMedia: article \nTitle: 33 years ago today, the Challenger exploded in front of millions of people on live TV. The disaster was captured by many cameras and has been immortalized as one of the most tragic moments in American history. Here's how it happened. \nWrite the Reddit post.\nAssistant:", 'text': " \nOn January 28th, 1986, NASA launched its space shuttle mission STS-51L with seven astronauts aboard, including Christa McAuliffe, who had won a competition to become the first teacher in space. The launch took place from Kennedy Space Center in Florida at 11:38 am EST.\nThe sh

2626
3037
3037
3454
3454
3688
3688
4181
4181
4317
4317
4705
4705
4871
4871
4987
4987
5167
5167
5399
5399
5745
5745
5919
5919
6072
6072
6196
6196
6355
6355
6488
6488
6625
6625
6764
6764
6909
6909
7057
7057
7203
Hit cache limit
broke
You are a Reddit user comment generator. In the conversation you change your Reddit username often to simulate different users.
User: You are on the subreddit /r/gifs.
Post title: 33 years ago today, the Challenger exploded in front of millions of people on live TV. The disaster was captured by many cameras and has been immortalized as one of the most tragic moments in American history. Here's how it happened.
Post media type: article
The post submission is about:  
On January 28th, 1986, NASA launched its space shuttle mission STS-51L with seven astronauts aboard, including Christa McAuliffe, who had won a competition to become the first teacher in space. The launch took place from Kennedy Space Center in Florida at 11:38 am EST.
The shuttle lifted off with

In [10]:
print(comments)

['Briar_Rabbit - I remember watching it happen while my mom was getting her chemotherapy treatment. She said she cried because she knew Christa would have died too if she went through what she did. It hit me hard then, even though I didn’t understand why. And now, looking back as an adult, I still get choked up thinking about it.', 'pinkpantherlives - My grandma passed away last year and i found some old VHS tapes she recorded when she worked at NASA during the Apollo missions. She used to work directly under Neil Armstrong. I watched them and saw the original footage of the moon landing. Its so weird seeing the way things looked back then. But also really cool.', 'DontFuckWithMyDoggie - What a great find! Did you keep them? Hopefully not in a box somewhere where someone will throw them out or sell them for profit!', "kelpys - I donated mine to the Smithsonian. They wanted them bad since they didn't have much video material of the actual landing itself. They got a lot of audio recordin