# In-Context Learning for AbstRCT dataset with LlaMA 3

## Libraries

In [1]:
import ast
import torch
import random
import numpy as np
import pandas as pd
import torch.nn.functional as F


from tqdm import tqdm
from operator import itemgetter
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

## Tokenizers and Models

In [2]:
embedding_tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
embedding_model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")

In [205]:
model_id = "unsloth/llama-3-70b-Instruct-bnb-4bit"

In [206]:
inference_tokenizer = AutoTokenizer.from_pretrained(model_id, padding='left', padding_side='left')
inference_tokenizer.pad_token = inference_tokenizer.eos_token
terminators = [
    inference_tokenizer.eos_token_id,
    inference_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [207]:
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir = '/home/umushtaq/scratch/am_work/in_context_learning/model_downloads',
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

## Dataset

In [208]:
neo_train_df = pd.read_csv("dataset/neo/train.csv")
neo_test_df = pd.read_csv("dataset/neo/test.csv")
gla_test_df = pd.read_csv("dataset/gla/test.csv")
mix_test_df = pd.read_csv("dataset/mix/test.csv")

In [209]:
def get_title_line(x, df):

    doc_id = x.doc_id
    df_doc = df[df["doc_id"] == doc_id].reset_index()
    
    for index, row in df_doc.iterrows():
        
        if pd.notna(row['text']):
            title_line = row['text'].split(".")[0]
            break

    return title_line

In [210]:
neo_train_df["abstract_title"] = neo_train_df.apply(lambda x: get_title_line(x, neo_train_df), axis=1)

In [211]:
neo_test_df["abstract_title"] = neo_test_df.apply(lambda x: get_title_line(x, neo_test_df), axis=1)

In [212]:
gla_test_df["abstract_title"] = gla_test_df.apply(lambda x: get_title_line(x, gla_test_df), axis=1)

In [213]:
mix_test_df["abstract_title"] = mix_test_df.apply(lambda x: get_title_line(x, mix_test_df), axis=1)

## Get title embeddings using BioBERT

In [214]:
title_embed_d = {}
for title in tqdm(neo_train_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 350/350 [00:32<00:00, 10.78it/s]


In [215]:
neo_train_df['title_embedding'] = neo_train_df.abstract_title.apply(lambda x: title_embed_d[x])

In [216]:
title_embed_d = {}
for title in tqdm(neo_test_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 99/99 [00:09<00:00, 10.47it/s]


In [217]:
neo_test_df['title_embedding'] = neo_test_df.abstract_title.apply(lambda x: title_embed_d[x])

In [218]:
title_embed_d = {}
for title in tqdm(gla_test_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 98/98 [00:08<00:00, 11.00it/s]


In [219]:
gla_test_df['title_embedding'] = gla_test_df.abstract_title.apply(lambda x: title_embed_d[x])

In [220]:
title_embed_d = {}
for title in tqdm(mix_test_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 100/100 [00:09<00:00, 10.91it/s]


In [221]:
mix_test_df['title_embedding'] = mix_test_df.abstract_title.apply(lambda x: title_embed_d[x])

In [222]:
def get_abstract_texts(df):

    concatenated_text = df.fillna('').groupby('doc_id')['text'].agg(' '.join).reset_index()
    concatenated_text['text'] = concatenated_text['text'].str.strip()
    df = df.merge(concatenated_text, on='doc_id', suffixes=('', '_concatenated'))
    df = df.rename(columns={'text_concatenated': 'abstract_text'})

    return df

In [223]:
neo_train_df = get_abstract_texts(neo_train_df)
neo_test_df = get_abstract_texts(neo_test_df)
gla_test_df = get_abstract_texts(gla_test_df)
mix_test_df = get_abstract_texts(mix_test_df)

In [224]:
def process_majorclaims(x):

    aty = x.aty
    if aty == 'MajorClaim':
        aty = 'Claim'

    return aty

In [225]:
neo_train_df['aty'] = neo_train_df.apply(lambda x: process_majorclaims(x), axis=1)
neo_test_df['aty'] = neo_test_df.apply(lambda x: process_majorclaims(x), axis=1)
gla_test_df['aty'] = gla_test_df.apply(lambda x: process_majorclaims(x), axis=1)
mix_test_df['aty'] = mix_test_df.apply(lambda x: process_majorclaims(x), axis=1)

In [226]:
neo_train_df[neo_train_df.aty != 'none'].aty.value_counts(normalize=True)

aty
Premise    0.680052
Claim      0.319948
Name: proportion, dtype: float64

## Get K neighbours

In [227]:
def get_k_neighbours(k, title, test_df, train_df=neo_train_df):

    test_title_embedding = test_df[test_df.abstract_title == title]["title_embedding"].values[0]

    title_embed_d = {}
    for e in train_df.iterrows():
        if e[1].abstract_title not in title_embed_d:
            title_embed_d[e[1].abstract_title] = e[1].title_embedding

    # train_titles = set(df[df.split == 'TRAIN'].title.unique())
    train_titles = set(train_df.abstract_title.unique())

    dist_l = []
    for t, v in title_embed_d.items():
        if t in train_titles:
            # d = cos_sim(title_embed_d[title], v)
            d = F.cosine_similarity(torch.tensor(test_title_embedding), torch.tensor(v), dim=0)
            dist_l.append((t, d.item()))

    sorted_dist_l = sorted(dist_l, key=itemgetter(1), reverse=True)
    
    return sorted_dist_l[0: k]

In [228]:
get_k_neighbours(5, neo_test_df.iloc[100]["abstract_title"], neo_test_df)

[(' Research has shown that self-directed stress management training improves mental well-being in patients undergoing chemotherapy',
  0.9953505992889404),
 (' Numerous studies have examined the comorbidity of depression with cancer, and some have indicated that depression may be associated with cancer progression or survival',
  0.994724452495575),
 (' Behavioral symptoms are common in breast cancer survivors, including disturbances in energy, sleep, and mood, though few risk factors for these negative outcomes have been identified',
  0.9935887455940247),
 (' Lymphoma patients commonly experience declines in physical functioning and quality of life (QoL) that may be reversed with exercise training',
  0.9935306310653687),
 (' Few intervention studies have been conducted to help couples manage the effects of prostate cancer and maintain their quality of life',
  0.992271900177002)]

## Prepare Prompts

In [229]:
def prepare_similar_example_prompts(title, df, k=0, seed=33):
    """
    Create a part of prompt made of k examples in the train set, whose topic is most similar to a given title.
    """

    random.seed(seed)

    neighbours_l = get_k_neighbours(2*k, title, df) # Fetch the 2*k closest neighbors
    # print(neighbours_l)
    sampled_neighbours_l = random.sample(neighbours_l, k) # Only keep k of them
    # bprint(sampled_neighbours_l)

    prompt = ''
    cnt = 0
    for i, (title, dist) in enumerate(sampled_neighbours_l):
        prompt += f'## Example {i+1}\n'

        example_df = neo_train_df[neo_train_df.abstract_title == title]
        example_df = example_df[example_df.aty != 'none'].reset_index()
        
        class_l = []
        for k in example_df.iterrows():
            
            if k[0] == 0:

                prompt += f'# Abstract:\n{example_df.iloc[0].abstract_text}\n\n# Arguments:\n'
                cnt = 0
                
            # prompt += f'Argument {cnt + 1}={k[1].text} - Class={k[1].aty}\n'
            prompt += f'Argument {cnt + 1}={k[1].text}\n'
            class_l.append(k[1].aty)
            cnt += 1
            
        prompt += '\n# Result:\n'
        prompt += '{' + ', '.join([f'"Argument {i + 1}": "{class_l[i]}"' for i in range(len(class_l))]) + '}'
        prompt += '\n\n'

    return prompt

In [230]:
print(prepare_similar_example_prompts(neo_test_df.iloc[14]["abstract_title"], neo_test_df, k=0))




## Prepare Test Set Prompts

In [231]:
claim_fulldesc = "A claim in the abstract of an RCT is a statement or conclusion about the findings of the study."

In [232]:
claim_fulldesc

'A claim in the abstract of an RCT is a statement or conclusion about the findings of the study.'

In [233]:
premise_fulldesc = "A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim."

In [234]:
premise_fulldesc

'A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.'

In [235]:
proportion_desc = "68.0052% of examples are of type Premise and 31.9948% of type Claim."

In [236]:
proportion_desc

'68.0052% of examples are of type Premise and 31.9948% of type Claim.'

In [281]:
%%time

#experiment_df = df[df.split == 'TEST']
# experiment_df = neo_test_df
# experiment_df = gla_test_df
experiment_df = mix_test_df

target_l = []

# Pre-prepare all examples from each document
title_l = []
essay_l = []
example_l = []
buffer_l = []
label_l = []
for e in experiment_df.iterrows():

    if e[1].span_pos == 1:
    #if e[0] == 0:
        essay_l.append(e[1].abstract_text)
        title_l.append(e[1].abstract_title)
    
    if e[1].span_pos == 1 and len(buffer_l) > 0:
    #if e[0] == 0 and len(buffer_l) > 0:
        example_l.append(buffer_l)
        target_l.append(label_l)
        buffer_l = []
        label_l = []
    if e[1].aty != 'none':
        buffer_l.append(e[1].text)
        label_l.append(e[1].aty)

# Flush the buffers
example_l.append(buffer_l)
target_l.append(label_l)


sys_msg_l_by_seed_d = {}
task_msg_l = []

for i in range(len(example_l)):

    # Prepare numbered list of ACs in this example
    other_acs_prompt = ''
    for j, s in enumerate(example_l[i]):
        # other_acs_prompt += f'Argument {j + 1}={s}\n'
        other_acs_prompt += f'Argument {j + 1}={s}\n'

    for seed in [42]:
        sys_msg = {"role":"system", "content": "### Task description: You are an expert biomedical assistant that takes 1) an abstract text and 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. " + proportion_desc + " You must absolutely not generate any text or explanation other than the following JSON format {\"Argument 1\": <predicted class for Argument 1 (str)>, ..., \"Argument n\": <predicted class for Argument n (str)>}\n\n### Class definitions:" + " Claim = " + claim_fulldesc + " Premise = " + premise_fulldesc + "" + prepare_similar_example_prompts(title_l[i], experiment_df, k=0, seed=seed)}  # Sample by similar title

        try:
            sys_msg_l_by_seed_d[seed].append(sys_msg)
        except KeyError:
            sys_msg_l_by_seed_d[seed] = [sys_msg]

    task_msg = {"role":"user", "content": f"# Abstract:\n{essay_l[i]}\n\n# Arguments:\n{other_acs_prompt}\n\n# Result:\n"}
    
    task_msg_l.append(task_msg)

CPU times: user 4.52 s, sys: 6.3 ms, total: 4.53 s
Wall time: 4.56 s


In [282]:
len(sys_msg_l_by_seed_d[42])

100

In [283]:
print(sys_msg_l_by_seed_d[42][0]['content'])

### Task description: You are an expert biomedical assistant that takes 1) an abstract text and 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}

### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.


In [284]:
len(task_msg_l)

100

In [285]:
print(task_msg_l[0]['content'])

# Abstract:
Few controlled clinical trials exist to support oral combination therapy in pulmonary arterial hypertension (PAH). Patients with PAH (idiopathic [IPAH] or associated with connective tissue disease [APAH-CTD]) taking bosentan (62.5 or 125 mg twice daily at a stable dose for ≥3 months) were randomized (1:1) to sildenafil (20 mg, 3 times daily; n = 50) or placebo (n = 53). The primary endpoint was change from baseline in 6-min walk distance (6MWD) at week 12, assessed using analysis of covariance. Patients could continue in a 52-week extension study. An analysis of covariance main-effects model was used, which included categorical terms for treatment, baseline 6MWD (<325 m; ≥325 m), and baseline aetiology; sensitivity analyses were subsequently performed. In sildenafil versus placebo arms, week-12 6MWD increases were similar (least squares mean difference [sildenafil-placebo], -2.4 m [90% CI: -21.8 to 17.1 m]; P = 0.6); mean ± SD changes from baseline were 26.4 ± 45.7 versus 1

In [286]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l_by_seed_d[42])):
    prepared_sys_task_msg_l.append([sys_msg_l_by_seed_d[42][i], task_msg_l[i]])

In [287]:
len(prepared_sys_task_msg_l)

100

In [288]:
print(prepared_sys_task_msg_l[0])

[{'role': 'system', 'content': '### Task description: You are an expert biomedical assistant that takes 1) an abstract text and 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}\n\n### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.'}, {'role': 'user', 'content': '# Abstract:\nFew controlled clinical trials exist to support oral combination therapy in pulmonary arterial hypertension (PAH). Patients with PAH (idiopathic [IPAH] or associated with connective 

## Generate Predictions

In [289]:
outputs_l = []

for i in tqdm(range(len(prepared_sys_task_msg_l))):

    messages = prepared_sys_task_msg_l[i]

    input_ids = inference_tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(generation_model.device)

    outputs = generation_model.generate(
    input_ids = input_ids,
    max_new_tokens=1024,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    # inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    outputs_l.append(inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True))

100%|██████████| 100/100 [12:34<00:00,  7.54s/it]


In [290]:
outputs_l

['{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Claim"}',
 '{"Argument 1": "Claim", "Argument 2": "Claim", "Argument 3": "Claim", "Argument 4": "Claim", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Premise", "Argument 8": "Claim", "Argument 9": "Premise"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Claim"}',
 '{"Argument 1": "Claim", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Claim", "Argument 5": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Claim", "Argument 4": "Claim", "Argument 5": "Premise", "Argument 6": "Claim", "Argument 7": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Claim", "Argument 3": "Premise", "Argument 4": "Claim", "Argument 5": "Premise"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise",

In [291]:
len(outputs_l)

100

In [292]:
preds_l = [list(ast.literal_eval(output).values()) for output in outputs_l]

In [293]:
preds_l

[['Premise', 'Premise', 'Premise', 'Claim'],
 ['Claim',
  'Claim',
  'Claim',
  'Claim',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Premise'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Claim', 'Claim', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Claim', 'Premise', 'Claim', 'Premise'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Claim',
  'Claim',
  'Claim',
  'Premise',
  'Claim'],
 ['Claim', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Premise', 'Claim', 'Claim', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim'],
 ['

In [294]:
len(preds_l)

100

In [295]:
grouped_df = mix_test_df[mix_test_df.aty != 'none'].groupby('doc_id', sort=False)['aty'].apply(list).reset_index()

In [296]:
grouped_df

Unnamed: 0,doc_id,aty
0,28874133,"[Premise, Premise, Premise, Claim]"
1,11346336,"[Premise, Premise, Premise, Premise, Premise, ..."
2,28370419,"[Premise, Premise, Premise, Premise, Premise, ..."
3,29596950,"[Premise, Premise, Premise, Premise, Claim]"
4,29633159,"[Claim, Claim, Premise, Premise, Premise, Clai..."
...,...,...
95,29625440,"[Premise, Premise, Claim]"
96,23285843,"[Premise, Premise, Premise, Claim, Claim]"
97,22814041,"[Premise, Premise, Premise, Premise, Claim, Cl..."
98,29329305,"[Claim, Premise, Premise, Premise, Premise, Cl..."


In [297]:
grounds_l = grouped_df.aty.tolist()

In [298]:
grounds_l

[['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Claim', 'Claim', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise',

In [299]:
len(grounds_l)

100

In [300]:
preds_l = [item for sublist in preds_l for item in sublist]
grounds_l = [item for sublist in grounds_l for item in sublist]

In [301]:
len(preds_l), len(grounds_l)

(609, 609)

In [302]:
print(classification_report(grounds_l, preds_l, digits=4))

              precision    recall  f1-score   support

       Claim     0.5921    0.7736    0.6708       212
     Premise     0.8554    0.7154    0.7791       397

    accuracy                         0.7356       609
   macro avg     0.7237    0.7445    0.7250       609
weighted avg     0.7637    0.7356    0.7414       609

