# In-Context Learning for AbstRCT dataset with LlaMA 3

## Libraries

In [1]:
import ast
import torch
import random
import numpy as np
import pandas as pd
import torch.nn.functional as F


from tqdm import tqdm
from operator import itemgetter
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

## Tokenizers and Models

In [2]:
embedding_tokenizer = AutoTokenizer.from_pretrained("dmis-lab/biobert-v1.1")
embedding_model = AutoModel.from_pretrained("dmis-lab/biobert-v1.1")

In [3]:
model_id = "unsloth/llama-3-70b-Instruct-bnb-4bit"

In [4]:
inference_tokenizer = AutoTokenizer.from_pretrained(model_id, padding='left', padding_side='left')
inference_tokenizer.pad_token = inference_tokenizer.eos_token
terminators = [
    inference_tokenizer.eos_token_id,
    inference_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    cache_dir = '/home/umushtaq/scratch/am_work/in_context_learning/model_downloads',
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

## Dataset

In [6]:
neo_train_df = pd.read_csv("dataset/neo/train.csv")
neo_test_df = pd.read_csv("dataset/neo/test.csv")
gla_test_df = pd.read_csv("dataset/gla/test.csv")
mix_test_df = pd.read_csv("dataset/mix/test.csv")

In [7]:
def get_title_line(x, df):

    doc_id = x.doc_id
    df_doc = df[df["doc_id"] == doc_id].reset_index()
    
    for index, row in df_doc.iterrows():
        
        if pd.notna(row['text']):
            title_line = row['text'].split(".")[0]
            break

    return title_line

In [8]:
neo_train_df["abstract_title"] = neo_train_df.apply(lambda x: get_title_line(x, neo_train_df), axis=1)

In [9]:
neo_test_df["abstract_title"] = neo_test_df.apply(lambda x: get_title_line(x, neo_test_df), axis=1)

In [10]:
gla_test_df["abstract_title"] = gla_test_df.apply(lambda x: get_title_line(x, gla_test_df), axis=1)

In [11]:
mix_test_df["abstract_title"] = mix_test_df.apply(lambda x: get_title_line(x, mix_test_df), axis=1)

## Get title embeddings using BioBERT

In [12]:
title_embed_d = {}
for title in tqdm(neo_train_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 350/350 [00:32<00:00, 10.80it/s]


In [13]:
neo_train_df['title_embedding'] = neo_train_df.abstract_title.apply(lambda x: title_embed_d[x])

In [14]:
title_embed_d = {}
for title in tqdm(neo_test_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 99/99 [00:09<00:00, 10.43it/s]


In [15]:
neo_test_df['title_embedding'] = neo_test_df.abstract_title.apply(lambda x: title_embed_d[x])

In [16]:
title_embed_d = {}
for title in tqdm(gla_test_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 98/98 [00:08<00:00, 10.99it/s]


In [17]:
gla_test_df['title_embedding'] = gla_test_df.abstract_title.apply(lambda x: title_embed_d[x])

In [18]:
title_embed_d = {}
for title in tqdm(mix_test_df.abstract_title.unique()):
    # print(title)
    while True:
        try:
            inputs = embedding_tokenizer(title, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            title_embed_d[title] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

100%|██████████| 100/100 [00:09<00:00, 10.81it/s]


In [19]:
mix_test_df['title_embedding'] = mix_test_df.abstract_title.apply(lambda x: title_embed_d[x])

In [20]:
def get_abstract_texts(df):

    concatenated_text = df.fillna('').groupby('doc_id')['text'].agg(' '.join).reset_index()
    concatenated_text['text'] = concatenated_text['text'].str.strip()
    df = df.merge(concatenated_text, on='doc_id', suffixes=('', '_concatenated'))
    df = df.rename(columns={'text_concatenated': 'abstract_text'})

    return df

In [21]:
neo_train_df = get_abstract_texts(neo_train_df)
neo_test_df = get_abstract_texts(neo_test_df)
gla_test_df = get_abstract_texts(gla_test_df)
mix_test_df = get_abstract_texts(mix_test_df)

In [22]:
def process_majorclaims(x):

    aty = x.aty
    if aty == 'MajorClaim':
        aty = 'Claim'

    return aty

In [23]:
neo_train_df['aty'] = neo_train_df.apply(lambda x: process_majorclaims(x), axis=1)
neo_test_df['aty'] = neo_test_df.apply(lambda x: process_majorclaims(x), axis=1)
gla_test_df['aty'] = gla_test_df.apply(lambda x: process_majorclaims(x), axis=1)
mix_test_df['aty'] = mix_test_df.apply(lambda x: process_majorclaims(x), axis=1)

In [24]:
neo_train_df[neo_train_df.aty != 'none'].aty.value_counts(normalize=True)

aty
Premise    0.680052
Claim      0.319948
Name: proportion, dtype: float64

## Get K neighbours

In [25]:
def get_k_neighbours(k, title, test_df, train_df=neo_train_df):

    test_title_embedding = test_df[test_df.abstract_title == title]["title_embedding"].values[0]

    title_embed_d = {}
    for e in train_df.iterrows():
        if e[1].abstract_title not in title_embed_d:
            title_embed_d[e[1].abstract_title] = e[1].title_embedding

    # train_titles = set(df[df.split == 'TRAIN'].title.unique())
    train_titles = set(train_df.abstract_title.unique())

    dist_l = []
    for t, v in title_embed_d.items():
        if t in train_titles:
            # d = cos_sim(title_embed_d[title], v)
            d = F.cosine_similarity(torch.tensor(test_title_embedding), torch.tensor(v), dim=0)
            dist_l.append((t, d.item()))

    sorted_dist_l = sorted(dist_l, key=itemgetter(1), reverse=True)
    
    return sorted_dist_l[0: k]

In [26]:
get_k_neighbours(5, neo_test_df.iloc[100]["abstract_title"], neo_test_df)

[(' Research has shown that self-directed stress management training improves mental well-being in patients undergoing chemotherapy',
  0.9953505992889404),
 (' Numerous studies have examined the comorbidity of depression with cancer, and some have indicated that depression may be associated with cancer progression or survival',
  0.994724452495575),
 (' Behavioral symptoms are common in breast cancer survivors, including disturbances in energy, sleep, and mood, though few risk factors for these negative outcomes have been identified',
  0.9935887455940247),
 (' Lymphoma patients commonly experience declines in physical functioning and quality of life (QoL) that may be reversed with exercise training',
  0.9935306310653687),
 (' Few intervention studies have been conducted to help couples manage the effects of prostate cancer and maintain their quality of life',
  0.992271900177002)]

## Prepare Prompts

In [27]:
def prepare_similar_example_prompts(title, df, k=5, seed=33):
    """
    Create a part of prompt made of k examples in the train set, whose topic is most similar to a given title.
    """

    random.seed(seed)

    neighbours_l = get_k_neighbours(2*k, title, df) # Fetch the 2*k closest neighbors
    # print(neighbours_l)
    sampled_neighbours_l = random.sample(neighbours_l, k) # Only keep k of them
    # bprint(sampled_neighbours_l)

    prompt = ''
    cnt = 0
    for i, (title, dist) in enumerate(sampled_neighbours_l):
        prompt += f'## Example {i+1}\n'

        example_df = neo_train_df[neo_train_df.abstract_title == title]
        example_df = example_df[example_df.aty != 'none'].reset_index()
        
        class_l = []
        for k in example_df.iterrows():
            
            if k[0] == 0:

                prompt += f'# Abstract:\n{example_df.iloc[0].abstract_text}\n\n# Arguments:\n'
                cnt = 0
                
            # prompt += f'Argument {cnt + 1}={k[1].text} - Class={k[1].aty}\n'
            prompt += f'Argument {cnt + 1}={k[1].text}\n'
            class_l.append(k[1].aty)
            cnt += 1
            
        prompt += '\n# Result:\n'
        prompt += '{' + ', '.join([f'"Argument {i + 1}": "{class_l[i]}"' for i in range(len(class_l))]) + '}'
        prompt += '\n\n'

    return prompt

In [28]:
print(prepare_similar_example_prompts(neo_test_df.iloc[14]["abstract_title"], neo_test_df, k=5))

## Example 1
# Abstract:
A single-item linear analogue self-assessment scale for mood was compared with a 28-item adjective checklist for emotional well-being. To confirm its concurrent validity and responsiveness to treatment and recurrence in patients with breast cancer, emotional well-being was assessed every 3 months for 2 years and at 1 and 6 months after recurrence in 1,169 patients who were premenopausal and 960 patients who were postmenopausal. These patients were enrolled in two International Breast Cancer Study Group randomized clinical trials in operable breast cancer conducted from 1986 to 1993. To assess concurrent validity, Pearson's correlation between the linear analogue self-assessment scale and the adjective checklist were calculated for each time-point within each treatment group and for the two assessments after recurrence. Responsiveness to treatment and recurrence were analyzed using paired t tests and the squared ratio of these t tests, an estimate of relative ef

## Prepare Test Set Prompts

In [29]:
claim_fulldesc = "A claim in the abstract of an RCT is a statement or conclusion about the findings of the study."

In [30]:
claim_fulldesc

'A claim in the abstract of an RCT is a statement or conclusion about the findings of the study.'

In [31]:
premise_fulldesc = "A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim."

In [32]:
premise_fulldesc

'A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.'

In [33]:
proportion_desc = "68.0052% of examples are of type Premise and 31.9948% of type Claim."

In [34]:
proportion_desc

'68.0052% of examples are of type Premise and 31.9948% of type Claim.'

### Neo_test

In [35]:
%%time

#experiment_df = df[df.split == 'TEST']
experiment_df = neo_test_df
# experiment_df = gla_test_df
# experiment_df = mix_test_df

target_l = []

# Pre-prepare all examples from each document
title_l = []
essay_l = []
example_l = []
buffer_l = []
label_l = []
for e in experiment_df.iterrows():

    if e[1].span_pos == 1:
    #if e[0] == 0:
        essay_l.append(e[1].abstract_text)
        title_l.append(e[1].abstract_title)
    
    if e[1].span_pos == 1 and len(buffer_l) > 0:
    #if e[0] == 0 and len(buffer_l) > 0:
        example_l.append(buffer_l)
        target_l.append(label_l)
        buffer_l = []
        label_l = []
    if e[1].aty != 'none':
        buffer_l.append(e[1].text)
        label_l.append(e[1].aty)

# Flush the buffers
example_l.append(buffer_l)
target_l.append(label_l)


sys_msg_l_by_seed_d = {}
task_msg_l = []

for i in range(len(example_l)):

    # Prepare numbered list of ACs in this example
    other_acs_prompt = ''
    for j, s in enumerate(example_l[i]):
        # other_acs_prompt += f'Argument {j + 1}={s}\n'
        other_acs_prompt += f'Argument {j + 1}={s}\n'

    for seed in [33, 34, 35, 36, 37]:
        sys_msg = {"role":"system", "content": "### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. " + proportion_desc + " You must absolutely not generate any text or explanation other than the following JSON format {\"Argument 1\": <predicted class for Argument 1 (str)>, ..., \"Argument n\": <predicted class for Argument n (str)>}\n\n### Class definitions:" + " Claim = " + claim_fulldesc + " Premise = " + premise_fulldesc + "\n\n### Examples:\n\n" + prepare_similar_example_prompts(title_l[i], experiment_df, k=5, seed=seed)}  # Sample by similar title

        try:
            sys_msg_l_by_seed_d[seed].append(sys_msg)
        except KeyError:
            sys_msg_l_by_seed_d[seed] = [sys_msg]

    task_msg = {"role":"user", "content": f"# Abstract:\n{essay_l[i]}\n\n# Arguments:\n{other_acs_prompt}\n\n# Result:\n"}
    
    task_msg_l.append(task_msg)

CPU times: user 25 s, sys: 18.4 ms, total: 25 s
Wall time: 25.1 s


In [36]:
len(sys_msg_l_by_seed_d[33])

100

In [37]:
print(sys_msg_l_by_seed_d[33][0]['content'])

### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}

### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.

### Examples:

## Example 1
# Abstract:
The aim of this study was to compare the 2-year functional performance and quality of life in patients with operable squamous cell carcinoma of the esophagus, who have received either surgery or definitive chemoradiation (C

In [38]:
len(task_msg_l)

100

In [39]:
print(task_msg_l[0]['content'])

# Abstract:
To determine the safety and efficacy of oral bexarotene (Targretin capsules; Ligand Pharmaceuticals Incorporated, San Diego, Calif). The effects of 2 randomized doses of 6.5 mg/m(2) per day (with crossover for progression) vs 650 mg/m(2) per day (later modified to 300 mg/m(2) per day) were evaluated in an open-label, multicenter, phase 2 and 3 study conducted between February 1997 and November 1998. Eighteen international cutaneous T-cell lymphoma clinics at academic referral centers. Fifty-eight patients with biopsy-proven stage IA through IIA cutaneous T-cell lymphoma that was refractory to (or patients were intolerant of) treatment or had reached at least a 6-month response plateau under at least 2 forms of prior therapy (median of 3.5 prior therapies). Bexarotene (Targretin capsules) administered once daily with meal for 16 weeks or longer. Primary end point classification of overall response rate of complete and partial remissions determined by either the Physician's G

In [40]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l_by_seed_d[33])):
    prepared_sys_task_msg_l.append([sys_msg_l_by_seed_d[33][i], task_msg_l[i]])

In [41]:
len(prepared_sys_task_msg_l)

100

In [42]:
print(prepared_sys_task_msg_l[0])

[{'role': 'system', 'content': '### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}\n\n### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.\n\n### Examples:\n\n## Example 1\n# Abstract:\nThe aim of this study was to compare the 2-year functional performance and quality of life in patients with operable squamous cell carcinoma of the esophagus, who have received either

## Generate Predictions

In [43]:
outputs_l = []

for i in tqdm(range(len(prepared_sys_task_msg_l))):

    messages = prepared_sys_task_msg_l[i]

    input_ids = inference_tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(generation_model.device)

    outputs = generation_model.generate(
    input_ids = input_ids,
    max_new_tokens=1024,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    # inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    outputs_l.append(inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True))

  0%|          | 0/100 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
100%|██████████| 100/100 [14:13<00:00,  8.53s/it]


In [44]:
outputs_l

['{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Premise", "Argument 8": "Claim", "Argument 9": "Claim"}',
 '{"Argument 1": "Claim", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Claim", "Argument 8": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Claim", "Argument 8": "Premise", "Argument 9": "Claim"}',
 '{"Argument 1": "Claim", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Claim", "Argument 7": "Premise"}',
 '{"Argument 1": "Claim", "Argument 2": "Claim", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Claim", "Argument 7": "Claim", "Ar

In [45]:
len(outputs_l)

100

In [46]:
preds_l = [list(ast.literal_eval(output).values()) for output in outputs_l]

In [47]:
preds_l

[['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Claim',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Premise',
  'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Premise'],
 ['Claim',
  'Claim',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Claim', 'Claim'],
 ['Pr

In [48]:
len(preds_l)

100

In [49]:
grouped_df = neo_test_df[neo_test_df.aty != 'none'].groupby('doc_id', sort=False)['aty'].apply(list).reset_index()

In [50]:
grouped_df

Unnamed: 0,doc_id,aty
0,11346336,"[Premise, Premise, Premise, Premise, Premise, ..."
1,10210927,"[Claim, Premise, Premise, Premise, Premise, Pr..."
2,12748244,"[Premise, Premise, Premise, Premise, Premise, ..."
3,24028441,"[Claim, Premise, Premise, Premise, Premise, Cl..."
4,12004217,"[Claim, Premise, Premise, Premise, Premise, Pr..."
...,...,...
95,20734132,"[Premise, Premise, Premise, Claim]"
96,23793805,"[Premise, Premise, Premise, Premise, Claim, Cl..."
97,23816967,"[Premise, Premise, Premise, Premise, Premise, ..."
98,20680680,"[Premise, Premise, Premise, Premise, Claim, Cl..."


In [51]:
grounds_l = grouped_df.aty.tolist()

In [52]:
grounds_l

[['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Claim',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise'

In [53]:
len(grounds_l)

100

In [54]:
preds_l = [item for sublist in preds_l for item in sublist]
grounds_l = [item for sublist in grounds_l for item in sublist]

In [55]:
len(preds_l), len(grounds_l)

(691, 691)

In [56]:
print(classification_report(grounds_l, preds_l, digits=4))

              precision    recall  f1-score   support

       Claim     0.9220    0.8105    0.8627       248
     Premise     0.9006    0.9616    0.9301       443

    accuracy                         0.9074       691
   macro avg     0.9113    0.8861    0.8964       691
weighted avg     0.9083    0.9074    0.9059       691



### gla_test

In [57]:
%%time

#experiment_df = df[df.split == 'TEST']
# experiment_df = neo_test_df
experiment_df = gla_test_df
# experiment_df = mix_test_df

target_l = []

# Pre-prepare all examples from each document
title_l = []
essay_l = []
example_l = []
buffer_l = []
label_l = []
for e in experiment_df.iterrows():

    if e[1].span_pos == 1:
    #if e[0] == 0:
        essay_l.append(e[1].abstract_text)
        title_l.append(e[1].abstract_title)
    
    if e[1].span_pos == 1 and len(buffer_l) > 0:
    #if e[0] == 0 and len(buffer_l) > 0:
        example_l.append(buffer_l)
        target_l.append(label_l)
        buffer_l = []
        label_l = []
    if e[1].aty != 'none':
        buffer_l.append(e[1].text)
        label_l.append(e[1].aty)

# Flush the buffers
example_l.append(buffer_l)
target_l.append(label_l)


sys_msg_l_by_seed_d = {}
task_msg_l = []

for i in range(len(example_l)):

    # Prepare numbered list of ACs in this example
    other_acs_prompt = ''
    for j, s in enumerate(example_l[i]):
        # other_acs_prompt += f'Argument {j + 1}={s}\n'
        other_acs_prompt += f'Argument {j + 1}={s}\n'

    for seed in [33, 34, 35, 36, 37]:
        sys_msg = {"role":"system", "content": "### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. " + proportion_desc + " You must absolutely not generate any text or explanation other than the following JSON format {\"Argument 1\": <predicted class for Argument 1 (str)>, ..., \"Argument n\": <predicted class for Argument n (str)>}\n\n### Class definitions:" + " Claim = " + claim_fulldesc + " Premise = " + premise_fulldesc + "\n\n### Examples:\n\n" + prepare_similar_example_prompts(title_l[i], experiment_df, k=5, seed=seed)}  # Sample by similar title

        try:
            sys_msg_l_by_seed_d[seed].append(sys_msg)
        except KeyError:
            sys_msg_l_by_seed_d[seed] = [sys_msg]

    task_msg = {"role":"user", "content": f"# Abstract:\n{essay_l[i]}\n\n# Arguments:\n{other_acs_prompt}\n\n# Result:\n"}
    
    task_msg_l.append(task_msg)

CPU times: user 25 s, sys: 37.8 ms, total: 25 s
Wall time: 25.1 s


In [58]:
len(sys_msg_l_by_seed_d[33])

100

In [59]:
print(sys_msg_l_by_seed_d[33][0]['content'])

### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}

### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.

### Examples:

## Example 1
# Abstract:
Exercise for Health was a randomized, controlled trial designed to evaluate two modes of delivering (face-to-face [FtF] and over-the-telephone [Tel]) an 8-month translational exercise intervention, commencing 6-weeks post-b

In [60]:
len(task_msg_l)

100

In [61]:
print(task_msg_l[0]['content'])

# Abstract:
To compare the longitudinal effects of treatment on intraocular pressure (IOP) and visual field performance in Japanese normal-tension glaucoma (NTG) between latanoprost and timolol.
This is an open-label, randomized, study. A total of 62 NTG patients were prospectively, consecutively enrolled. All study subjects were randomly assigned to 0.005% latanoprost instillation once daily in the morning or 0.5% timolol instillation twice daily for a prospective 3-year follow-up, and underwent a routine ocular examination every month. Automated perimetry was performed every 6 months using Humphrey field analysers. Stereophotographs of optic discs were also obtained every 6 months. Percentage of IOP reduction or the magnitude of IOP reduction showed no intergroup differences either at any time point (13-15%). In the visual field, the estimated rate of change in the MD value (dB/year) was -0.34+/-0.17 (SE) for the latanoprost group, and -0.10+/-0.18 (SE) for the timolol group. The est

In [62]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l_by_seed_d[33])):
    prepared_sys_task_msg_l.append([sys_msg_l_by_seed_d[33][i], task_msg_l[i]])

In [63]:
len(prepared_sys_task_msg_l)

100

In [64]:
print(prepared_sys_task_msg_l[0])

[{'role': 'system', 'content': '### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}\n\n### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.\n\n### Examples:\n\n## Example 1\n# Abstract:\nExercise for Health was a randomized, controlled trial designed to evaluate two modes of delivering (face-to-face [FtF] and over-the-telephone [Tel]) an 8-month translational exercise

## Generate Predictions

In [65]:
outputs_l = []

for i in tqdm(range(len(prepared_sys_task_msg_l))):

    messages = prepared_sys_task_msg_l[i]

    input_ids = inference_tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(generation_model.device)

    outputs = generation_model.generate(
    input_ids = input_ids,
    max_new_tokens=1024,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    # inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    outputs_l.append(inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True))

100%|██████████| 100/100 [13:12<00:00,  7.92s/it]


In [66]:
outputs_l

['{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Claim", "Argument 7": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Claim", "Argument 8": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Claim", "Argument 7": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Claim", "Argument 6": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4"

In [67]:
len(outputs_l)

100

In [68]:
outputs_l_cleaned = []

for output in outputs_l:
    new_output = output.replace("Here is the result in the required JSON format:\n\n","")
    outputs_l_cleaned.append(new_output)

In [69]:
preds_l = [list(ast.literal_eval(output).values()) for output in outputs_l_cleaned]

In [70]:
preds_l

[['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise

In [71]:
len(preds_l)

100

In [72]:
grouped_df = gla_test_df[gla_test_df.aty != 'none'].groupby('doc_id', sort=False)['aty'].apply(list).reset_index()

In [73]:
grouped_df

Unnamed: 0,doc_id,aty
0,15037889,"[Premise, Premise, Premise, Premise, Premise, ..."
1,11124286,"[Premise, Premise, Premise, Premise, Premise, ..."
2,19427617,"[Claim, Premise, Premise, Premise, Claim]"
3,11336940,"[Premise, Premise, Premise, Premise, Premise, ..."
4,17947826,"[Premise, Premise, Claim]"
...,...,...
95,20545217,"[Premise, Premise, Premise, Claim]"
96,20092592,"[Premise, Premise, Premise, Premise, Claim]"
97,22814041,"[Premise, Premise, Premise, Premise, Claim, Cl..."
98,11813932,"[Premise, Premise, Premise, Premise, Premise, ..."


In [74]:
grounds_l = grouped_df.aty.tolist()

In [75]:
grounds_l

[['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
 

In [76]:
len(grounds_l)

100

In [77]:
preds_l = [item for sublist in preds_l for item in sublist]
grounds_l = [item for sublist in grounds_l for item in sublist]

In [78]:
len(preds_l), len(grounds_l)

(615, 615)

In [79]:
print(classification_report(grounds_l, preds_l, digits=4))

              precision    recall  f1-score   support

       Claim     0.9333    0.8063    0.8652       191
     Premise     0.9178    0.9741    0.9451       424

    accuracy                         0.9220       615
   macro avg     0.9256    0.8902    0.9051       615
weighted avg     0.9226    0.9220    0.9203       615



### mix_test

In [80]:
%%time

#experiment_df = df[df.split == 'TEST']
# experiment_df = neo_test_df
# experiment_df = gla_test_df
experiment_df = mix_test_df

target_l = []

# Pre-prepare all examples from each document
title_l = []
essay_l = []
example_l = []
buffer_l = []
label_l = []
for e in experiment_df.iterrows():

    if e[1].span_pos == 1:
    #if e[0] == 0:
        essay_l.append(e[1].abstract_text)
        title_l.append(e[1].abstract_title)
    
    if e[1].span_pos == 1 and len(buffer_l) > 0:
    #if e[0] == 0 and len(buffer_l) > 0:
        example_l.append(buffer_l)
        target_l.append(label_l)
        buffer_l = []
        label_l = []
    if e[1].aty != 'none':
        buffer_l.append(e[1].text)
        label_l.append(e[1].aty)

# Flush the buffers
example_l.append(buffer_l)
target_l.append(label_l)


sys_msg_l_by_seed_d = {}
task_msg_l = []

for i in range(len(example_l)):

    # Prepare numbered list of ACs in this example
    other_acs_prompt = ''
    for j, s in enumerate(example_l[i]):
        # other_acs_prompt += f'Argument {j + 1}={s}\n'
        other_acs_prompt += f'Argument {j + 1}={s}\n'

    for seed in [33, 34, 35, 36, 37]:
        sys_msg = {"role":"system", "content": "### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. " + proportion_desc + " You must absolutely not generate any text or explanation other than the following JSON format {\"Argument 1\": <predicted class for Argument 1 (str)>, ..., \"Argument n\": <predicted class for Argument n (str)>}\n\n### Class definitions:" + " Claim = " + claim_fulldesc + " Premise = " + premise_fulldesc + "\n\n### Examples:\n\n" + prepare_similar_example_prompts(title_l[i], experiment_df, k=5, seed=seed)}  # Sample by similar title

        try:
            sys_msg_l_by_seed_d[seed].append(sys_msg)
        except KeyError:
            sys_msg_l_by_seed_d[seed] = [sys_msg]

    task_msg = {"role":"user", "content": f"# Abstract:\n{essay_l[i]}\n\n# Arguments:\n{other_acs_prompt}\n\n# Result:\n"}
    
    task_msg_l.append(task_msg)

CPU times: user 24.9 s, sys: 31.6 ms, total: 24.9 s
Wall time: 25.1 s


In [81]:
len(sys_msg_l_by_seed_d[33])

100

In [82]:
print(sys_msg_l_by_seed_d[33][0]['content'])

### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}

### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.

### Examples:

## Example 1
# Abstract:
Malignant pleural effusion (MPE) is a common complication of advanced non-small cell lung cancer (NSCLC). Bevacizumab, a humanized monoclonal antibody against vascular endothelial growth factor (VEGF), has been shown to be 

In [83]:
len(task_msg_l)

100

In [84]:
print(task_msg_l[0]['content'])

# Abstract:
Few controlled clinical trials exist to support oral combination therapy in pulmonary arterial hypertension (PAH). Patients with PAH (idiopathic [IPAH] or associated with connective tissue disease [APAH-CTD]) taking bosentan (62.5 or 125 mg twice daily at a stable dose for ≥3 months) were randomized (1:1) to sildenafil (20 mg, 3 times daily; n = 50) or placebo (n = 53). The primary endpoint was change from baseline in 6-min walk distance (6MWD) at week 12, assessed using analysis of covariance. Patients could continue in a 52-week extension study. An analysis of covariance main-effects model was used, which included categorical terms for treatment, baseline 6MWD (<325 m; ≥325 m), and baseline aetiology; sensitivity analyses were subsequently performed. In sildenafil versus placebo arms, week-12 6MWD increases were similar (least squares mean difference [sildenafil-placebo], -2.4 m [90% CI: -21.8 to 17.1 m]; P = 0.6); mean ± SD changes from baseline were 26.4 ± 45.7 versus 1

In [85]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l_by_seed_d[33])):
    prepared_sys_task_msg_l.append([sys_msg_l_by_seed_d[33][i], task_msg_l[i]])

In [86]:
len(prepared_sys_task_msg_l)

100

In [87]:
print(prepared_sys_task_msg_l[0])

[{'role': 'system', 'content': '### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. 68.0052% of examples are of type Premise and 31.9948% of type Claim. You must absolutely not generate any text or explanation other than the following JSON format {"Argument 1": <predicted class for Argument 1 (str)>, ..., "Argument n": <predicted class for Argument n (str)>}\n\n### Class definitions: Claim = A claim in the abstract of an RCT is a statement or conclusion about the findings of the study. Premise = A premise in the abstract of an RCT is a statement that provides an evidence or proof for a claim.\n\n### Examples:\n\n## Example 1\n# Abstract:\nMalignant pleural effusion (MPE) is a common complication of advanced non-small cell lung cancer (NSCLC). Bevacizumab, a humanized monoclonal antibody against vascular endothelial gro

## Generate Predictions

In [88]:
outputs_l = []

for i in tqdm(range(len(prepared_sys_task_msg_l))):

    messages = prepared_sys_task_msg_l[i]

    input_ids = inference_tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(generation_model.device)

    outputs = generation_model.generate(
    input_ids = input_ids,
    max_new_tokens=1024,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    # inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    outputs_l.append(inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True))

100%|██████████| 100/100 [13:01<00:00,  7.82s/it]


In [89]:
outputs_l

['{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Premise", "Argument 8": "Claim", "Argument 9": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Premise", "Argument 7": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Claim"}',
 '{"Argument 1": "Claim", "Argument 2": "Claim", "Argument 3": "Premise", "Argument 4": "Premise", "Argument 5": "Premise", "Argument 6": "Claim", "Argument 7": "Claim"}',
 '{"Argument 1": "Claim", "Argument 2": "Premise", "Argument 3": "Premise", "Argument 4": "Claim", "Argument 5": "Claim"}',
 '{"Argument 1": "Premise", "Argument 2": "Premise", "Argument 3": "P

In [90]:
len(outputs_l)

100

In [91]:
preds_l = [list(ast.literal_eval(output).values()) for output in outputs_l]

In [92]:
preds_l

[['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Claim', 'Claim', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim',
  'Claim',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise',

In [93]:
len(preds_l)

100

In [94]:
grouped_df = mix_test_df[mix_test_df.aty != 'none'].groupby('doc_id', sort=False)['aty'].apply(list).reset_index()

In [95]:
grouped_df

Unnamed: 0,doc_id,aty
0,28874133,"[Premise, Premise, Premise, Claim]"
1,11346336,"[Premise, Premise, Premise, Premise, Premise, ..."
2,28370419,"[Premise, Premise, Premise, Premise, Premise, ..."
3,29596950,"[Premise, Premise, Premise, Premise, Claim]"
4,29633159,"[Claim, Claim, Premise, Premise, Premise, Clai..."
...,...,...
95,29625440,"[Premise, Premise, Claim]"
96,23285843,"[Premise, Premise, Premise, Claim, Claim]"
97,22814041,"[Premise, Premise, Premise, Premise, Claim, Cl..."
98,29329305,"[Claim, Premise, Premise, Premise, Premise, Cl..."


In [96]:
grounds_l = grouped_df.aty.tolist()

In [97]:
grounds_l

[['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Claim', 'Claim', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Claim',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Premise',
  'Claim',
  'Claim',
  'Premise',
  'Claim'],
 ['Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim'],
 ['Premise', 'Premise', 'Premise', 'Premise', 'Claim', 'Claim', 'Claim'],
 ['Claim', 'Premise', 'Premise', 'Premise', 'Claim'],
 ['Premise', 'Premise', 'Premise',

In [98]:
len(grounds_l)

100

In [99]:
preds_l = [item for sublist in preds_l for item in sublist]
grounds_l = [item for sublist in grounds_l for item in sublist]

In [100]:
len(preds_l), len(grounds_l)

(609, 609)

In [101]:
print(classification_report(grounds_l, preds_l, digits=4))

              precision    recall  f1-score   support

       Claim     0.9303    0.8821    0.9056       212
     Premise     0.9387    0.9647    0.9516       397

    accuracy                         0.9360       609
   macro avg     0.9345    0.9234    0.9286       609
weighted avg     0.9358    0.9360    0.9355       609

