In [None]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
import pickle
from typing import Dict, List, Tuple

class ThemeClassifier:
    def __init__(self, model_name="sentence-transformers/paraphrase-multilingual-mpnet-base-v2"):
        # Initialize model and tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(self.device, flush=True)
        self.model = self.model.to(self.device)

        # Storage for theme embeddings
        self.theme_embeddings: Dict[str, np.ndarray] = {}

    def get_embedding(self, verses: List[str]) -> np.ndarray:
      """Generate embeddings for a list of verses using batched processing"""
      verses_len = len(verses)
      print(verses_len)
      print("==============")

      # List to store embeddings from each batch
      all_embeddings = []

      # Initialize batch parameters
      i = 0
      if verses_len < 100:
          batch_size = 10
      else:
          batch_size = 100

      # Calculate number of full batches and whether there's a remainder
      num_full_batches = verses_len // batch_size
      has_remainder = verses_len % batch_size > 0
      total_iterations = num_full_batches + (1 if has_remainder else 0)

      for batch_idx in range(total_iterations):
          # Calculate batch indices
          start_idx = batch_idx * batch_size
          end_idx = min(start_idx + batch_size, verses_len)
          batch_verses = verses[start_idx:end_idx]

          # Process batch
          encoded_input = self.tokenizer(batch_verses, padding=True, truncation=True,
                                      max_length=128, return_tensors='pt')
          encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}

          with torch.no_grad():
              model_output = self.model(**encoded_input)

          attention_mask = encoded_input['attention_mask']
          input_mask_expanded = attention_mask.unsqueeze(-1).expand(
              model_output.last_hidden_state.size()).float()

          # Calculate embeddings for current batch
          batch_embeddings = torch.sum(model_output.last_hidden_state * input_mask_expanded, 1) / \
                            torch.clamp(input_mask_expanded.sum(1), min=1e-9)

          # Store batch embeddings
          all_embeddings.append(batch_embeddings.cpu())

      # Concatenate all embeddings
      final_embeddings = torch.cat(all_embeddings, dim=0)
      return final_embeddings.numpy()

    def process_and_save_embeddings(self, df: pd.DataFrame, save_path: str = 'content/theme_embeddings.pkl'):
        """Process verses and save theme embeddings to pickle file"""
        # Process each theme
        for theme in df['poem theme'].unique():
            theme_verses = []

            # Get all verses for this theme
            theme_df = df[df['poem theme'] == theme]

            for i, verses in enumerate(theme_df['poem verses']):
                # Handle string or list input
                if isinstance(verses, str):
                    verses = eval(verses)
                theme_verses.extend(verses)

            # Generate and store embeddings for this theme
            self.theme_embeddings[theme] = self.get_embedding(theme_verses)

        # Save to pickle file
        with open(save_path, 'wb') as f:
            pickle.dump(self.theme_embeddings, f)

        return self.theme_embeddings

    def load_embeddings(self, load_path: str = 'theme_embeddings.pkl'):
        """Load theme embeddings from pickle file"""
        with open(load_path, 'rb') as f:
            self.theme_embeddings = pickle.load(f)
        return self.theme_embeddings

    def classify_verse(self, verse: str, top_k: int = 1) -> List[Tuple[str, float]]:
        """
        Classify a verse into themes based on embedding similarity
        Returns: List of tuples (theme, confidence_score)
        """
        # Generate embedding for input verse
        verse_embedding = self.get_embedding([verse])

        # Compare with each theme's verses
        theme_scores = []
        for theme, theme_embeddings in self.theme_embeddings.items():
            # Calculate similarities with all verses in this theme
            similarities = cosine_similarity(verse_embedding, theme_embeddings)[0]
            # Use max similarity as the theme score
            theme_score = np.max(similarities)
            theme_scores.append((theme, theme_score))

        # Sort by similarity score and return top k matches
        return sorted(theme_scores, key=lambda x: x[1], reverse=True)[:top_k]

In [None]:
df = pd.read_csv('./ashaar.csv')

# Initialize classifier and process verses
classifier = ThemeClassifier()

# Save embeddings
print("Processing and saving embeddings...")
classifier.process_and_save_embeddings(df, 'theme_embeddings.pkl')

# Later, you can load the embeddings and classify new verses
print("\nLoading embeddings and classifying new verse...")
classifier.load_embeddings('theme_embeddings.pkl')

In [None]:
# Classify a new verse
new_verse = ""
predictions = classifier.classify_verse(new_verse, top_k=10)

# Print results
for theme, confidence in predictions:
    print(f"Theme: {theme}")
    print(f"Confidence Score: {confidence:.4f}")
    print("---")