In [15]:
from together import Together
import pandas as pd
import numpy as np
from datasets import load_dataset
from tokens import tokens
import os
import dspy
import textgrad
from textgrad.engine import get_engine
from textgrad import Variable
from textgrad.optimizer import TextualGradientDescent
from textgrad.loss import TextLoss
import concurrent
from tqdm import tqdm
import random

## Initialization

In [2]:
# set up tokens
os.environ['TOGETHER_API_KEY'] = together_token
os.environ['HF_TOKEN'] = hf_token

In [3]:
# Load testing data
dataset = load_dataset("gtfintechlab/fomc_communication")
sample = dataset['test'][0]
sentence = sample['sentence']
label = sample['label']
print(f"Original sentence: {sentence}\nLabel: {label}")

Original sentence: Participants agreed that the labor market had remained strong over the intermeeting period and that economic activity had risen at a moderate rate.
Label: 2




In [4]:
# the one in prompts.py was not returning just the system prompt/user message which together 1.2 needs
def fomc_prompt(sentence: str):

    system_prompt = f"""Discard all the previous instructions. Behave like you are an expert sentence clas-
                sifier."""
    user_msg = f"""Classify the following sentence from FOMC into ‘HAWKISH’, ‘DOVISH’, or ‘NEU-
                TRAL’ class. Label ‘HAWKISH’ if it is corresponding to tightening of the monetary policy,
                ‘DOVISH’ if it is corresponding to easing of the monetary policy, or ‘NEUTRAL’ if the
                stance is neutral. Provide the label in the first line and provide a short explanation in the
                second line. This is the sentence: {sentence}"""

    return system_prompt, user_msg

## Initial textgrad testing

In [7]:
# basic textgrad together engine
engine = get_engine('together-allenai/OLMo-7B-Instruct')
engine(fomc_prompt(sentence)[1], system_prompt = fomc_prompt(sentence)[0], max_tokens=128, temperature=0.7, top_p=0.7)

'Neutral class\nExplanation: The sentence indicates that participants agreed on the strong labor market and moderate economic activity, suggesting a neutral stance towards monetary policy.'

In [9]:
system_prompt = 'Discard all the previous instructions. Behave like you are an expert sentence classifier.'
engine = get_engine('together-mistralai/Mistral-7B-Instruct-v0.3')

sys_prompt = Variable('Discard all the previous instructions. Behave like you are an expert sentence classifier.', role_description="The system prompt")
user_prompt = Variable("Classify the sentence's stance on the monetary policy between hawkish, neutral, and dovish.", role_description="The user prompt", requires_grad=True)
input_sentence = Variable(sentence, role_description="The input sentence")
loss = TextLoss(sys_prompt, engine=engine)

# optimization does not allow you to set any parameters besides prompt & system prompt
# meaning you can't set max tokens or temperature --> problem b/c default output tokens is 2000 
# so models with too small context windows will error

optimizer = TextualGradientDescent(parameters=[user_prompt], engine=engine)
l = loss(input_sentence)
l.backward(engine)
optimizer.step()

In [10]:
# optimized version of user prompt
user_prompt.value

'Determine the monetary policy stance (hawkish, neutral, or dovish) expressed in the given sentence.'

## Textgrad Prompt Optimization

In [5]:
def eval_llm(response, label, eval_model):
    response_var = Variable(f"Response: {response.value}\nTrue answer: {label.value}", role_description="response to be evaluated")
    eval_response = eval_model(response_var)
    eval_response = eval_response.value.strip().split(" ")[0].lower()
    return int(eval_response == "yes")

def eval_sample(item, eval_fn, model, eval_model):
    x, y = item
    x = Variable(x, requires_grad=False, role_description="query to the language model")
    y = Variable(y, requires_grad=False, role_description="correct answer for the query")
    response = model(x)
    return eval_fn(response, y, eval_model)
    
def eval_dataset(test_set, eval_fn, model, eval_model, max_samples = None):
    if max_samples is None: 
        max_samples = len(test_set)
    accuracy_list = []
    with concurrent.futures.ThreadPoolExecutor() as executor:
        futures = []
        for _, sample in enumerate(test_set):
            future = executor.submit(eval_sample, sample, eval_fn, model, eval_model)
            futures.append(future)
            if len(futures) >= max_samples:
                break
        tqdm_loader = tqdm(concurrent.futures.as_completed(futures), total=len(futures), position=0)
        for future in tqdm_loader:
            acc_item = future.result()
            accuracy_list.append(acc_item)
            tqdm_loader.set_description(f"Accuracy: {np.mean(accuracy_list)}")
    return accuracy_list 

def run_validation_revert(system_prompt, results, eval_fn, model, eval_model, val_set):
    val_accuracy = np.mean(eval_dataset(val_set, eval_fn, model, eval_model))
    prev_accuracy = np.mean(results["val_accuracy"][-1])
    print(f"Validation accuracy: {val_accuracy}\nPrevious validation accuracy: {prev_accuracy}")
    previous_prompt = results["prompt"][-1]
    
    if val_accuracy < prev_accuracy:
        print(f"Rejected prompt: {system_prompt.value}")
        system_prompt.set_value(previous_prompt)
        val_accuracy = prev_accuracy

    results["val_accuracy"].append(val_accuracy)

In [17]:
engine = get_engine('together-mistralai/Mistral-7B-Instruct-v0.3')
eval_prompt = Variable("Determine if the response matches the true answer. Respond with 'yes' or 'no'.", 
                            requires_grad=False, role_description="evaluation prompt to the language model")
eval_model = textgrad.BlackboxLLM(engine, eval_prompt)

starting_prompt = "Classify the sentence's stance on the monetary policy between hawkish, neutral, and dovish."
system_prompt = Variable(starting_prompt, requires_grad=True, role_description="system prompt to the language model")
model = textgrad.BlackboxLLM(engine, system_prompt)

optimizer = TextualGradientDescent(engine=engine, parameters=[system_prompt])

In [7]:
# testing zero-shot performance
mapping = {0: 'dovish', 1: 'hawkish', 2: 'neutral'}
data = [(data['sentence'], mapping[data['label']]) for data in dataset['test']]
results = {"test_accuracy": [], "prompt": []}
results["test_accuracy"].append(eval_dataset(data, eval_llm, model, eval_model))
results["prompt"].append(system_prompt.get_value())

Accuracy: 0.5826612903225806: 100%|██████████| 496/496 [20:12<00:00,  2.44s/it]


In [24]:
# subset the data to do a mini run
training_data = [(data['sentence'], mapping[data['label']]) for data in dataset['train']]
testing_data = [(data['sentence'], mapping[data['label']]) for data in dataset['test']]
random.shuffle(training_data)
training_data = training_data[:30]
val_data = training_data[-10:]
testing_data = testing_data[:10]

In [38]:
results = {"test_accuracy": [0], "prompt": ["Classify the sentence's stance on the monetary policy between hawkish, neutral, and dovish."], "val_accuracy": [0]}
train_loader = textgrad.tasks.DataLoader(training_data, batch_size=3, shuffle=True)
num_epochs = 5
for epoch in range(num_epochs):
    for steps, (batch_x, batch_y) in enumerate((pbar := tqdm(train_loader, position=0))):
        pbar.set_description(f"Training step {steps}. Epoch {epoch}")
        optimizer.zero_grad()
        losses = []
        for (x, y) in zip(batch_x, batch_y):
            eval_var = eval_sample((x, y), eval_llm, model, eval_model)
            eval_var = Variable(str(eval_var), requires_grad=False, role_description="evaluation variable")
            losses.append(eval_var)
        total_loss = textgrad.sum(losses)
        total_loss.backward(engine)
        optimizer.step()
        
        run_validation_revert(system_prompt, results, eval_llm, model, eval_model, val_data)
        
        print("sys prompt: ", system_prompt)
        test_acc = eval_dataset(testing_data, eval_llm, model, eval_model)
        results["test_accuracy"].append(test_acc)
        results["prompt"].append(system_prompt.get_value())
        if steps == 3:
            break

Accuracy: 0.3: 100%|██████████| 10/10 [00:26<00:00,  2.63s/it]              


Validation accuracy: 0.3
Previous validation accuracy: 0.0
sys prompt:  Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a clear and concise response.


Accuracy: 0.5: 100%|██████████| 10/10 [00:24<00:00,  2.45s/it]               
Accuracy: 0.6: 100%|██████████| 10/10 [00:26<00:00,  2.61s/it]              


Validation accuracy: 0.6
Previous validation accuracy: 0.3
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:24<00:00,  2.44s/it]              
Accuracy: 0.5: 100%|██████████| 10/10 [00:23<00:00,  2.35s/it]              


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 518.87it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 360.19it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 642.86it/s]     
Training step 3. Epoch 0: : 3it [02:32, 50.91s/it]
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 540.52it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 645.57it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 277.24it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 523.40it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 420.47it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 543.61it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 561.43it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 507.55it/s]     
Training step 3. Epoch 1: : 3it [00:17,  5.68s/it]
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 704.78it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 528.55it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 339.98it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 557.17it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 378.91it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 507.75it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 604.98it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 187.04it/s]     
Training step 3. Epoch 2: : 3it [00:09,  3.05s/it]
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 899.35it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 411.34it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 571.90it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 455.35it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 573.77it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 691.91it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 432.46it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 638.90it/s]     
Training step 3. Epoch 3: : 3it [00:01,  2.52it/s]
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 527.03it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 451.86it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 173.79it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 363.78it/s]     
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 511.61it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 1038.89it/s]    
Accuracy: 0.5: 100%|██████████| 10/10 [00:00<00:00, 587.89it/s]     


Validation accuracy: 0.5
Previous validation accuracy: 0.6
Rejected prompt: Identify and classify the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, providing a concise and clear response.
sys prompt:  Determine and categorize the monetary policy stance (hawkish, neutral, or dovish) in the given sentence, offering a succinct and clear response.


Accuracy: 0.7: 100%|██████████| 10/10 [00:00<00:00, 429.21it/s]     
Training step 3. Epoch 4: : 3it [00:08,  2.81s/it]


## Testing together 1.2 API

In [8]:
# Testing together 1.2 API
client = Together()
model_response = client.chat.completions.create(
    model='allenai/OLMo-7B-Instruct',
    messages=[{'role': 'system', 'content': fomc_prompt(sentence)[0]},
                {'role': 'user', 'content': fomc_prompt(sentence)[1]}],
    max_tokens=128,
    temperature=0.7,
    top_k=50,
    top_p=0.7,
    repetition_penalty=1.1
)
response_label = model_response.choices[0].message.content
token_usage = {'prompt_tokens': model_response.usage.prompt_tokens, 'response_tokens': model_response.usage.completion_tokens, 'total_tokens': model_response.usage.total_tokens}
print(f"LLM response: {response_label.strip()}\nToken usage: {token_usage}")

LLM response: Neutral class
Explanation: The sentence mentions that participants agreed that the labor market remained strong and economic activity rose at a moderate rate, indicating a neutral stance towards monetary policy.
Token usage: {'prompt_tokens': 174, 'response_tokens': 37, 'total_tokens': 211}


## Testing DSPy

In [54]:
# testing dspy
# dspy together module has some issues with how they wrote it (passing specific parameters is hard)
# rate limiting error messages
# TODO: look into dspy together module more to see how to fix it

lm = dspy.Together(model = 'mistralai/Mistral-7B-Instruct-v0.3', stop = ['</s>', '<s>'])
dspy.settings.configure(lm=lm)
d = {0: 'dovish', 1: 'hawkish', 2: 'neutral'}

class StanceAnalysis(dspy.Signature):
    """Classify the sentence's stance on the monetary policy between hawkish, neutral, and dovish."""
    
    sentence = dspy.InputField()
    stance = dspy.OutputField(desc = "hawkish, neutral, or dovish")

class Analysis(dspy.Module):
    def __init__(self):
        super().__init__()
        self.predict = dspy.Predict(StanceAnalysis)
    
    def forward(self, sentence):
        return self.predict(sentence=sentence)
    
analyze = Analysis()
analyze(sentence)

Backing off 0.6 seconds after 1 tries calling function <function Together._generate at 0x0000017878ABF700> with kwargs {'temperature': 0.0, 'max_tokens': 512, 'top_p': 1, 'top_k': 20, 'repetition_penalty': 1, 'n': 1, 'stop': ['</s>', '<s>']}
Backing off 0.5 seconds after 2 tries calling function <function Together._generate at 0x0000017878ABF700> with kwargs {'temperature': 0.0, 'max_tokens': 512, 'top_p': 1, 'top_k': 20, 'repetition_penalty': 1, 'n': 1, 'stop': ['</s>', '<s>']}
Backing off 3.4 seconds after 3 tries calling function <function Together._generate at 0x0000017878ABF700> with kwargs {'temperature': 0.0, 'max_tokens': 512, 'top_p': 1, 'top_k': 20, 'repetition_penalty': 1, 'n': 1, 'stop': ['</s>', '<s>']}
Backing off 7.0 seconds after 4 tries calling function <function Together._generate at 0x0000017878ABF700> with kwargs {'temperature': 0.0, 'max_tokens': 512, 'top_p': 1, 'top_k': 20, 'repetition_penalty': 1, 'n': 1, 'stop': ['</s>', '<s>']}
Backing off 12.9 seconds after 5

KeyboardInterrupt: 