In [None]:
import sys, os
from string import Template

# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
    ExLlamaV2Lora,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler,
)

import time

# Initialize model and cache

model_directory = (
    "/mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ"
)
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()

model = ExLlamaV2(config)
print("Loading model: " + model_directory)
model.load()

tokenizer = ExLlamaV2Tokenizer(config)

cache = ExLlamaV2Cache(model)

# Load LoRA

lora_directory = "/mnt/Woo/text-generation-webui/loras/Wizard-Vicuna-13B-Uncensored-GPTQ-reddit-submissions"
lora = ExLlamaV2Lora.from_directory(model, lora_directory)

# Initialize generators

streaming_generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
streaming_generator.warmup()

streaming_generator.set_stop_conditions(["\nUser:", tokenizer.eos_token_id])

simple_generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# Sampling settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.98
settings.top_p = 0.37
settings.token_repetition_penalty = 1.18

In [None]:
def generate_with_lora(prompt_, lora_, max_new_tokens):
    #print(prompt_, end="")
    #sys.stdout.flush()

    input_ids = tokenizer.encode(prompt_)

    streaming_generator.begin_stream(input_ids, settings, loras=lora_)
    generated_tokens = 0
    output = ""
    while True:
        chunk, eos, _ = streaming_generator.stream()
        generated_tokens += 1
        output += chunk
        #print(chunk, end="")
        #sys.stdout.flush()
        if eos or generated_tokens == max_new_tokens:
            break

    #print()

    #print("done")
    return output

In [None]:
from string import Template

POST_TEMPLATE = Template(
    """You are a Reddit post generator.
User: 
Subreddit: $subreddit 
Author: $author 
Media: $media 
Title: $title 
Write the Reddit post.
Assistant:"""
)

TAGS = {
    "Subreddit": "[SUBREDDIT]",
    "Author": "[AUTHOR]",
    "Media": "[MEDIA]",
    "Title": "[TITLE]",
    "EOS": """Write the Reddit post.
Assistant:""",
}
def contains_letter(s):
    return any(c.isalpha() for c in s)
VALID_MEDIA = ["image", "video", "text", "article"] # articles are links to external sites
CONSTRAINTS_FUNCS = { # by tag key
    "Subreddit": lambda x: x.startswith("/r/") and len(x) > len("/r/"),
    "Author": lambda x: len(x) > 0 and contains_letter(x),
    "Media": lambda x: x in VALID_MEDIA,
    "Title": lambda x: len(x) > 0,
    "EOS": lambda x: x,

}

POST_TAGS_ORDER_OP = [
    TAGS["Subreddit"],
    TAGS["Author"],
    TAGS["Media"],
    TAGS["Title"],
    TAGS["EOS"],
]  # order matters

In [25]:
import random

def get_up_to_tag_line(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        print("Error: tag not found")
        return -1

    # get last \n before tag
    last_newline_index = prompt.rfind("\n", 0, sub_index)

    if last_newline_index == -1:
        last_newline_index = 0
    # remove  from last_newline_index to newline_index
    prompt = prompt[:last_newline_index]
    return prompt


def get_up_to_tag(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        return -1
    prompt = prompt[:sub_index]
    return prompt


def generate_post(
    subreddit=TAGS["Subreddit"],
    author=TAGS["Author"],
    media=TAGS["Media"],
    title=TAGS["Title"],
):
    prompt_full = POST_TEMPLATE.substitute(
        subreddit=subreddit, author=author, media=media, title=title
    )

    # loop over tags in order
    for tag in POST_TAGS_ORDER_OP:
        if tag == TAGS["EOS"]:
            # add EOS tag
            #prompt_full = f"{prompt_full}\n{tag}"
            break
        prompt = get_up_to_tag(prompt_full, tag)
        if prompt == -1:
            continue  # skip tag if not found
        #print(prompt)

        # get tag key from tag value
        tag_key = list(TAGS.keys())[list(TAGS.values()).index(tag)]
        next_tag_value = POST_TAGS_ORDER_OP[POST_TAGS_ORDER_OP.index(tag) + 1]
        next_tag_key = list(TAGS.keys())[list(TAGS.values()).index(next_tag_value)]
        # print("Tag: " + tag_key)
        #print(f"""generate from here to: '\\n'""")
        #print()
        streaming_generator.set_stop_conditions(
            ["\n", "\nUser:", tokenizer.eos_token_id]
        )
        valid_output = False
        while not valid_output:
            if tag == "[MEDIA]":
                # random valid media
                output = random.choice(VALID_MEDIA)
            else:
                output = generate_with_lora(prompt, lora, 250)
            output = output.strip()
            #print(f"generated output: {output}")
            valid_output = CONSTRAINTS_FUNCS[tag_key](output)
                
        # replace tag with output
        prompt_full = prompt_full.replace(tag, output)
        #prompt = f"{prompt}{output} \n{next_tag_key}:}"

    #print(prompt_full)
    streaming_generator.set_stop_conditions(
            ["\nUser:", tokenizer.eos_token_id]
        )
    return prompt_full, generate_with_lora(prompt_full, lora, len(prompt_full) + 500)


prompt, post = generate_post('/r/aww')
print(prompt)
print(post)
# generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg", title="Cute doggo")
#generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg")

You are a Reddit post generator.
User: 
Subreddit: /r/aww 
Author: 10_gallons_of_gasoline 
Media: article 
Title: 23-year-old man who was born with no hands is now the first person in Europe to receive a bionic hand that can do everything from playing piano to opening a bottle of beer 
Write the Reddit post.
Assistant:
 
A 23-year-old man has become the first person in Europe to receive a new type of bionic hand that not only looks like a real human hand but also performs all sorts of tasks, including playing the piano and even opening a bottle of beer. The young man, named Steven Spiers, was born without any hands due to a rare genetic condition called symbrachydactyly, which affects about one in every 30,000 births worldwide. But thanks to advancements in prosthetics technology, he's been able to get his life back on track by using a cutting-edge device known as the "LimbO." Developed by researchers at the University of Glasgow in Scotland, the LimbO features advanced sensors and mot