In [None]:
from unsloth import FastLanguageModel
from unsloth import is_bfloat16_supported
from unsloth import unsloth_train

from datasets import Dataset
from vllm import SamplingParams
from transformers import TrainingArguments, TrainerCallback
from transformers import TextStreamer
from trl import SFTTrainer

In [None]:
import random
import json
import re
import itertools
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

from sampling.sampling_fallacies_detection import load_all_dataset, get_prt, get_spl, get_nb_element

In [None]:
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
paths = {
    'cocolofa': './Data_jsonl/cocolofa.jsonl',
    'mafalda': './Data_jsonl/mafalda.jsonl'
}

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name= "meta-llama/meta-Llama-3.1-8B-Instruct",
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=load_in_4bit,
    fast_inference=True,
    gpu_memory_utilization=0.6
)
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

In [None]:
SYSTEM_PROMPT = 'You are an expert in argumentation. Your task is to determine the type of fallacy in the given [SENTENCE]. The fallacy would be in the [FALLACY] Set. Utilize the [TITLE] and the [FULL TEXT] as context to support your decision.\nYour answer must be in the following format with only the fallacy in the answer section:\n<|ANSWER|><answer><|ANSWER|>.'

n_sample = 1000
data, fallacies = load_all_dataset(paths)
df_train, df_val, df_test = get_prt(data, fallacies, SYSTEM_PROMPT)
sample_train, over_train = get_spl(df_train, fallacies, n_sample=n_sample)
sample_val, over_val = get_spl(df_val, fallacies, n_sample=n_sample*0.2)
sample_test, over_test = get_spl(df_test, fallacies, n_sample=n_sample*0.2)

In [None]:
def formatting_prompt(data: dict):
    text = tokenizer.apply_chat_template(data.get('prompt'),tokenize = False, add_generation_prompt = False)
    return { 'text': text, }

SAMPLING_PARAMS = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 128,
)
data_train = Dataset.from_pandas(sample_train).map(
    formatting_prompt,
    batched=True,
)
data_val = Dataset.from_pandas(sample_val).map(
    formatting_prompt,
    batched=True,
)
data_test = Dataset.from_pandas(sample_test).map(
    formatting_prompt,
    batched=True
)

In [None]:
# def gen(txt, model, sampling_params):
#     output = model.fast_generate(
#         txt,
#         sampling_params = sampling_params,
#     )[0].outputs[0].text
    
#     return output

# def format_output(answer: str, fallacies: set) -> list:
#     s = '<[|]ANSWER[|]>'
#     tmp = re.split(s, answer)
#     pred= [i for i in tmp if i in fallacies]
#     return pred

# def zero_shot_gen(data: list[str], model, fallacies: set, sampling_params) -> list:
#     res = []
#     for i in data:
#         out = gen(i, model, sampling_params)
#         pred = format_output(out, fallacies)
#         res.append(pred)
#     return res

In [None]:
def gen(p, model, text_streamer):
    txt = tokenizer.apply_chat_template(
        p,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to('cuda')
    output = model.generate(
        txt,
        streamer=text_streamer,
        max_new_tokens=128,
        pad_token_id=tokenizer.eos_token_id
    )
    return output

def format_output(answer: list, labels: set) -> list:
    s = '<[|]ANSWER[|]>'
    tmp = re.split(s, answer[0])
    pred = [i for i in tmp if i in labels]
    return pred

def zero_shot_gen(
    data:Dataset,
    model,
    labels:set,
    text_streamer: TextStreamer
) -> list:
    res= []
    prompt = data['prompt']
    for prt in prompt:
        out = gen(prt, model, text_streamer)
        decoded_out = tokenizer.batch_decode(out)
        pred = format_output(decoded_out, labels)
        res.append(pred)
    return res

In [None]:
# class custom_validation_callback(TrainerCallback):
#     def __init__(self, data, sampling_params, fallacies, n_step=10):
#         super().__init__()
#         self.val_dataset = data
#         self.sampling_params = sampling_params
#         self.fallacies = fallacies
#         self.n_step=n_step
#     def on_step_end(self, args, state, control, **kwargs):
#         if state.global_step % self.n_step == 0 and state.global_step > 0 :
#             model.save_lora('sft_save_lora')
#             FastLanguageModel.for_inference(model)
#             pred = zero_shot_gen(
#                 data=self.val_dataset['text'],
#                 model=model,
#                 fallacies=self.fallacies,
#                 sampling_params=self.sampling_params
#             )
#             tmp_pred = [i if i != [] else ['Failed'] for i in pred]
#             d = pd.DataFrame().from_records(tmp_pred)
#             d['truth_label'] = self.val_dataset['answer']
#             d['step'] = np.full((len(d['truth_label']),), state.global_step)
#             try:
#                 d.to_csv(
#                     './validation_res.csv',
#                     index=False,
#                     mode='a',
#                     header=['pred', 'truth_label', 'step']
#                 )
#             except FileNotFoundError:
#                 d.to_csv('./validation_res.csv', index=False, header=['pred', 'truth_label'])
#         return super().on_step_end(args, state, control, **kwargs)

# class custom_test_callback(TrainerCallback):
#     def __init__(self, data, sampling_params, fallacies):
#         super().__init__()
#         self.test_dataset = data
#         self.sampling_params = sampling_params
#         self.fallacies = fallacies
#     def on_train_end(self, args, state, control, **kwargs):
#         model.save_lora('sft_save_lora')
#         FastLanguageModel.for_inference(model)
#         pred = zero_shot_gen(
#             data=self.test_dataset['text'],
#             model=model,
#             fallacies=self.fallacies,
#             sampling_params=self.sampling_params
#         )
#         tmp_pred = [i if i != [] else ['Failed'] for i in pred]
#         d = pd.DataFrame().from_records(tmp_pred)
#         d['truth_label'] = self.test_dataset['answer']
#         try:
#             d.to_csv('./test_res.csv', index=False, mode='a', header=['pred','truth_label'])
#         except FileNotFoundError:
#             d.to_csv('./test_res.csv', index=False, header=['pred', 'truth_label'])
#         return super().on_train_end(args, state, control, **kwargs)

In [None]:
outputs_dir = './outputs/1e1000spl'
training_args = TrainingArguments(
    per_device_train_batch_size = 4, #2
    per_device_eval_batch_size= 4,
    gradient_accumulation_steps = 8, #4
    eval_accumulation_steps= 8,
    warmup_steps = 5,
    num_train_epochs = 1.5, # Set this for 1 full training run.
    # max_steps = 60,
    learning_rate = 2e-4,
    fp16 = not is_bfloat16_supported(),
    bf16 = is_bfloat16_supported(),
    logging_steps = 1,
    optim = "adamw_8bit",
    weight_decay = 0.01,
    lr_scheduler_type = "linear",
    seed = 3407,
    output_dir = outputs_dir,
    report_to = "tensorboard", # Use this for WandB etc
    eval_strategy="steps",
    eval_steps=4,
)   
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = data_train,
    eval_dataset=data_val,
    # formatting_func=formatting_prompt,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    packing = False, # Can make training 5x faster for short sequences.
    args = training_args,
)
# trainer.get_train_dataloader().shuffle = False
# trainer.get_eval_dataloader().shuffle = False
# trainer.train()
unsloth_train(trainer)

In [None]:
def get_precision_recall(data):
    try:
        tmp = data.apply(
            lambda x: x['pred'] in x['lbl'],
            axis=1
        )
        tp = tmp.value_counts().loc[True]
        score = tp / len(tmp)
        fp_fn = len(tmp) - tp
        return score, tp, fp_fn
    except (KeyError, ZeroDivisionError):
        tp = 0
        score = 0
        fp_fn = len(tmp) - tp
        return score, tp, fp_fn

def get_f1(precision, recall):
    try:
        f1 = 2 * ((precision * recall) / (precision + recall))
    except ZeroDivisionError:
        f1 = 0
    return f1

def get_metrics(data, labels):
    tp_preci = 0
    tp_rec = 0
    fn = 0
    fp = 0
    res = {}
    for l in labels:
        df_on_pred = data[data['pred'] == l]
        df_on_label = data.apply(
            lambda x: x if l in x['lbl'] else np.nan,
            result_type='broadcast',
            axis = 1
        ).dropna()
        precision, tp_p, fp_p = get_precision_recall(df_on_pred)
        recall, tp_r, fn_r = get_precision_recall(df_on_label)
        f1 = get_f1(precision, recall)
        res.update({l: (f1, precision, recall)})
        fn += fn_r
        fp += fp_p
        tp_preci += tp_p
        tp_rec += tp_r
    precision_all_data = tp_preci / (tp_preci + fp)
    recall_all_data = tp_rec / (tp_rec + fn)
    f1_all_data = get_f1(precision_all_data, recall_all_data)
    res.update({
        'score_all_data': (f1_all_data, precision_all_data, recall_all_data)
    })
    return res

FastLanguageModel.for_inference(model)
text_streamer = TextStreamer(tokenizer)
pred = zero_shot_gen(
    data=data_test[:10],
    model=model,
    labels=fallacies,
    text_streamer=text_streamer
)

names_dataset = data_test[:10]['datasets']
true_labels = data_test[:10]['answer']

tmp_pred = [i if i != [] else ['Failed'] for i in pred]
pred_flat = list(itertools.chain.from_iterable(tmp_pred))

d_res = {'names': names_dataset, 'pred': pred_flat, 'lbl': true_labels}
df_res = pd.DataFrame(data=d_res)
metric = get_metrics(df_res, fallacies)

In [None]:
def plot_stat_sample(
    sample: pd.DataFrame,
    over: dict,
    lst_labels: set[str],
    n_sample: int,
    savefile=None
) -> None:
    width = 0.3
    nb_per_lbl = {}
    fig, ax = plt.subplots()
    d = {}
    x = np.arange(len(lst_labels))
    mult = 0
    dataset_names = sample['datasets'].value_counts().index.to_list()
    for name in dataset_names:
        data = sample[sample['datasets'] == name]['answer'] 
        nb_per_lbl.update({
            name: get_nb_element(data.to_list())
        })
    df_nb_lbl = pd.DataFrame().from_dict(nb_per_lbl).fillna(0).sort_index()
    df_over = pd.DataFrame().from_dict(over).fillna(0).sort_index()
    print(nb_per_lbl)
    print(df_over)
    for name in dataset_names:
        nb_element = n_sample / len(lst_labels)
        df_nb_lbl.index.names = [None]
        tmp = [
            (x, nb_element) 
            if y > 15 else (x, y)
            for x, y in zip(df_over[name].values, df_nb_lbl[name].values)
        ]
        nb_lbl = [
            nb_element - x[0] 
            if x[0] > 0 else x[1]
            for x in tmp
        ]
        nb_lbl = pd.Series(data=nb_lbl, index=df_over[name].index)
        # nb_lbl_list = df_nb_lbl[name].values - df_over[name].values - nb_element
        nb_lbl_list = df_nb_lbl[name].values - df_over[name].values - nb_lbl.values
        # nb_lbl_list[nb_lbl_list < 0] = 0
        sr_lbl_list = pd.Series(data=nb_lbl_list, index=df_over[name].index)
        d.update({
            name: (
                nb_lbl,
                df_over[name],
                sr_lbl_list
            )
        })
    for k,v in d.items():
        p = ax.bar(x + width * mult, v[0], width, label=k)
        ax.bar_label(p, label_type='center')
        p = ax.bar(
            x + width * mult, v[1], width, label=f'oversample {k}', bottom=v[0]
        )
        ax.bar_label(p, label_type='center')
        p = ax.bar(
            x + width * mult,
            v[2],
            width,
            label=f'labels in list {k}',
            bottom=v[1]+v[0]
        )
        ax.bar_label(p, label_type='center')
        mult += 1
    ax.set_xticks(x + width, sorted(lst_labels), rotation=90)
    ax.legend(loc='upper right')
    ax.set_title('Sampled data')
    fig.set_size_inches(20, 10)
    plt.show()
    if savefile is not None:
        plt.savefig(savefile, format='png')

def plot_metric(
    metric: dict,
    columns: list[str]=['f1', 'precision', 'recall'],
    title='',
    savefile=None,
):
    # rand_mark = pd.Series(np.full((len(metric),), 1/(len(metric)-1)))
    df_metric = pd.DataFrame().from_dict(
        metric,
        orient='index',
        columns=columns
    )
    fig, ax = plt.subplots(1, 1)
    df_metric.plot(
        ax=ax,
        kind='bar',
        figsize=(20,10),
        title=title,
    )
    # rand_mark.plot(ax=ax, color='red', linestyle='dashed')
    plt.xticks(rotation=90)
    plt.show()
    if savefile is not None:
        plt.savefig(savefile, format='png')
    
columns = ['f1', 'precision', 'recall']
plot_stat_sample(sample_train, over_train, fallacies, n_sample)
plot_metric(metric, title='Scores on all sampled data')