In [1]:
from huggingface_hub import InferenceClient
from tqdm import tqdm
import os
import json



In [2]:
from concurrent.futures import ThreadPoolExecutor, as_completed

generate_kwargs = dict(
    temperature=1.0,
    max_new_tokens=4096,
    top_p=0.95,
    repetition_penalty=1.0,
    do_sample=True,
)

In [3]:
client = InferenceClient(
    "https://mixtral.us-west-2.mesolitica.com", timeout = 120
)


def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

In [4]:
from datasets import load_dataset

dataset = load_dataset("open-phi/textbooks")

Found cached dataset parquet (/home/husein/.cache/huggingface/datasets/open-phi___parquet/open-phi--textbooks-b1f9998a547cd367/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

In [5]:
dataset['train'][0].keys()

dict_keys(['topic', 'model', 'concepts', 'outline', 'markdown', 'field', 'subfield', 'rag'])

In [6]:
def partition(text, size = 2000):
    splitted = text.split()
    return [' '.join(splitted[i: i + size]) for i in range(0, len(splitted), size) if len(splitted[i: i + size]) > 500]

In [7]:
partitions = []
for i in range(len(dataset['train'])):
    partitions.extend(partition(dataset['train'][i]['markdown']))
    
len(partitions)

28809

In [8]:
partitions = partitions * 3

In [13]:
prompts = []
for t in partitions:
    prompt = f'-------\n{t}\n-------\n\ngenerate factually incorrect and confusing questions ONLY to trick people based on context above'
    prompts.extend([(t, prompt)] * 1)
    
len(prompts)

86427

In [17]:
!mkdir mixtral-rag-question-factually-wrong-textbook
# !rm mixtral-rag-question-factually-wrong-textbook/*.json

mkdir: cannot create directory ‘mixtral-rag-question-factually-wrong-textbook’: File exists


In [15]:
def answer(q, i):
    filename = f'mixtral-rag-question-factually-wrong-textbook/{i}.json'
    if os.path.exists(filename):
        return
    
    
    for _ in range(3):
        try:
            prompt = q[1]
            formatted_prompt = format_prompt(prompt, [])
            stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=False, details=True, return_full_text=False)
            output = stream.generated_text
            splitted = output.split('\n')
            splitted = [s for s in splitted if len(s) > 3]
            
            if len(splitted) < 4:
                continue
                
            with open(filename, 'w') as fopen:
                json.dump((q[0], output), fopen)
            break
        except Exception as e:
            # print(e)
            pass

In [16]:
answer(prompts[0], 0)

In [18]:
def consumer(queue, name):
    while True:
        if queue.qsize() == 0:
            break
        item = queue.get()
        answer(*item)
    print(f'consumer {name} done')

In [19]:
urls = [(q, no) for no, q in enumerate(prompts)]

In [20]:
from threading import Thread
from queue import Queue

queue = Queue()
for u in urls:
    queue.put(u)
    
ori_size = queue.qsize()

In [21]:
max_worker = 256
consumers = [Thread(target=consumer, args=(queue,i)) for i in range(max_worker)]
for i in range(len(consumers)):
    consumers[i].start()
    
pbar = tqdm(total=ori_size)
last_size = 0
while True:
    size = queue.qsize()
    if size == 0:
        break
    left = ori_size - size
    minus = left - last_size
    if minus > 0:
        pbar.update(minus)
        last_size += minus

pbar.close()

 72%|████████████████████████████████████████████████████████████████████████████▎                             | 62259/86427 [7:19:58<3:09:20,  2.13it/s]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

