In [2]:
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from typing import List, Dict, Any, Tuple, Optional
from pathlib import Path
import pandas as pd

In [None]:
device = "cuda"

model = AutoModelForCausalLM.from_pretrained("dicta-il/dictalm2.0-instruct", torch_dtype=torch.bfloat16, device_map=device)
tokenizer = AutoTokenizer.from_pretrained("dicta-il/dictalm2.0-instruct")

In [None]:
@dataclass
class ResponseData:
    """Data class for storing response information."""
    file_name: str
    conv: str
    response: str

In [None]:
class ConversationProcessor:
    def __init__(self, output_dir):
        """Initialize the processor with output directory."""
        self.output_dir = Path(output_dir)
        self.setup_output_directory()

    def setup_output_directory(self):
        """Create output directory if it doesn't exist."""
        self.output_dir.mkdir(parents=True, exist_ok=True)

    def get_surrounding_indices(current_index, df_length, window=3):
        """Get valid indices around the current index within a specified window."""
        return range(
            max(0, current_index - window),
            min(df_length, current_index + window + 1)
        )

    def process_model_response(response, combine_words) :
        """Process the model's response and extract relevant information."""
        response_words = response.strip().split()
        filtered_response: List[str] = [
            word for word in response_words
            if isinstance(word, str) and word not in {"[INST]", "[/INST]"}
        ][len(combine_words):]

        # Extract first word for classification
        first_word = ''.join(filtered_response[:4])
        response_string = ' '.join(filtered_response)

        return first_word, response_string

    def classify_response(first_word) :
        """Classify response based on first word."""
        if "כן" in first_word:
            return 1
        if "לא" in first_word:
            return 0
        return None

    def create_response_data(session_id, conversation, response) :
        """Create a ResponseData object containing response information."""
        return ResponseData(
            file_name=session_id,
            conv=conversation,
            response=response
        )

    def save_results(
        self,
        df: pd.DataFrame,
        yes_responses: List[ResponseData],
        no_responses: List[ResponseData]
    ) -> None:
        """Save all results to CSV files."""
        yes_df = pd.DataFrame([vars(r) for r in yes_responses]) if yes_responses else pd.DataFrame()
        no_df = pd.DataFrame([vars(r) for r in no_responses]) if no_responses else pd.DataFrame()

        columns = ['file_name', 'conv', 'response']

        if not yes_df.empty:
            yes_df.to_csv(
                self.output_dir / 'yes_res_call_cut.csv',
                columns=columns,
                index=False
            )

        if not no_df.empty:
            no_df.to_csv(
                self.output_dir / 'no_res_call_cut.csv',
                columns=columns,
                index=False
            )

        df.to_csv(
            self.output_dir / 'updated_original_df_call_cut.csv',
            index=False
        )

    def process_conversation(
        self,
        row: pd.Series,
        query_model: callable,
        example: str
    ) -> Tuple[ResponseData, Optional[int]]:
        """Process a single conversation and get model's response."""
        conv = row['conversation']

        prompt = f'השיחה: "{conv}"\nשאלה: האם המשפט נאמר באופן לא מכבד?\nתשובה: רק כן או לא.'
        asist = "הבנתי. דוגמה ברורה."
        combine = f"{prompt}{asist}{example}"
        combine_words = combine.strip().split()

        # Get model response
        response = query_model(conv).strip()

        first_word, response_string = self.process_model_response(response, combine_words)
        classification = self.classify_response(first_word)

        response_data = self.create_response_data(
            session_id=str(row['session_id']),
            conversation=conv,
            response=response_string
        )

        return response_data, classification

    def process_responses(
        self,
        df,
        query_model,
        example
    ):
        """Main function to process all responses."""
        yes_responses: List[ResponseData] = []
        no_responses: List[ResponseData] = []

        relevant_rows = df[
            (df['contains_calls_to_order'] == 1) |
            (df['contains_cut'] == 1)
        ]

        for index, row in relevant_rows.iterrows():
            session_id = row['session_id']

            for idx in self.get_surrounding_indices(index, len(df)):
                row_to_check = df.iloc[idx]

                if row_to_check['session_id'] == session_id:
                    try:
                        response_data, classification = self.process_conversation(
                            row_to_check, query_model, example
                        )

                        match classification:
                            case 1:
                                yes_responses.append(response_data)
                            case 0:
                                no_responses.append(response_data)
                            case _:
                                pass

                        df.at[index, 'dicta_answer'] = classification

                    except Exception as e:
                        print(f"Error processing row {index}: {str(e)}")
                        df.at[index, 'dicta_answer'] = None

        try:
            self.save_results(df, yes_responses, no_responses)
            print("Processing completed. Results saved successfully.")
        except Exception as e:
            print(f"Error saving results: {str(e)}")

        return df