In [1]:
import sys, os
from string import Template

# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
    ExLlamaV2Lora,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler,
)

import time

# Initialize model and cache

model_directory = (
    "/mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ"
)
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()

model = ExLlamaV2(config)
print("Loading model: " + model_directory)
model.load()

tokenizer = ExLlamaV2Tokenizer(config)

cache = ExLlamaV2Cache(model)

# Load LoRA

lora_directory = "/mnt/Woo/text-generation-webui/loras/Wizard-Vicuna-13B-Uncensored-GPTQ-reddit-submissions"
lora = ExLlamaV2Lora.from_directory(model, lora_directory)

# Initialize generators

streaming_generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
streaming_generator.warmup()

streaming_generator.set_stop_conditions(["\nUser:", tokenizer.eos_token_id])

simple_generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# Sampling settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.98
settings.top_p = 0.37
settings.token_repetition_penalty = 1.18

Loading model: /mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ


In [2]:
def generate_with_lora(prompt_, lora_, max_new_tokens):
    #print(prompt_, end="")
    #sys.stdout.flush()

    input_ids = tokenizer.encode(prompt_)

    streaming_generator.begin_stream(input_ids, settings, loras=lora_)
    generated_tokens = 0
    output = ""
    while True:
        chunk, eos, _ = streaming_generator.stream()
        generated_tokens += 1
        output += chunk
        #print(chunk, end="")
        #sys.stdout.flush()
        if eos or generated_tokens == max_new_tokens:
            break

    #print()

    #print("done")
    return output

In [3]:
from string import Template
import string

POST_TEMPLATE = Template(
    """You are a Reddit post generator.
User: 
Subreddit: $subreddit 
Author: $author 
Media: $media 
Title: $title 
Write the Reddit post.
Assistant:"""
)

def contains_letter(s):
    return any(c.isalpha() for c in s)

def gen_valid_first_character(include_digits=True):
    if include_digits and random.random() < 0.5:
        return random.choice(string.digits)
    return random.choice(string.ascii_letters)


VALID_MEDIA = ["image", "video", "text", "article"] # articles are links to external sites
TAGS = {
    "Subreddit": {
        "TAG": "[SUBREDDIT]",
        "CONSTRAINT": lambda x: x.startswith("/r/") and len(x) > len("/r/") + 3 and contains_letter(x) and len(x) <= len("/r/") + 21,
        "HELPER": lambda x: f"/r/{gen_valid_first_character(include_digits=False)}",
    },
    "Author": {
        "TAG": "[AUTHOR]",
        "CONSTRAINT": lambda x: len(x) >= 3 and contains_letter(x) and len(x) <= 23,
        "HELPER": lambda x: gen_valid_first_character(include_digits=True),
    },
    "Media": {
        "TAG": "[MEDIA]",
        "CONSTRAINT": lambda x: x in VALID_MEDIA,
        "HELPER": lambda x: x,
    },
    "Title": {
        "TAG": "[TITLE]",
        "CONSTRAINT": lambda x: len(x) > 0 and contains_letter(x) and len(x) <= 300,
        "HELPER": lambda x: x,
    },
    "EOS": {
        "TAG": "Write the Reddit post.\nAssistant:",
        "CONSTRAINT": lambda x: x,
        "HELPER": lambda x: x,
    },
}

POST_TAGS_ORDER_OP = [
    TAGS["Subreddit"]["TAG"],
    TAGS["Author"]["TAG"],
    TAGS["Media"]["TAG"],
    TAGS["Title"]["TAG"],
    TAGS["EOS"]["TAG"],
]  # order matters

In [5]:
import random

def get_up_to_tag_line(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        print("Error: tag not found")
        return -1

    # get last \n before tag
    last_newline_index = prompt.rfind("\n", 0, sub_index)

    if last_newline_index == -1:
        last_newline_index = 0
    # remove  from last_newline_index to newline_index
    prompt = prompt[:last_newline_index]
    return prompt


def get_up_to_tag(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        return -1
    prompt = prompt[:sub_index]
    return prompt



def generate_post(
    **options
):
    default_options = {
        "subreddit": TAGS["Subreddit"]["TAG"],
        "author": TAGS["Author"]["TAG"],
        "media": TAGS["Media"]["TAG"],
        "title": TAGS["Title"]["TAG"],
    }
    options = {**default_options, **options}
    prompt_full = POST_TEMPLATE.substitute(
        options
    )

    # loop over tags in order
    for tag in POST_TAGS_ORDER_OP:
        if tag == TAGS["EOS"]["TAG"]:
            # add EOS tag
            #prompt_full = f"{prompt_full}\n{tag}"
            break
        prompt = get_up_to_tag(prompt_full, tag)
        if prompt == -1:
            continue  # skip tag if not found
        #print(prompt)

        # get tag key from tag value
        for temp_tag_key, tag_obj in TAGS.items():
            if tag_obj["TAG"] == tag:
                tag_key = temp_tag_key
                next_tag_value = POST_TAGS_ORDER_OP[POST_TAGS_ORDER_OP.index(tag) + 1]
                continue
            if tag_obj["TAG"] == next_tag_value:
                next_tag_key = temp_tag_key
                break

        # print("Tag: " + tag_key)
        #print(f"""generate from here to: '\\n'""")
        #print()
        streaming_generator.set_stop_conditions(
            ["\n", "\nUser:", tokenizer.eos_token_id]
        )
        valid_output = False
        while not valid_output:
            helper_starter_char = TAGS[tag_key]["HELPER"]("")
            if tag == "[MEDIA]":
                # random valid media
                output = random.choice(VALID_MEDIA)
            else:
                output = generate_with_lora(f"{prompt}{helper_starter_char}", lora, 250)
            output = output.strip()
            output = helper_starter_char + output
            #print(f"generated output: {output}")
            valid_output = TAGS[tag_key]["CONSTRAINT"](output)
                
        # replace tag with output
        prompt_full = prompt_full.replace(tag, output)
        #prompt = f"{prompt}{output} \n{next_tag_key}:}"

    #print(prompt_full)
    streaming_generator.set_stop_conditions(
            ["\nUser:", tokenizer.eos_token_id]
        )
    return prompt_full, generate_with_lora(prompt_full, lora, len(prompt_full) + 500)


prompt, post = generate_post()
print(prompt)
print(post)
# generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg", title="Cute doggo")
#generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg")

You are a Reddit post generator.
User: 
Subreddit: /r/EarthPorn 
Author: 1989_Girl 
Media: article 
Title: 20,000+ acres of California desert will be protected from development thanks to the state's largest conservation deal in decades 
Write the Reddit post.
Assistant:
 
California has just made history with its biggest land conservation deal ever — protecting over 20,000 acres (8,094 hectares) of pristine Mojave Desert wilderness from future development. The agreement between the state and two private companies is expected to help preserve some of America's most unique wildlife habitats while also providing new opportunities for outdoor recreation.
The project involves purchasing two large tracts of land north of Los Angeles that were previously owned by Tejon Ranch Company and Lennar Corporation. Together, these properties form what's known as the "Tehachapi Mountains" region, which includes rugged canyons, towering peaks, and other stunning natural features.
According to Governor G