In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# import os
# os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

import os
hf_token = os.getenv('HUGGINGFACE_TOKEN')
# Загрузка модели (например, Gemma 9B)
model_name = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token, torch_dtype=torch.bfloat16, device_map="auto")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [2]:
prompt = ''' You are an expert in natural language processing and query classification. Your task is to label each given query as either "retrieval needed" (1) or "retrieval not needed" (0). 

Instructions:
1. Analyze the query carefully.
2. If the query requires specific, detailed, or specialized external information (e.g., fact-checking, detailed data, or precise terminology), label it as "1" (retrieval needed).
3. If the query is generic, abstract, or related to tasks like summarization, paraphrasing, or general knowledge that can be answered by a language model without external data, label it as "0" (retrieval not needed).
4. Provide your answer in the following format: 
   <label><0 or 1></label>
   <explanation><one or two sentences explaining your reasoning></explanation>

Examples:
Query: When did Virgin Australia start operating?
<label>1</label>
<explanation>This query asks for a specific historical fact (a start date) that is unlikely to be deduced purely from context and requires external factual data.</explanation>

Query: Which is a species of fish? Tope or Rope
<label>1</label>
<explanation>The query requires knowing the correct fish species between two options, which is a precise factual lookup rather than a generic or inferable answer.</explanation>

Query: Alice's parents have three daughters: Amy, Jessy, and what’s the name of the third daughter?
<label>0</label>
<explanation>This is a riddle where the answer is directly given in the query ("Alice"), making external retrieval unnecessary.</explanation>

Query: What individual has won the most Olympic gold medals in the history of the games?
<label>1</label>
<explanation>This is a specific fact-based question that requires up-to-date and precise information about Olympic records, which typically necessitates an external data source.</explanation>

Query: Which Dutch artist painted “Girl with a Pearl Earring”?
<label>0</label>
<explanation>This is a well-known art fact that most language models can answer from general knowledge without needing to retrieve external information.</explanation>
   
Now, please label the following query:

{query}
'''

In [3]:
# Функция для разметки
import re 

def label_query_batch(queries, batch_size=8):
    results = []
    
    for i in range(0, len(queries), batch_size):
        batch_queries = queries[i:i + batch_size]
        
        # Prepare batch inputs
        batch_inputs = []
        for query in batch_queries:
            input_text = prompt.format(query=query)
            messages = [{"role": "user", "content": input_text}]
            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            batch_inputs.append(input_text)
        
        # Tokenize batch
        inputs = tokenizer(batch_inputs, add_special_tokens=False, padding=True, return_tensors="pt").to("cuda")
        
        # Generate for batch
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            temperature=0.1,
            do_sample=False,
            max_new_tokens=45,
            pad_token_id=tokenizer.pad_token_id
        )
        
        # Process batch outputs
        for output in outputs:
            response = tokenizer.decode(output)
            try:
                # Находим часть ответа после последнего слова "model"
                parts = re.split(r'<start_of_turn> *model', response, flags=re.IGNORECASE)
                response = parts[-1] if len(parts) > 1 else response
                #print(response)
                
                # Ищем метку в формате <label>X</label>
                label_match = re.search(r'<label>(\d)</label>', response)
                # Ищем объяснение в формате <explanation>text</explanation>
                explanation_match = re.search(r'<explanation>(.*?)</explanation>', response)
                
                label = int(label_match.group(1)) if label_match else -1
                explanation = explanation_match.group(1).strip() if explanation_match else ""
                
                # Поскольку confidence не используется в новом формате, установим его в 1.0
                confidence = 1.0 if label != -1 else 0.0
            except:
                label, confidence, explanation = -1, 0.0, ""
            
            results.append((label, confidence, explanation))
    
        # Clear CUDA cache after processing each batch
        torch.cuda.empty_cache()
        # Clear variables
        del inputs, outputs
    return results

In [4]:
import pandas as pd
from tqdm import tqdm
from datasets import load_dataset

def evaluate_model(df, batch_size=8):
    #queries = [str(i) + ' '+ str(c) for i, c in zip(df['instruction'].astype(str).tolist(), df['input'].astype(str).tolist())]
    queries = df['prompt'].astype(str).tolist()
    results = []
    
    for i in tqdm(range(0, len(queries), batch_size)):
        batch_queries = queries[i:i + batch_size]
        batch_results = label_query_batch(batch_queries, batch_size)
        
        # Modified this section to include instruction and context
        for j, ((pred_label, confidence, reasoning), query) in enumerate(zip(batch_results, batch_queries)):
            idx = i + j
            results.append({
                'query': query,
                'instruction': df.iloc[idx]['prompt'],
                #'context': df.iloc[idx]['input'],
                'predicted_label': pred_label,
                'confidence': confidence,
                'reasoning': reasoning
            })
        
        if i % 100 == 0:
            pd.DataFrame(results).to_csv(f'gemma_oasst_{i}.csv', index=False)
    
    return pd.DataFrame(results)


#dataset = load_dataset("databricks/databricks-dolly-15k")

df = pd.read_csv('./datasets/oasst_suggest.csv')

#df = df.sample(frac=1, random_state=42).reset_index(drop=True)

# Calculate character length and filter
df['total_chars'] = df['prompt'].fillna('').str.len()# + df['input'].fillna('').str.len()
df = df[df['total_chars'] <= 2000]
#df.to_csv('dolly_15k.csv', index=False)

# Запускаем оценку
results_df = evaluate_model(df, 2)

# Выводим первые несколько результатов
print("\nПримеры предсказаний:")
print(results_df.head())

The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
100%|██████████| 1500/1500 [1:10:57<00:00,  2.84s/it]


Примеры предсказаний:
                                               query  \
0  Can you recommend a fun DIY project that can b...   
1  Hello Assistant! Could you help me out in maki...   
2  What are the differences between the Lindy hop...   
3  Could you describe the easiest way to draw a c...   
4  I want you to act as a Linux terminal. I will ...   

                                         instruction  predicted_label  \
0  Can you recommend a fun DIY project that can b...                0   
1  Hello Assistant! Could you help me out in maki...                1   
2  What are the differences between the Lindy hop...                1   
3  Could you describe the easiest way to draw a c...                0   
4  I want you to act as a Linux terminal. I will ...               -1   

   confidence                                          reasoning  
0         1.0  This query seeks a suggestion or recommendatio...  
1         1.0                                                     





In [10]:
pd.DataFrame(results_df).to_csv(f'gemma_alpaca_{3000}.csv', index=False)

In [9]:
results_df['predicted_label'].value_counts()
# for i in range(len(results_df[:100])):
#     print(results_df['query'][i])
#     print(results_df['predicted_label'][i])
#     print(results_df['confidence'][i])
#     print(results_df['reasoning'][i])
#     print('-'*100)


predicted_label
 0    2013
 1     986
-1       1
Name: count, dtype: int64

In [11]:
results_df.to_csv('gemma_results.csv', index=False)

In [5]:
#вывести все ошибки
errors = results_df[results_df['true_label'] != results_df['predicted_label']]

for idx, row in errors.iterrows():
    print(f'query: {row["query"]}, \n true_label: {row["true_label"]}, pred_label: {row["predicted_label"]}, confidence: {row["confidence"]}, reasoning: {row["reasoning"]}')


KeyError: 'true_label'

In [11]:
len(errors)

30

In [7]:
# Фильтруем результаты с confidence = 1.0
high_conf_results = results_df[results_df['confidence'] > 0.95]

# Считаем метрики
correct = sum(high_conf_results['true_label'] == high_conf_results['predicted_label'])
total = len(high_conf_results)
accuracy = correct / total if total > 0 else 0

# Подсчет полноты и точности
true_positives = sum((high_conf_results['true_label'] == 1) & (high_conf_results['predicted_label'] == 1))
total_actual_positives = sum(high_conf_results['true_label'] == 1)
total_predicted_positives = sum(high_conf_results['predicted_label'] == 1)

recall = true_positives / total_actual_positives if total_actual_positives > 0 else 0
precision = true_positives / total_predicted_positives if total_predicted_positives > 0 else 0

print(f"Количество предсказаний с уверенностью 1.0: {total}")
print(f"Точность для уверенных предсказаний: {accuracy:.2%}")
print(f"Полнота: {recall:.2%}")
print(f"Точность: {precision:.2%}")

Количество предсказаний с уверенностью 1.0: 12
Точность для уверенных предсказаний: 91.67%
Полнота: 50.00%
Точность: 100.00%


In [8]:
high_conf_results

Unnamed: 0,query,true_label,predicted_label,confidence,reasoning
1,WFSD-LP (107.9 FM) is a low-power FM radio sta...,0,0,1.0,"The question asks for a specific detail (""What..."
5,The Parliamentary Commissioner for Standards i...,0,0,1.0,The question explicitly asks for information (...
6,Rraboshtë is a village located in the former K...,0,0,1.0,"The question asks for specific details (""what ..."
17,The third century AD showed some remarkable de...,0,0,1.0,The question asks for a specific detail from t...
19,nan Is the following statement true or false: ...,1,1,1.0,"This is a question about a general, widely kno..."
21,The history of the Marine Corps began when two...,0,0,1.0,The question explicitly asks for information (...
35,The 2009 L'Aquila earthquake occurred in the r...,0,0,1.0,The question explicitly asks for information (...
42,Husinec is located about 6 kilometres (4 mi) n...,0,0,1.0,The question specifically asks for information...
70,"""The Fox in the Attic"" was originally publishe...",0,0,1.0,This question requires retrieving specific pub...
73,Rhodes Scholarships are international postgrad...,0,0,1.0,The question explicitly asks for information (...
