In [28]:
import pandas as pd
from dict import depression_labels, newsgroups_labels, emotion_labels, news_labels, yahoo_labels, rating_labels
import os
import sys

In [29]:
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

In [30]:
from utils import get_dataset, get_model_by_tag, get_prompt_template
from LLMAnnotator import LLMAnnotator

In [31]:
os.environ['OPENAI_API_KEY_CLARIN']=os.getenv("OPENAI_API_KEY_CLARIN")
os.environ["LANGCHAIN_API_KEY"]=os.getenv("LANGCHAIN_API_KEY")
os.environ["LANGCHAIN_TRACING_V2"]="true"
os.environ["LANGCHAIN_PROJECT"]=os.getenv("LANGCHAIN_PROJECT")

token = os.getenv("OPENAI_API_KEY_CLARIN")

In [32]:
dataset = '20_newsgroups'
# exp = 5
num_exp = 80
model_tag = 'llama3'
# path = f'results/cot_random_samples_cohere_temp0.3/cohere/{dataset}_cohere_cot_random_samples_cohere_cohere_random42_{num_exp}_temp0.3_exp{exp}.csv'
# df = pd.read_csv(path)
selected_samples = 'cot_random_samples_cohere'
tempp = 0


In [33]:
def get_dict(dataset):
    if(dataset == 'ag_news'):
        return news_labels
    if(dataset == 'yahoo'):
        return yahoo_labels
    if(dataset == '20_newsgroups'):
        return newsgroups_labels
    if(dataset == 'social_media'):
        return depression_labels
    if(dataset == 'go_emotions'):
        return emotion_labels
    if(dataset == 'sst5'):
        return rating_labels
    

In [34]:

def calculate_accuracy(df):
    # Bierzemy tylko wiersze, gdzie obie kolumny nie są NaN i was_in_selected_samples jest False
    valid = df[df['was_in_selected_samples'] == False]

    # Zastępujemy wartości NaN pustym ciągiem (lub inną wartością)
    valid['original_label'] = valid['original_label'].fillna('')
    valid['extraction_complete'] = valid['extraction_complete'].fillna('')

    # Liczymy ile jest zgodnych wartości
    matches = (valid['original_label'] == valid['extraction_complete']).sum()

    # Dokładność = liczba zgodnych / liczba wierszy do porównania
    accuracy = matches / len(valid) if len(valid) > 0 else 0

    return accuracy


In [35]:
labels = get_dict(dataset)
labels

{'rec.sport.hockey': 0,
 'soc.religion.christian': 1,
 'rec.sport.baseball': 2,
 'rec.motorcycles': 3,
 'sci.crypt': 4,
 'rec.autos': 5,
 'sci.med': 6,
 'sci.space': 7,
 'comp.os.ms-windows.misc': 8,
 'comp.sys.ibm.pc.hardware': 9,
 'sci.electronics': 10,
 'comp.windows.x': 11,
 'comp.graphics': 12,
 'misc.forsale': 13,
 'comp.sys.mac.hardware': 14,
 'talk.politics.mideast': 15,
 'talk.politics.guns': 16,
 'alt.atheism': 17,
 'talk.politics.misc': 18,
 'talk.religion.misc': 19}

In [36]:
def find_non_matching_labels_df(label_dict, df):
    keys_set = set(label_dict.keys())
    mask = (~df['extracted_label'].isin(keys_set)) & (df['extracted_label'].notna()) & (df['extracted_label'] != '')
    filtered = df.loc[mask].copy()
    filtered['indeks'] = filtered.index
    return filtered.reset_index(drop=True)


In [37]:
def fill_missing_labels(df, label_dict, new_values):
    keys_set = set(label_dict.keys())
    
    # Znajdź indeksy wierszy, gdzie extracted_label nie jest w kluczach i jest puste lub inne
    mask = (~df['extracted_label'].isin(keys_set)) | (df['extracted_label'].isna()) | (df['extracted_label'] == '')
    indices_to_fill = df.index[mask]

    n_fill = min(len(indices_to_fill), len(new_values))

    # Podstawiamy nowe wartości do kolumny 'extracted_label_by_llm'
    for i in range(n_fill):
        df.at[indices_to_fill[i], 'extracted_label_by_llm'] = new_values[i]

    # Tworzymy kolumnę 'extraction_complete':
    # Jeśli 'extracted_label' jest w keys_set, to bierzemy ją,
    # w przeciwnym razie jeśli 'extracted_label_by_llm' jest w keys_set, to ją bierzemy,
    # inaczej NaN
    def choose_label(row):
        if row['extracted_label'] in keys_set:
            return row['extracted_label']
        elif 'extracted_label_by_llm' in row and row['extracted_label_by_llm'] in keys_set:
            return row['extracted_label_by_llm']
        else:
            return pd.NA

    df['extraction_complete'] = df.apply(choose_label, axis=1)

    return df



def generate_task_string(category_dict):
    categories = ', '.join(category_dict.keys())
    prompt = (
        "Task: Determine which category was chosen as the final one. Possible categories:\n\n"
        f"{categories}\n\n"
        'Text for you: "{text}"\n\n'
        "Return only one category. No explanations."
    )
    return prompt


In [38]:
prompt = generate_task_string(labels)

In [39]:
prompt

'Task: Determine which category was chosen as the final one. Possible categories:\n\nrec.sport.hockey, soc.religion.christian, rec.sport.baseball, rec.motorcycles, sci.crypt, rec.autos, sci.med, sci.space, comp.os.ms-windows.misc, comp.sys.ibm.pc.hardware, sci.electronics, comp.windows.x, comp.graphics, misc.forsale, comp.sys.mac.hardware, talk.politics.mideast, talk.politics.guns, alt.atheism, talk.politics.misc, talk.religion.misc\n\nText for you: "{text}"\n\nReturn only one category. No explanations.'

In [40]:

from LLMAnnotator import LLMAnnotator
from utils import get_dataset, get_model_by_tag, get_prompt_template


for i in [1]:
    print(f"Iteration {i}")
    path  = f'results/{selected_samples}_temp{tempp}/{model_tag}/{dataset}_{model_tag}_{selected_samples}_{model_tag}_random42_{num_exp}_temp{tempp}_exp{i}.csv'
    temp = 0
    model = get_model_by_tag(model_tag, token)  
    df_ori = pd.read_csv(path)
    wrong = find_non_matching_labels_df(labels, df_ori)
    dataset_for_annotation = wrong
    prompt_txt = prompt
    # output_path = f'./results/cot_random_samples_temp{temp}/{model_tag}_extraction/{model_tag}_{model_tag}_{prompt}_{selected_samples}_temp{temp}_exp{i}.csv'
    # os.makedirs(f'./results/{prompt}_temp{temp}/{model_tag}', exist_ok=True)
    text_col = "output"
    label_col = "original_label"
    output_col = "extracted_by_llm"
    
    print('Len:', len(wrong))
    if(len(wrong) != 0):
        
        annotator = LLMAnnotator(
            model=model,
            dataset=dataset_for_annotation,
            examples_for_prompt=pd.DataFrame(),
            prompt_template=prompt,
            column_text=text_col,
            column_label=label_col,
            column_output=output_col
        )
        
        new_val_list = annotator.get_results()
        
        new_val_df = pd.DataFrame(new_val_list)   
        new_vals = new_val_df['extracted_by_llm'].tolist()  

        df = fill_missing_labels(df_ori, labels, new_vals)
        
        output_path_new = 'extracted/' + path

        folder = os.path.dirname(output_path_new)
        if not os.path.exists(folder):
            os.makedirs(folder)
            
        print(f"Accuracy: {calculate_accuracy(df)}")

        df.to_csv(output_path_new, index=False)
        
    else:
        
        df = fill_missing_labels(df_ori, labels, [])
        output_path_new = 'extracted/' + path
        folder = os.path.dirname(output_path_new)
        if not os.path.exists(folder):
            os.makedirs(folder)
            
        print(f"Accuracy: {calculate_accuracy(df)}")

        df.to_csv(output_path_new, index=False)



Iteration 1
Len: 121
Nr: 0 Predicted label: misc.forsale
Nr: 1 Predicted label: misc.forsale
Nr: 2 Predicted label: soc.religion.christian
Nr: 3 Predicted label: talk.politics.misc
Nr: 4 Predicted label: misc.forsale
Nr: 5 Predicted label: soc.religion.christian
Nr: 6 Predicted label: misc.forsale
Nr: 7 Predicted label: soc.religion.christian
Nr: 8 Predicted label: talk.religion.misc
Nr: 9 Predicted label: talk.politics.misc
Nr: 10 Predicted label: misc.forsale
Nr: 11 Predicted label: misc.forsale
Nr: 12 Predicted label: misc.forsale
Nr: 13 Predicted label: soc.religion.christian
Nr: 14 Predicted label: talk.religion.misc
Nr: 15 Predicted label: soc.religion.christian
Nr: 16 Predicted label: misc.forsale
Nr: 17 Predicted label: soc.religion.christian
Nr: 18 Predicted label: talk.religion.misc
Nr: 19 Predicted label: soc.religion.christian
Nr: 20 Predicted label: soc.religion.christian
Nr: 21 Predicted label: talk.politics.misc
Nr: 22 Predicted label: misc.forsale
Nr: 23 Predicted label

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid['original_label'] = valid['original_label'].fillna('')
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  valid['extraction_complete'] = valid['extraction_complete'].fillna('')


In [41]:
df

Unnamed: 0.1,Unnamed: 0,text,output,logprobs,top_logprobs,original_label,was_in_selected_samples,extracted_label,error,extracted_label_by_llm,extraction_complete
0,0,article ashish arora writes excerpts netnewssc...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.176...","[{'token': 'To', 'logprob': -0.176899492740631...",talk.politics.misc,True,talk.politics.misc,,,talk.politics.misc
1,1,gateway telepath modem month actually one woul...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.233...","[{'token': 'To', 'logprob': -0.233964130282402...",comp.sys.ibm.pc.hardware,False,comp.sys.ibm.pc.hardware,,,comp.sys.ibm.pc.hardware
2,2,anybody provide advice concerning following tw...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.225...","[{'token': 'To', 'logprob': -0.225865259766578...",sci.med,False,sci.med,,,sci.med
3,3,article mike silverman writes anybody know goi...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.148...","[{'token': 'To', 'logprob': -0.148368075489997...",rec.sport.baseball,False,rec.sport.baseball,,,rec.sport.baseball
4,4,article stich christian e writes installed mot...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.219...","[{'token': 'To', 'logprob': -0.219457104802131...",sci.electronics,False,comp.sys.ibm.pc.hardware,,,comp.sys.ibm.pc.hardware
...,...,...,...,...,...,...,...,...,...,...,...
1995,645,david sternlight writes article karl barrus wr...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.181...","[{'token': 'To', 'logprob': -0.181033700704574...",sci.crypt,False,sci.crypt,,,sci.crypt
1996,646,hello im looking information alphanumeric page...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.183...","[{'token': 'To', 'logprob': -0.183513775467872...",sci.electronics,False,sci.electronics,,,sci.electronics
1997,647,john r daker writes would like offocially nomi...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.148...","[{'token': 'To', 'logprob': -0.148061156272888...",rec.motorcycles,False,rec.motorcycles,,,rec.motorcycles
1998,648,article writes looking information concerning ...,To determine the most suitable category for th...,"{'content': [{'token': 'To', 'logprob': -0.268...","[{'token': 'To', 'logprob': -0.268602669239044...",sci.space,False,sci.space,,,sci.space
