In [13]:
import pandas as pd
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import re
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments
from trl import SFTTrainer, SFTConfig
import torch
from peft import LoraConfig, get_peft_model
from datasets import Dataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from statistics import mean
from tqdm import tqdm
import random

In [14]:


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_id = 'mistralai/Mistral-7B-Instruct-v0.3'

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token_id = tokenizer.eos_token_id
#tokenizer.add_special_tokens({'additional_special_tokens':['<INFORMATION>','</INFORMATION>','<PERCEPTION>', '</PERCEPTION>', '<BACKGROUND>', '</BACKGROUND>']})

# bnb_config = BitsAndBytesConfig(
#     load_in_4bit=True,
#     bnb_4bit_quant_type="nf4",
#     bnb_4bit_compute_dtype=torch.bfloat16,
#     bnb_4bit_use_double_quant=True,
#     bnb_4bit_quant_storage=torch.bfloat16,
# )

# LMmodel = AutoModelForCausalLM.from_pretrained(
#     model_id,
#     quantization_config = bnb_config,
#     torch_dtype = torch.bfloat16,
#     device_map = 'auto'
# )
# LMmodel.resize_token_embeddings(len(tokenizer))

# peft_config = LoraConfig(target_modules=[ "v_proj", "q_proj", "up_proj", "o_proj", "k_proj", "down_proj", "gate_proj" ], inference_mode=False, r=4, lora_alpha=32, lora_dropout=0.1)

# LMmodel = get_peft_model(LMmodel, peft_config)

# LMmodel.print_trainable_parameters()

In [2]:
import json
with open('../data/multicite/full_raw.json', 'r') as file:
    data = json.load(file)

In [30]:
def label_mapping(label):
    if label == '@BACK@': return 'BACKGROUND'
    if label == '@MOT@': return 'MOTIVATION'
    if label == '@USE@': return 'USE'
    if label == '@EXT@': return 'EXTENDS'
    if label == '@SIM@': return 'SIMILARITY'
    if label == '@DIF@': return 'DIFFERENCES'
    if label == '@fut@': return 'FUTURE'
    if label =='BACKGROUND': return 0
    if label =='MOTIVATION': return 1
    if label =='USE': return 2
    if label =='EXTENDS': return 3
    if label =='SIMILARITY': return 4
    if label =='DIFFERENCES': return 5
    if label =='FUTURE': return 6
    
    
def seperate_segments(text_json):
    # create sent lookup
    sent_lookup = {sent_entry['sent_id']:sent_entry['text'] for sent_entry in text_json}

    # sort sentences keys:
    sorted_keys = sorted(list(sent_lookup.keys()), key=lambda x: int(x.split('-')[-1]))
    
    #seperate segments
    key_sec_lookup, sec_arr = {}, [[]]
    current_sec = 0
    sec_idx = 0
    for key in sorted_keys:
        text = sent_lookup[key]
        if re.match(r'----------------------------------', text):
            current_sec += 1
            sec_idx = 0
            sec_arr.append([])
            key_sec_lookup[key] = -1
        else:
            sec_arr[current_sec].append(text)
            key_sec_lookup[key]= (current_sec, sec_idx)
            sec_idx += 1
    return key_sec_lookup, sec_arr

def get_cit_dict(label_json):
    res_dict = {}
    for label, context_data in label_json.items():
        if len(context_data['gold_contexts']) != len(context_data['cite_sentences']):continue
        for sent_id, context in zip(context_data['cite_sentences'], context_data['gold_contexts']):
            if sent_id in res_dict:
                res_dict[sent_id]['label'].append(label)
                res_dict[sent_id]['context'] = list(set(res_dict[sent_id]['context'] + context))
            else:
                res_dict[sent_id] = {
                    'label': [label],
                    'context': context
                }
    return res_dict

data_json = {}
for big_key in tqdm(data.keys()):
    data_json[big_key] = []
    sent_lookup, sent_arr = seperate_segments(data[big_key]['x'])
    cit_dict = get_cit_dict(data[big_key]['y'])

    for key in cit_dict.keys():
        
        #check if key in lookup
        if key not in sent_lookup.keys(): continue
        
        #set basic data
        labels, context = cit_dict[key].values()
        sec_id, citing_sent_id = sent_lookup[key]
        heading = sent_arr[sec_id][0]


        #add all context sentences and enclosed sentences to input_arr
        sorted_context = sorted(context, key=lambda x: int(x.split('-')[-1]))
        context_sent_ids = [sent_lookup[context_sent][1] if sent_lookup[context_sent]!= -1 else -1 for context_sent in sorted_context]
        if -1 in context_sent_ids: continue
        input_arr = list(range(context_sent_ids[0], context_sent_ids[-1] +1))
        #chekc if input_arr > 5 -> continue
        if len(input_arr) > 5 or len(input_arr) ==0 or len(sent_arr[sec_id]) < 3: continue

        # randomly add sentences to front / back until 5 sentences or all sentences are added
        while len(input_arr) < 5 and len(input_arr) < len(sent_arr[sec_id])-1:
            chance = random.randint(0,1)
            prev = input_arr[0] -1
            next = input_arr[-1] + 1
            if chance == 0 and prev > 0:
                input_arr.insert(0, prev)
            elif chance == 1 and next < len(sent_arr[sec_id]):
                input_arr.append(next)

        #replace <ref> tags in text
        input_text = f'Section Heading: {heading}\n\n '
        for i, idx in enumerate(input_arr):
            sent = sent_arr[sec_id][idx]
            if idx == citing_sent_id:
                clean_sent = re.sub(r'<span.*?>(.*?)<\/span>','#AUTHOR_TAG', sent)
            else:
                clean_sent = re.sub(r'<span.*?>(.*?)<\/span>',r'\1', sent)
            input_text += f'sent{i}: {clean_sent}\n '
            

        trainin_sample = {
            "x": input_text,
            "y" : {
                'labels': [label_mapping(label) for label in labels],
                'context': [f'sent{input_arr.index(id)}' for id in context_sent_ids]
            },
        }
        data_json[big_key].append(trainin_sample)

  0%|          | 0/1193 [00:00<?, ?it/s]

 36%|███▌      | 428/1193 [00:00<00:00, 2142.56it/s]

1
4
2
4
3
4
10
24
11
24
12
24
13
24
14
24
9
14
10
14
11
14
12
14
13
14
6
24
7
24
8
24
9
24
10
24
10
24
11
24
12
24
13
24
14
24
7
14
8
14
9
14
10
14
11
14
9
24
10
24
11
24
12
24
13
24
9
14
10
14
11
14
12
14
13
14
7
24
8
24
9
24
10
24
11
24
15
24
16
24
17
24
18
24
19
24
1
7
2
7
3
7
4
7
5
7
1
12
2
12
3
12
4
12
5
12
4
12
5
12
6
12
7
12
8
12
1
4
2
4
3
4
2
7
3
7
4
7
5
7
6
7
5
12
6
12
7
12
8
12
9
12
1
7
2
7
3
7
4
7
5
7
2
7
3
7
4
7
5
7
6
7
1
14
2
14
3
14
4
14
5
14
3
14
4
14
5
14
6
14
7
14
7
12
8
12
9
12
10
12
11
12
1
10
2
10
3
10
4
10
5
10
1
12
2
12
3
12
4
12
5
12
3
12
4
12
5
12
6
12
7
12
1
7
2
7
3
7
4
7
5
7
1
4
2
4
3
4
5
10
6
10
7
10
8
10
9
10
1
14
2
14
3
14
4
14
5
14
1
12
2
12
3
12
4
12
5
12
9
14
10
14
11
14
12
14
13
14
1
14
2
14
3
14
4
14
5
14
1
6
2
6
3
6
4
6
5
6
1
18
2
18
3
18
4
18
5
18
3
18
4
18
5
18
6
18
7
18
8
14
9
14
10
14
11
14
12
14
1
9
2
9
3
9
4
9
5
9
2
18
3
18
4
18
5
18
6
18
3
12
4
12
5
12
6
12
7
12
6
12
7
12
8
12
9
12
10
12
1
6
2
6
3
6
4
6
5
6
1
17
2
17
3
17
4
17
5
17
4
17
5
17
6


 86%|████████▋ | 1031/1193 [00:00<00:00, 2744.88it/s]

10
23
11
23
12
23
15
23
16
23
17
23
18
23
19
23
1
11
2
11
3
11
4
11
5
11
2
11
3
11
4
11
5
11
6
11
1
6
2
6
3
6
4
6
5
6
1
6
2
6
3
6
4
6
5
6
1
6
2
6
3
6
4
6
5
6
1
9
2
9
3
9
4
9
5
9
8
14
9
14
10
14
11
14
12
14
11
25
12
25
13
25
14
25
15
25
1
9
2
9
3
9
4
9
5
9
9
14
10
14
11
14
12
14
13
14
1
5
2
5
3
5
4
5
17
22
18
22
19
22
20
22
21
22
6
25
7
25
8
25
9
25
10
25
1
14
2
14
3
14
4
14
5
14
1
16
2
16
3
16
4
16
5
16
11
16
12
16
13
16
14
16
15
16
2
7
3
7
4
7
5
7
6
7
1
9
2
9
3
9
4
9
5
9
1
5
2
5
3
5
4
5
1
24
2
24
3
24
4
24
5
24
1
17
2
17
3
17
4
17
5
17
1
3
2
3
15
20
16
20
17
20
18
20
19
20
9
14
10
14
11
14
12
14
13
14
7
23
8
23
9
23
10
23
11
23
1
7
2
7
3
7
4
7
5
7
4
20
5
20
6
20
7
20
8
20
7
20
8
20
9
20
10
20
11
20
1
14
2
14
3
14
4
14
5
14
1
4
2
4
3
4
3
13
4
13
5
13
6
13
7
13
7
12
8
12
9
12
10
12
11
12
6
21
7
21
8
21
9
21
10
21
1
14
2
14
3
14
4
14
5
14
5
10
6
10
7
10
8
10
9
10
4
14
5
14
6
14
7
14
8
14
5
21
6
21
7
21
8
21
9
21
12
21
13
21
14
21
15
21
16
21
1
5
2
5
3
5
4
5
19
30
20
30
21
30
22
30
23
30


100%|██████████| 1193/1193 [00:00<00:00, 2404.23it/s]

21
8
21
9
21
10
21
11
21
16
21
17
21
18
21
19
21
20
21
3
22
4
22
5
22
6
22
7
22
9
22
10
22
11
22
12
22
13
22
3
13
4
13
5
13
6
13
7
13
1
17
2
17
3
17
4
17
5
17
2
17
3
17
4
17
5
17
6
17
12
17
13
17
14
17
15
17
16
17
11
22
12
22
13
22
14
22
15
22
12
22
13
22
14
22
15
22
16
22
14
22
15
22
16
22
17
22
18
22
8
13
9
13
10
13
11
13
12
13
7
13
8
13
9
13
10
13
11
13
4
14
5
14
6
14
7
14
8
14
7
14
8
14
9
14
10
14
11
14
3
28
4
28
5
28
6
28
7
28
7
28
8
28
9
28
10
28
11
28
9
28
10
28
11
28
12
28
13
28
14
28
15
28
16
28
17
28
18
28
1
8
2
8
3
8
4
8
5
8
1
18
2
18
3
18
4
18
5
18
4
9
5
9
6
9
7
9
8
9
2
7
3
7
4
7
5
7
6
7
2
7
3
7
4
7
5
7
6
7
1
9
2
9
3
9
4
9
5
9
1
9
2
9
3
9
4
9
5
9
1
18
2
18
3
18
4
18
5
18
1
18
2
18
3
18
4
18
5
18
5
18
6
18
7
18
8
18
9
18
10
18
11
18
12
18
13
18
14
18
12
18
13
18
14
18
15
18
16
18
4
18
5
18
6
18
7
18
8
18
1
7
2
7
3
7
4
7
5
7
5
29
6
29
7
29
8
29
9
29
2
17
3
17
4
17
5
17
6
17
20
29
21
29
22
29
23
29
24
29
20
35
21
35
22
35
23
35
24
35
5
13
6
13
7
13
8
13
9
13
1
30
2
30
3
30
4
3




In [None]:
def get_fine_tune_prompt( 
    input_str: str,
    label_str: str,
    tokenizer,
    test: bool = False
) -> torch.Tensor:

    usr_msg1 = "You are given a excerpt from a scientific text with one citation marked as #AUTHOR_TAG. Further you are given a list of citation functions." \
        " Your task is it to find the citation function class of the marked citation."\
        ' To do this, first, find the marked citation, second, identify which function class fits the best for the marked citation by considering the surrounding words, third, reply with a short valid json only containing the label e.g.: {"label": "USE"}.'\
        ' Make sure the output is a valid jason with "" instead of \'\' '\
        f"""\n\nList Citation Function Types\n
        BACKGROUND: The cited paper provides relevant background information or is part of the body of literature. \n
        USE: The citing paper uses the methodology or tools created by the cited paper. \n
        COMPARE_CONTRAST: The citing paper expresses similarities to or or differences from, or disagrees with, the cited paper.\n
        MOTIVATION: The citing paper is directly motivated by the cited paper. \n
        EXTENSION: The citing paper extends the methods, tools, or data of the cited paper.\n
        FUTURE: The cited paper is a potential avenue for future work. \n
        """\
        "\n\n" \
        "Are the instructions clear to you?"
    
    asst_msg1 = "Yes, the instructions are clear to me. I will determine the citation function class of the marked citation (#AUTHOR_TAG) based on the provided citation function types and respond in json fromat."
    
    usr_msg2 = "In a similar vain to #AUTHOR_TAG and Buchholz et al. ( 1999 ) , the method extends an existing flat shallow-parsing method to handle composite structures ."

    asst_msg2 = '{"label": "COMPARE_CONTRAST"}'

    usr_msg3 = "Give a brief explanation of why your answer is correct."

    asst_msg3 = "The marked citation (#AUTHOR_TAG) is compared to the method proposed in the citing paper and judged as similar."\
                "Therefore, the citation function class is 'COMPARE_CONTRAST'"
    
    usr_msg4 = "Great! I am now going to give you another excerpt. Please detect the function class in it " \
                "according to the previous instructions. Do not include an explanation in your answer."
    
    asst_msg4 = "Sure! Please give me the user utterance."

    messages = [
        {"role": "user", "content": usr_msg1},
        {"role": "assistant", "content": asst_msg1},
        {"role": "user", "content": usr_msg2},
        {"role": "assistant", "content": asst_msg2},
        {"role": "user", "content": usr_msg3},
        {"role": "assistant", "content": asst_msg3},
        {"role": "user", "content": usr_msg4},
        {"role": "assistant", "content": asst_msg4},
        {"role": "user", "content": input_str},
    ]
    if not test: messages.append({"role": "assistant", "content": label_str})
    
    encoded_input_ids = tokenizer.apply_chat_template(messages)

    return {'input_ids': encoded_input_ids}

# res = LMmodel.generate(torch.tensor([get_fine_tune_prompt('','',tokenizer)['input_ids']]), max_new_tokens=512)
# tokenizer.decode(res[0]).split('[/INST]')[-1]

In [None]:
# prepare data
def create_clean_data(df):
    res_df = pd.DataFrame(columns=['par', 'label', 'json'])
    for idx, row in df.iterrows():
        par = ' '.join(eval(row['cite_context_paragraph']))
        par_token = tokenizer.encode(par)
        if len(par_token)>400:
            if '#AUTHOR_TAG' in tokenizer.decode(par_token[:400]):
                par = tokenizer.decode(par_token[:400])
            elif '#AUTHOR_TAG' in tokenizer.decode(par_token[len(par_token)-400:]):
                par = tokenizer.decode(par_token[len(par_token)-400:])
            else: continue
        label = row['citation_class_label']
        json = '{"label": "' + label_mapping(label) + '"}'
        res_df.loc[len(res_df)] = [par, label, json]
    return res_df

train_df = create_clean_data(train)
test_df = create_clean_data(test)

# Convert the DataFrame to a Dataset
train_ds = Dataset.from_pandas(train_df)
test_ds = Dataset.from_pandas(test_df)

#Apply the tokenization function to the dataset
train_ds = train_ds.map(
    lambda row: get_fine_tune_prompt(row['par'], row['json'], tokenizer), 
    batched=False, 
    remove_columns=train_ds.column_names  # Remove all original columns
)

test_ds = test_ds.map(
    lambda row: get_fine_tune_prompt(row['par'], row['json'], tokenizer, True), 
    batched=False, 
    remove_columns=test_ds.column_names  # Remove all original columns
)
        

In [None]:
import seaborn as sns
length = []
for item in train_ds:
    length.append(len(item['input_ids']))
sns.histplot(data=length)

In [None]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union

@dataclass
class CustomDataCollatorWithPadding:
    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None
    return_tensors: str = "pt"

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
        batch = self.tokenizer.pad(
            features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors,
        )
        labels = batch["input_ids"].clone()
        
        # # Set loss mask for all pad tokens
        # labels[labels == self.tokenizer.pad_token_id] = -100
        
        # Compute loss mask for appropriate tokens only
        for i in range(batch['input_ids'].shape[0]):
            
            # Decode the training input
            text_content = self.tokenizer.decode(batch['input_ids'][i][1:])  # slicing from [1:] is important because tokenizer adds bos token
            
            # Extract substrings for prompt text in the training input
            # The training input ends at the last user msg ending in [/INST]
            prompt_gen_boundary = text_content.rfind("[/INST]") + len("[/INST]")
            prompt_text = text_content[:prompt_gen_boundary]
            
            # print(f"""PROMPT TEXT:\n{prompt_text}""")
            
            # retokenize the prompt text only
            prompt_text_tokenized = self.tokenizer(
                prompt_text,
                return_overflowing_tokens=False,
                return_length=False,
            )
            # compute index where prompt text ends in the training input
            prompt_tok_idx = len(prompt_text_tokenized['input_ids'])
            
            # Set loss mask for all tokens in prompt text
            labels[i][range(prompt_tok_idx)] = -100
            
                    
        batch["labels"] = labels
        return batch

In [None]:
max_seq_length = 1024

data_collator=CustomDataCollatorWithPadding(
    tokenizer=tokenizer, 
    padding="longest", 
    max_length=max_seq_length, 
    return_tensors="pt"
)


training_arguments = SFTConfig(
    output_dir="./tmp",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    dataset_text_field="input_ids",
    max_seq_length=1024,
    learning_rate=1e-4,
    max_steps=2500,
    warmup_ratio=0.1,
    weight_decay=0.01,
)
trainer = SFTTrainer(
    model=LMmodel,
    train_dataset=train_ds,
    tokenizer=tokenizer,
    args=training_arguments,
    # Using custom data collator inside SFTTrainer
    data_collator=data_collator
)

In [None]:
import json
def is_json(myjson):
  try:
    json.loads(myjson)
  except ValueError as e:
    return False
  return True

In [None]:
from transformers import logging
logging.set_verbosity_error()

def evaluate(model, test_ds, data_collator):

    # eval_df = pd.DataFrame(columns=['pred', 'label_pred', 'is_valid_json'])
    # #generate response
    # for input_data in tqdm(test_ds):
    #     res = model.generate(data_collator([input_data])['input_ids'].to(DEVICE), max_new_tokens=20)
    #     answ = tokenizer.decode(res[0]).split('[/INST]')[-1]
    #     clean_answ = re.sub('</s>', '', answ)
    #     is_valid_json = is_json(clean_answ)
    #     label = None
    #     if is_valid_json: 
    #         label = int(label_mapping(json.loads(clean_answ)['label']))
    #     eval_df.loc[len(eval_df)] = [clean_answ, label, is_valid_json]
    
    
    # #evaluate response
    # eval_df = pd.concat([test_df.loc[: len(eval_df) -1], eval_df], axis=1)
    # return eval_df

def calculate_metrics(eval_df):
    # valid_json = sum(eval_df['is_valid_json']) / len(eval_df)
    # eval_df.dropna(inplace=True)
    # macro_f1 = f1_score([int(no) for no in eval_df['label']], [int(no) for no in eval_df['label_pred']], average='macro')
    # micro_f1 = f1_score([int(no) for no in eval_df['label']], [int(no) for no in eval_df['label_pred']], average='micro')
    # return valid_json, micro_f1, macro_f1

        

In [None]:
eval_df = evaluate(trainer.model, test_ds, data_collator)
metric_res = calculate_metrics(eval_df)
print(metric_res)
eval_df

In [None]:
trainer.train()

In [None]:
eval_df = evaluate(trainer.model, test_ds, data_collator)
metric_res = calculate_metrics(eval_df)
print(metric_res)
eval_df.head()