In [1]:
import json, random, rouge, itertools
from tqdm import tqdm
from langchain.embeddings import OpenAIEmbeddings
from scipy.spatial.distance import cosine
from shroom_classifier_usp_v2 import ShroomClassifier

In [12]:
def average_pairwise_rouge(texts):
    r = rouge.Rouge()
    pairwise_combinations = [[a, b] for a, b in itertools.combinations(texts, 2)]
    hyps, refs = map(list, zip(*pairwise_combinations))
    return r.get_scores(hyps, refs, avg=True)['rouge-1']['f']
    

In [13]:
def cosine_similarity(v1, v2):
    return 1 - cosine(v1, v2)

In [92]:
def selection_metric(p, S, l=0.8):
    phi_p = p["rationale_embedding"]
    Sc = max([ cosine_similarity(phi_p, s["rationale_embedding"]) for s in S ])
    return p['F_LFG'] - (l * Sc)


In [31]:
def pseudo_demos(dp):
    for i, rationale in enumerate(dp["rationales"]):
        if dp['predictions'][i] == dp['predicted']:
            yield { 
                'hyp': dp['hyp'],
                'tgt': dp['tgt'],
                'src': dp['src'],
                'ref': dp['ref'],
                'task': dp['task'],
                'model': dp['model'],
                'rationale': rationale,
                'predicted': dp['predicted'],
                'rationale_embedding': dp['rationale_embeddings'][i],
                'F_LFG': dp['F_LFG']
            }

In [90]:
def pseudo_demo_selection(datapoints, K=3):
    pool = [ pd for pds in [ pseudo_demos(dp) for dp in datapoints ] for pd in pds ]
    selections = []
    for k in range(K):
        if k == 0:
            sk = max(pool, key=lambda x: x['F_LFG'])
        else:
            sk = max(pool, key=lambda x: selection_metric(x, selections))
        selections.append(sk)
        pool.remove(sk)
    return selections

In [99]:
EMBEDDINGS_MODEL = OpenAIEmbeddings()

In [2]:
demo_generating_dataset = json.load(open('train.model-agnostic.json', 'r'))

In [3]:
dm_unlabelled_datapoints = random.sample([ dp for dp in demo_generating_dataset if dp['task'] == "DM" ], 64)
pg_unlabelled_datapoints = random.sample([ dp for dp in demo_generating_dataset if dp['task'] == "PG" ], 64)
mt_unlabelled_datapoints = random.sample([ dp for dp in demo_generating_dataset if dp['task'] == "MT" ], 64)
cp = ShroomClassifier(model_name="gpt-4-1106-preview", temperature=0.7)

In [5]:
for dp in tqdm(dm_unlabelled_datapoints):
    dp.update(cp.stage_1_classify(dp["task"], dp["src"], dp["tgt"], dp["hyp"], dp["ref"]))
    dp["F_LFG"] = average_pairwise_rouge(dp["rationales"])
    dp["rationale_embeddings"] = EMBEDDINGS_MODEL.embed_documents(dp["rationales"])

 56%|█████▋    | 36/64 [12:19<09:38, 20.66s/it]Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised APIError: HTTP code 502 from API (<html>
<head><title>502 Bad Gateway</title></head>
<body>
<center><h1>502 Bad Gateway</h1></center>
<hr><center>cloudflare</center>
</body>
</html>
).
100%|██████████| 64/64 [23:02<00:00, 21.61s/it]


In [97]:
dm_selections = pseudo_demo_selection(dm_unlabelled_datapoints, K=5)
for i, s in enumerate(dm_selections):
    print(f'{i}: {s["hyp"]}, "{s["rationale"]}" ({s["predicted"]})')

0: (fandom slang) A card game in which the player character is a troll., "The output is a hallucination. The target defines "trollcards" in the context of Homestuck fandom slang as a graphic featuring a portrait and biography of a fantroll, which is essentially a fan-created character in the Homestuck universe. In contrast, the output describes "trollcards" as a card game where the player character is a troll. This description does not match the target reference, and there is no evidence provided that "trollcards" is indeed a card game. Therefore, the output contains information that is not supported by the reference, making it a hallucination." (Hallucination)
1: (Australia, New Zealand, slang) A bar., "The output is a hallucination. The term "jukie" in the context provided by the target refers to a jukebox, which is a machine for playing music. The output incorrectly defines "jukie" as a bar in Australia and New Zealand slang. This definition is not supported by the target reference,

In [100]:
for dp in tqdm(pg_unlabelled_datapoints):
    dp.update(cp.stage_1_classify(dp["task"], dp["src"], dp["tgt"], dp["hyp"], dp["ref"]))
    dp["F_LFG"] = average_pairwise_rouge(dp["rationales"])
    dp["rationale_embeddings"] = EMBEDDINGS_MODEL.embed_documents(dp["rationales"])

  9%|▉         | 6/64 [01:52<18:09, 18.78s/it]Retrying langchain.chat_models.openai.ChatOpenAI.completion_with_retry.<locals>._completion_with_retry in 4.0 seconds as it raised APIError: Request failed due to server shutdown {
  "error": {
    "message": "Request failed due to server shutdown",
    "type": "server_error",
    "param": null,
    "code": null
  }
}
 500 {'error': {'message': 'Request failed due to server shutdown', 'type': 'server_error', 'param': None, 'code': None}} {'Date': 'Fri, 10 Nov 2023 22:46:05 GMT', 'Content-Type': 'application/json', 'Content-Length': '141', 'Connection': 'keep-alive', 'access-control-allow-origin': '*', 'openai-model': 'gpt-4-1106-preview', 'openai-organization': 'saros156', 'openai-processing-ms': '11872', 'openai-version': '2020-10-01', 'strict-transport-security': 'max-age=15724800; includeSubDomains', 'x-ratelimit-limit-requests': '10000', 'x-ratelimit-limit-tokens': '300000', 'x-ratelimit-remaining-requests': '9950', 'x-ratelimit-remaini

In [101]:
pg_selections = pseudo_demo_selection(pg_unlabelled_datapoints, K=5)
for i, s in enumerate(pg_selections):
    print(f'{i}: {s["hyp"]}, "{s["rationale"]}" ({s["predicted"]})')

0: What'd she do to ya?, "The output "What'd she do to ya?" is not a hallucination. It is a paraphrased version of the input "What did she do to you?" that uses contractions and a more colloquial or informal form of "you" ("ya"). The meaning of the sentence is preserved, and no new information is introduced. The output maintains the original question's intent and content without adding, removing, or altering the information. Therefore, the assertion that the output is a hallucination is incorrect." (Not Hallucination)
1: They're phony., "The output "They're phony" is not a hallucination. It is a valid paraphrase of the input "They're fake." Both "fake" and "phony" are synonyms that convey the idea of something being inauthentic or not genuine. The output maintains the original meaning without introducing any new information that is unsupported by the reference." (Not Hallucination)
2: Hiya, Smitty., "The output "Hiya, Smitty." is not a hallucination. It is a valid paraphrase of the inp

In [102]:
for dp in tqdm(mt_unlabelled_datapoints):
    dp.update(cp.stage_1_classify(dp["task"], dp["src"], dp["tgt"], dp["hyp"], dp["ref"]))
    dp["F_LFG"] = average_pairwise_rouge(dp["rationales"])
    dp["rationale_embeddings"] = EMBEDDINGS_MODEL.embed_documents(dp["rationales"])

100%|██████████| 64/64 [20:03<00:00, 18.80s/it]


In [103]:
mt_selections = pseudo_demo_selection(mt_unlabelled_datapoints, K=5)
for i, s in enumerate(mt_selections):
    print(f'{i}: {s["hyp"]}, "{s["rationale"]}" ({s["predicted"]})')

0: Even Tom couldn't help., "The output "Even Tom couldn't help." is not a hallucination. It directly corresponds to the input "Даже Том не смог бы помочь." and conveys the same meaning as the target translation. There is no additional information in the output that is not supported by the reference input or target. Therefore, the output is an accurate translation and not a hallucination." (Not Hallucination)
1: Put it all in the truck., "The output is not a hallucination. It directly corresponds to the provided input and accurately reflects the target. The English translation "Put it all in the truck." is an appropriate and faithful translation of the Russian input "Положите это всё в грузовик." with no additional information added. Hence, the output contains no information that is not supported by the reference." (Not Hallucination)
2: It depends on where we decide to go., "The output is not a hallucination. The output "It depends on where we decide to go." is a faithful translation 

In [104]:
selected_pseudo_demos = {
    "DM": dm_selections,
    "PG": pg_selections,
    "MT": mt_selections,
}

json.dump(selected_pseudo_demos, open('shroom_selected_pseudo_demos.json', "w+"))

In [114]:
demos = {
    "DM": demo_string(dm_selections),
    "PG": demo_string(pg_selections),
    "MT": demo_string(mt_selections),
}

In [113]:
def demo_string(selections):
    return ''.join([ f'##\nInput: {pd["src"]}\nTarget: {pd["tgt"]}\nOutput: {pd["hyp"]}\nRationale: {pd["rationale"]}\n' for pd in selections])

In [115]:
demos

{'DM': '##\nInput: <define> trollcards </define> for my trollsona : v :\nTarget: (Homestuck, _, fandom slang) A graphic featuring a portrait and biography of a fantroll.\nOutput: (fandom slang) A card game in which the player character is a troll.\nRationale: The output is a hallucination. The target defines "trollcards" in the context of Homestuck fandom slang as a graphic featuring a portrait and biography of a fantroll, which is essentially a fan-created character in the Homestuck universe. In contrast, the output describes "trollcards" as a card game where the player character is a troll. This description does not match the target reference, and there is no evidence provided that "trollcards" is indeed a card game. Therefore, the output contains information that is not supported by the reference, making it a hallucination.\n##\nInput: I would sit at the end of the bar , far from the <define> jukie </define> , near the door .\nTarget: (slang) A jukebox (machine for playing music).\n