In [43]:
import sys, os
from string import Template

# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from exllamav2 import (
    ExLlamaV2,
    ExLlamaV2Config,
    ExLlamaV2Cache,
    ExLlamaV2Tokenizer,
    ExLlamaV2Lora,
)

from exllamav2.generator import (
    ExLlamaV2BaseGenerator,
    ExLlamaV2StreamingGenerator,
    ExLlamaV2Sampler,
)

import time

# Initialize model and cache

model_directory = (
    "/mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ"
)
config = ExLlamaV2Config()
config.model_dir = model_directory
config.prepare()

model = ExLlamaV2(config)
print("Loading model: " + model_directory)
model.load()

tokenizer = ExLlamaV2Tokenizer(config)

cache = ExLlamaV2Cache(model)

# Load LoRA

lora_directory = "/mnt/Woo/text-generation-webui/loras/Wizard-Vicuna-13B-Uncensored-GPTQ-reddit-submissions"
lora = ExLlamaV2Lora.from_directory(model, lora_directory)

# Initialize generators

streaming_generator = ExLlamaV2StreamingGenerator(model, cache, tokenizer)
streaming_generator.warmup()

streaming_generator.set_stop_conditions(["\nUser:", tokenizer.eos_token_id])

simple_generator = ExLlamaV2BaseGenerator(model, cache, tokenizer)

# Sampling settings

settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.85
settings.top_k = 50
settings.top_p = 0.8
settings.token_repetition_penalty = 1.1

Loading model: /mnt/Woo/text-generation-webui/models/TheBloke_Wizard-Vicuna-13B-Uncensored-GPTQ


In [55]:
def generate_with_lora(prompt_, lora_, max_new_tokens):
    print(prompt_, end="")
    sys.stdout.flush()

    input_ids = tokenizer.encode(prompt_)

    streaming_generator.begin_stream(input_ids, settings, loras=lora_)
    generated_tokens = 0
    output = ""
    while True:
        chunk, eos, _ = streaming_generator.stream()
        generated_tokens += 1
        output += chunk
        print(chunk, end="")
        sys.stdout.flush()
        if eos or generated_tokens == max_new_tokens:
            break

    print()

    print("done")
    return output

In [45]:
from string import Template

POST_TEMPLATE = Template(
    """You are a Reddit post generator.
User: 
Subreddit: $subreddit 
Author: $author 
Media: $media 
Title: $title 
Write the Reddit post.
Assistant:"""
)

TAGS = {
    "Subreddit": "[SUBREDDIT]",
    "Author": "[AUTHOR]",
    "Media": "[MEDIA]",
    "Title": "[TITLE]",
    "EOS": """Write the Reddit post.
Assistant:""",
}

POST_TAGS_ORDER_OP = [
    TAGS["Subreddit"],
    TAGS["Author"],
    TAGS["Media"],
    TAGS["Title"],
    TAGS["EOS"],
]  # order matters

In [59]:
def get_up_to_tag_line(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        print("Error: tag not found")
        return -1

    # get last \n before tag
    last_newline_index = prompt.rfind("\n", 0, sub_index)

    if last_newline_index == -1:
        last_newline_index = 0
    # remove  from last_newline_index to newline_index
    prompt = prompt[:last_newline_index]
    return prompt


def get_up_to_tag(prompt, tag):
    try:
        sub_index = prompt.index(tag)
    except ValueError:
        return -1
    prompt = prompt[:sub_index]
    return prompt


def generate_post(
    subreddit=TAGS["Subreddit"],
    author=TAGS["Author"],
    media=TAGS["Media"],
    title=TAGS["Title"],
):
    prompt_full = POST_TEMPLATE.substitute(
        subreddit=subreddit, author=author, media=media, title=title
    )

    # loop over tags in order
    for tag in POST_TAGS_ORDER_OP:
        if tag == TAGS["EOS"]:
            break
        prompt = get_up_to_tag(prompt_full, tag)
        if prompt == -1:
            continue  # skip tag if not found
        print(prompt)

        # get tag key from tag value
        tag_key = list(TAGS.keys())[list(TAGS.values()).index(tag)]
        next_tag_value = POST_TAGS_ORDER_OP[POST_TAGS_ORDER_OP.index(tag) + 1]
        next_tag_key = list(TAGS.keys())[list(TAGS.values()).index(next_tag_value)]
        # print("Tag: " + tag_key)
        print(f"""generate from here to: '\\n'""")
        print()
        streaming_generator.set_stop_conditions(
            ["\n", "\nUser:", tokenizer.eos_token_id]
        )
        output = generate_with_lora(prompt, lora, 100)
        print(f"generated output: {output}")

    # print(prompt_full)


# generate_post()
print()
# generate_post(subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg", title="Cute doggo")
generate_post(
    subreddit="r/aww", author="u/aww", media="https://i.redd.it/1q2w3e4r5t6y.jpg"
)


You are a Reddit post generator.
User: 
Subreddit: r/aww 
Author: u/aww 
Media: https://i.redd.it/1q2w3e4r5t6y.jpg 
Title: 
generate from here to: '\n'

You are a Reddit post generator.
User: 
Subreddit: r/aww 
Author: u/aww 
Media: https://i.redd.it/1q2w3e4r5t6y.jpg 
Title: 19 years ago today, I was born to be your mom. 
done
generated output: 19 years ago today, I was born to be your mom. 
