In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# import os
# os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'

import os
hf_token = os.getenv('HUGGINGFACE_TOKEN')
# Загрузка модели (например, Gemma 9B)
model_name = "google/gemma-2-9b-it"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=token, torch_dtype=torch.bfloat16, device_map="auto")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [2]:
prompt = 'Task: Classify the following question/query by human into one of two categories:  \
- **Class 1**: Requires retrieval from a knowledge base (e.g., specific documents, paragraphs, or external references).  \
- **Class 0**: Can be handled by an LLM without retrieval, including general NLP tasks (classification, summarization, etc.) or generic knowledge.  \
\
**Rules**:  \
1. Assign **Class 0** if:  \
   - The question explicitly references a source (e.g., "according to this paragraph", "given a reference text").  \
   - It asks about details from a specific document, entity, or context (e.g., "What city is WFSD Radio licensed to?").  \
   - Requires cross-referencing information (e.g., comparing definitions, analyzing a specific case).  \
\
2. Assign **Class 1** if:  \
   - The request is a general NLP task (e.g., text classification, summarization, paraphrasing).  \
   - The question is generic, open-ended, or opinion-based (e.g., "How to start a hobby?", "Is X a good career?").  \
   - It asks for lists, definitions, or widely known facts (e.g., "What is WordPress?", "Name 10 best games").  \
\
**Examples**:  \
| Question/Request | Class | Explanation |  \
|------------------|-------|-------------|  \
| "Classify this tweet as positive or negative." | 1 | Text classification task. |  \
| "Summarize the following article in 3 sentences." | 1 | Summarization task. |  \
| "Paraphrase this sentence." | 1 | General NLP request. |  \
| "What was Britain called before it was Britain?" | 1 | General historical fact. |  \
| "Given this paragraph about XGBoost, what are its advantages?" | 0 | Requires document retrieval. |  \
\
**Input**:  \
{query}  \
\
**Output Format**:  \
- Class: [0/1]  \
- Confidence: [0.0-1.0]  \
- Reasoning: [1-2 sentences explaining the decision]  \
\
Think step by step. Solution of this task is very important for my job.\
**Response**: '

In [3]:
# Функция для разметки
import re 

def label_query_batch(queries, batch_size=8):
    results = []
    
    for i in range(0, len(queries), batch_size):
        batch_queries = queries[i:i + batch_size]
        
        # Prepare batch inputs
        batch_inputs = []
        for query in batch_queries:
            input_text = prompt.format(query=query)
            messages = [{"role": "user", "content": input_text}]
            input_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            batch_inputs.append(input_text)
        
        # Tokenize batch
        inputs = tokenizer(batch_inputs, add_special_tokens=False, padding=True, return_tensors="pt").to("cuda")
        
        # Generate for batch
        outputs = model.generate(
            input_ids=inputs.input_ids,
            attention_mask=inputs.attention_mask,
            temperature=0.1,
            do_sample=False,
            max_new_tokens=25,
            pad_token_id=tokenizer.pad_token_id
        )
        
        # Process batch outputs
        for output in outputs:
            response = tokenizer.decode(output)
            try:
                parts = re.split(r'\bmodel\b', response, flags=re.IGNORECASE)
                response = parts[-1] if len(parts) > 1 else response

                class_match = re.search(r'Class:\s*(\d)', response)
                confidence_match = re.search(r'Confidence:\s*(\d*\.?\d+)', response)
                label = int(class_match.group(1)) if class_match else -1
                confidence = float(confidence_match.group(1)) if confidence_match else 0.0
                reasoning_match = re.search(r'Reasoning:\s*(.+)', response)
                reasoning = reasoning_match.group(1) if reasoning_match else ""
            except:
                label, confidence, reasoning = -1, 0.0, ""
            
            results.append((label, confidence, reasoning))
    
        # Clear CUDA cache after processing each batch
        torch.cuda.empty_cache()
        # Clear variables
        del inputs, outputs
    return results

In [4]:
import pandas as pd
from tqdm import tqdm

def evaluate_model(df, batch_size=8):
    queries = df['prompt'].astype(str).tolist()
    results = []
    
    # Process in batches with tqdm
    for i in tqdm(range(0, len(queries[:1000]), batch_size)):
        # print(i)
        batch_queries = queries[i:i + batch_size]
        batch_results = label_query_batch(batch_queries, batch_size)
        
        for query, (pred_label, confidence, reasoning) in zip(batch_queries, batch_results):
            results.append({
                'query': query,
                'predicted_label': pred_label,
                'confidence': confidence,
                'reasoning': reasoning
            })
    
    return pd.DataFrame(results)

balanced_df = pd.read_csv('../oasst_quality_with_suggestions.csv')
# Запускаем оценку
results_df = evaluate_model(balanced_df, 2)

# Выводим первые несколько результатов
print("\nПримеры предсказаний:")
print(results_df.head())

The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.
100%|██████████| 500/500 [14:02<00:00,  1.69s/it]


Примеры предсказаний:
                                               query  predicted_label  \
0  Can you explain contrastive learning in machin...                1   
1  I want to start doing astrophotography as a ho...                1   
2  Can you give me an example of a python script ...                1   
3  How can I learn to optimize my webpage for sea...                1   
4  Listened to Dvorak's "The New World" symphony,...                1   

   confidence                                          reasoning  
0        0.95    This question asks for a simplified explanation  
1        0.95                           This is a general, open-  
2        0.95               This request asks for a code example  
3        0.95           This question asks for general advice on  
4        0.95  This question asks for recommendations and com...  





In [10]:
results_df['predicted_label'].value_counts()
for i in range(len(results_df[:100])):
    print(results_df['query'][i])
    print(results_df['predicted_label'][i])
    print(results_df['confidence'][i])
    print(results_df['reasoning'][i])
    print('-'*100)


Can you explain contrastive learning in machine learning in simple terms for someone new to the field of ML?
1
0.95
This question asks for a simplified explanation
----------------------------------------------------------------------------------------------------
I want to start doing astrophotography as a hobby, any suggestions what could i do?
1
0.95
This is a general, open-
----------------------------------------------------------------------------------------------------
Can you give me an example of a python script that opens an api point and serves a string?
1
0.95
This request asks for a code example
----------------------------------------------------------------------------------------------------
How can I learn to optimize my webpage for search engines?
1
0.95
This question asks for general advice on
----------------------------------------------------------------------------------------------------
Listened to Dvorak's "The New World" symphony, liked it much. What compose

In [11]:
results_df.to_csv('gemma_results.csv', index=False)

In [5]:
#вывести все ошибки
errors = results_df[results_df['true_label'] != results_df['predicted_label']]

for idx, row in errors.iterrows():
    print(f'query: {row["query"]}, \n true_label: {row["true_label"]}, pred_label: {row["predicted_label"]}, confidence: {row["confidence"]}, reasoning: {row["reasoning"]}')


KeyError: 'true_label'

In [11]:
len(errors)

30

In [7]:
# Фильтруем результаты с confidence = 1.0
high_conf_results = results_df[results_df['confidence'] > 0.95]

# Считаем метрики
correct = sum(high_conf_results['true_label'] == high_conf_results['predicted_label'])
total = len(high_conf_results)
accuracy = correct / total if total > 0 else 0

# Подсчет полноты и точности
true_positives = sum((high_conf_results['true_label'] == 1) & (high_conf_results['predicted_label'] == 1))
total_actual_positives = sum(high_conf_results['true_label'] == 1)
total_predicted_positives = sum(high_conf_results['predicted_label'] == 1)

recall = true_positives / total_actual_positives if total_actual_positives > 0 else 0
precision = true_positives / total_predicted_positives if total_predicted_positives > 0 else 0

print(f"Количество предсказаний с уверенностью 1.0: {total}")
print(f"Точность для уверенных предсказаний: {accuracy:.2%}")
print(f"Полнота: {recall:.2%}")
print(f"Точность: {precision:.2%}")

Количество предсказаний с уверенностью 1.0: 12
Точность для уверенных предсказаний: 91.67%
Полнота: 50.00%
Точность: 100.00%


In [8]:
high_conf_results

Unnamed: 0,query,true_label,predicted_label,confidence,reasoning
1,WFSD-LP (107.9 FM) is a low-power FM radio sta...,0,0,1.0,"The question asks for a specific detail (""What..."
5,The Parliamentary Commissioner for Standards i...,0,0,1.0,The question explicitly asks for information (...
6,Rraboshtë is a village located in the former K...,0,0,1.0,"The question asks for specific details (""what ..."
17,The third century AD showed some remarkable de...,0,0,1.0,The question asks for a specific detail from t...
19,nan Is the following statement true or false: ...,1,1,1.0,"This is a question about a general, widely kno..."
21,The history of the Marine Corps began when two...,0,0,1.0,The question explicitly asks for information (...
35,The 2009 L'Aquila earthquake occurred in the r...,0,0,1.0,The question explicitly asks for information (...
42,Husinec is located about 6 kilometres (4 mi) n...,0,0,1.0,The question specifically asks for information...
70,"""The Fox in the Attic"" was originally publishe...",0,0,1.0,This question requires retrieving specific pub...
73,Rhodes Scholarships are international postgrad...,0,0,1.0,The question explicitly asks for information (...
