In [None]:
# telugu_metaphor_model_fixed.py
import os
import re
import json
import warnings
from collections import Counter
import importlib.util
from typing import Optional, Dict, Any

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, PreTrainedModel, PretrainedConfig
from transformers.modeling_outputs import SequenceClassifierOutput
from safetensors.torch import save_file, load_file

try:
    import spacy
except Exception:
    spacy = None

from sklearn.model_selection import train_test_split

warnings.filterwarnings("ignore")

# ===========================
# 1. CUSTOM CONFIG
# ===========================

class TeluguMetaphorConfig(PretrainedConfig):
    """Configuration class for Telugu Metaphor Detection model"""
    model_type = "telugu_metaphor_detector"

    def __init__(
        self,
        muril_model_name: str = "google/muril-base-cased",
        bilstm_hidden_dim: int = 256,
        bilstm_num_layers: int = 2,
        bilstm_dropout: float = 0.3,
        meta_hidden_dim: int = 512,
        num_classes: int = 2,
        max_length: int = 64,
        cmt_feature_dim: int = 33,
        syntactic_feature_dim: int = 20,
        invert_output: bool = False,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.muril_model_name = muril_model_name
        self.bilstm_hidden_dim = bilstm_hidden_dim
        self.bilstm_num_layers = bilstm_num_layers
        self.bilstm_dropout = bilstm_dropout
        self.meta_hidden_dim = meta_hidden_dim
        self.num_classes = num_classes
        self.max_length = max_length
        self.cmt_feature_dim = cmt_feature_dim
        self.syntactic_feature_dim = syntactic_feature_dim
        self.invert_output = invert_output


# Register custom model with Transformers
def register_custom_model():
    """Register custom Telugu Metaphor model with Transformers"""
    try:
        from transformers import AutoConfig, AutoModelForSequenceClassification
        AutoConfig.register("telugu_metaphor_detector", TeluguMetaphorConfig, exist_ok=True)
        AutoModelForSequenceClassification.register(TeluguMetaphorConfig, TeluguMetaphorModel, exist_ok=True)
    except Exception as e:
        pass  # Will register later if needed


# ===========================
# 2. MODEL COMPONENTS
# ===========================

class TeluguCMTAnalyzer:
    """Conceptual Metaphor Theory analyzer for Telugu"""
    def __init__(self):
        self.simile_patterns = [
            r'\b‡∞≤‡∞æ\b', r'\b‡∞≤‡∞æ‡∞ó‡∞æ\b', r'\b‡∞µ‡∞Ç‡∞ü‡∞ø\b', r'\b‡∞Æ‡∞æ‡∞¶‡∞ø‡∞∞‡∞ø‡∞ó‡∞æ\b', r'\b‡∞∞‡±Ç‡∞™‡∞Ç‡∞≤‡±ã\b'
        ]
        self.metaphor_indicators = [
            '‡∞ó‡±Å‡∞Ç‡∞°‡±Ü', '‡∞π‡±É‡∞¶‡∞Ø‡∞Ç', '‡∞®‡∞¶‡∞ø', '‡∞µ‡∞æ‡∞®', '‡∞Æ‡±á‡∞ò‡∞Ç', '‡∞™‡∞∞‡±ç‡∞µ‡∞§‡∞Ç', '‡∞ö‡±Ü‡∞ü‡±ç‡∞ü‡±Å', '‡∞Ü‡∞ï‡∞æ‡∞∂‡∞Ç',
            '‡∞∏‡∞Æ‡±Å‡∞¶‡±ç‡∞∞‡∞Ç', '‡∞ö‡∞Ç‡∞¶‡±ç‡∞∞‡±Å‡∞≤‡±Å', '‡∞ö‡∞Ç‡∞¶‡±ç‡∞∞‡±Å‡∞°‡±Å', '‡∞∏‡±Ç‡∞∞‡±ç‡∞Ø‡±Å‡∞°‡±Å', '‡∞¶‡±Ä‡∞™‡∞Ç', '‡∞™‡∞æ‡∞Æ‡±Å', '‡∞™‡∞æ‡∞Æ‡±Å‡∞≤‡∞æ',
            '‡∞µ‡±à‡∞¶‡±ç‡∞Ø‡±Å‡∞°‡±Å', '‡∞Ø‡±Å‡∞¶‡±ç‡∞ß‡∞Ç', '‡∞ï‡∞®‡±ç‡∞®‡±Å‡∞≤‡±Å', '‡∞µ‡∞ø‡∞¶‡±ç‡∞Ø', '‡∞Ö‡∞ó‡±ç‡∞®‡∞ø', '‡∞§‡±Å‡∞´‡∞æ‡∞®‡±Å', '‡∞µ‡∞∞‡±ç‡∞∑‡∞™‡±Å',
            '‡∞ö‡∞ø‡∞®‡±Å‡∞ï‡±Å', '‡∞µ‡∞∏‡∞Ç‡∞§‡∞Ç', '‡∞µ‡±Ü‡∞≤‡±Å‡∞ó‡±Å'
        ]

    def extract_features(self, text):
        features = []
        for pattern in self.simile_patterns:
            features.append(1 if re.search(pattern, text) else 0)
        for indicator in self.metaphor_indicators:
            features.append(1 if indicator in text else 0)
        features.append(1 if '‡∞∑‡∞¨‡±ç‡∞¶‡∞Ç‡∞ó‡∞æ' in text else 0)
        features.append(1 if '‡∞™‡∞≤‡∞ø‡∞ï‡∞ø‡∞® ‡∞µ‡∞æ‡∞ï‡±ç‡∞Ø‡∞Ç' in text else 0)
        features.append(1 if re.search(r'\w+\s+\w+\s+\w+\.', text) and len(text.split()) <= 5 else 0)
        return torch.tensor(features, dtype=torch.float32)


class SyntacticFeatureExtractor:
    """Extract dependency-based syntactic features"""
    def __init__(self, nlp):
        self.nlp = nlp

    def extract_features(self, text):
        if self.nlp is None:
            return torch.zeros(20, dtype=torch.float32)

        doc = self.nlp(text)
        features = [
            sum(1 for tok in doc if tok.dep_ == 'nsubj'),
            sum(1 for tok in doc if tok.dep_ == 'dobj'),
            sum(1 for tok in doc if tok.dep_ == 'advmod'),
            sum(1 for tok in doc if tok.dep_ == 'compound'),
            sum(1 for token in doc if token.pos_ == 'VERB' and not any(child.dep_ == 'nsubj' for child in token.children)),
            len(doc),
            sum(1 for tok in doc if tok.pos_ == 'VERB'),
            sum(1 for tok in doc if tok.pos_ == 'NOUN'),
        ]
        features += [0] * (20 - len(features))
        return torch.tensor(features[:20], dtype=torch.float32)


class MuRILBiLSTMEncoder(nn.Module):
    """BiLSTM encoder for MuRIL embeddings"""
    def __init__(self, embedding_dim=768, hidden_dim=256, num_layers=2, dropout=0.3):
        super().__init__()
        self.lstm = nn.LSTM(
            embedding_dim, hidden_dim, num_layers=num_layers,
            batch_first=True, bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)

    def forward(self, x, attention_mask=None):
        if attention_mask is not None:
            lengths = attention_mask.sum(dim=1).cpu().long()
        else:
            lengths = torch.full((x.size(0),), x.size(1), dtype=torch.long)

        packed = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
        packed_out, (hn, cn) = self.lstm(packed)

        forward_hidden = hn[-2]
        backward_hidden = hn[-1]
        combined = torch.cat([forward_hidden, backward_hidden], dim=1)
        combined = self.layer_norm(combined)
        combined = self.dropout(combined)
        return combined


class AttentionPooling(nn.Module):
    """Attention-based pooling"""
    def __init__(self, hidden_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, x, mask=None):
        scores = self.attention(x).squeeze(-1)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)
        weights = torch.softmax(scores, dim=-1).unsqueeze(-1)
        pooled = (x * weights).sum(dim=1)
        return pooled


class MetaClassifier(nn.Module):
    """Meta-classifier"""
    def __init__(self, input_dim, hidden_dim=512, num_classes=2):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.LayerNorm(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim // 2, hidden_dim // 4),
            nn.LayerNorm(hidden_dim // 4),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim // 4, num_classes)
        )

    def forward(self, x):
        return self.mlp(x)


# ===========================
# 3. MAIN MODEL (HF Compatible)
# ===========================

class TeluguMetaphorModel(PreTrainedModel):
    """HuggingFace-compatible Telugu Metaphor Detection Model"""
    config_class = TeluguMetaphorConfig

    def __init__(self, config: TeluguMetaphorConfig):
        super().__init__(config)
        self.config = config

        # Load MuRIL
        try:
            self.muril = AutoModel.from_pretrained(config.muril_model_name, use_safetensors=True)
        except:
            # Fallback without safetensors
            self.muril = AutoModel.from_pretrained(config.muril_model_name)

        # BiLSTM encoder
        self.bilstm_encoder = MuRILBiLSTMEncoder(
            embedding_dim=768,
            hidden_dim=config.bilstm_hidden_dim,
            num_layers=config.bilstm_num_layers,
            dropout=config.bilstm_dropout
        )

        # Attention pooling
        self.attention_pooling = AttentionPooling(768)

        # Meta classifier
        meta_input_dim = (config.bilstm_hidden_dim * 2) + 768 + config.cmt_feature_dim + config.syntactic_feature_dim
        self.meta_classifier = MetaClassifier(meta_input_dim, config.meta_hidden_dim, config.num_classes)

        # Feature normalization stats (will be loaded from checkpoint)
        self.register_buffer('cmt_mean', torch.zeros(config.cmt_feature_dim))
        self.register_buffer('cmt_std', torch.ones(config.cmt_feature_dim))
        self.register_buffer('synt_mean', torch.zeros(config.syntactic_feature_dim))
        self.register_buffer('synt_std', torch.ones(config.syntactic_feature_dim))

        # Freeze MuRIL by default
        for param in self.muril.parameters():
            param.requires_grad = False

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        cmt_features: torch.Tensor,
        synt_features: torch.Tensor,
        labels: Optional[torch.Tensor] = None
    ):
        # Get MuRIL embeddings
        muril_outputs = self.muril(input_ids=input_ids, attention_mask=attention_mask)
        muril_embeddings = muril_outputs.last_hidden_state

        # BiLSTM encoding
        bilstm_features = self.bilstm_encoder(muril_embeddings, attention_mask)

        # Attention pooling
        muril_pooled = self.attention_pooling(muril_embeddings, attention_mask)

        # Normalize CMT and syntactic features
        cmt_norm = (cmt_features - self.cmt_mean.unsqueeze(0)) / (self.cmt_std.unsqueeze(0) + 1e-8)
        synt_norm = (synt_features - self.synt_mean.unsqueeze(0)) / (self.synt_std.unsqueeze(0) + 1e-8)

        # Combine all features
        combined = torch.cat([bilstm_features, muril_pooled, cmt_norm, synt_norm], dim=1)

        # Classify
        logits = self.meta_classifier(combined)

        # Apply inversion if configured
        if self.config.invert_output:
            logits = torch.flip(logits, dims=[1])

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None
        )

    def update_normalization_stats(self, cmt_mean, cmt_std, synt_mean, synt_std):
        """Update feature normalization statistics"""
        self.cmt_mean.copy_(cmt_mean)
        self.cmt_std.copy_(cmt_std)
        self.synt_mean.copy_(synt_mean)
        self.synt_std.copy_(synt_std)


# ===========================
# 4. DATASET CLASS
# ===========================

class TeluguMetaphorDataset(Dataset):
    """Dataset for Telugu metaphor detection"""
    def __init__(self, data, tokenizer, cmt_analyzer, synt_extractor, max_len=64):
        self.raw_data = data[:]
        self.data = []
        self.tokenizer = tokenizer
        self.cmt_analyzer = cmt_analyzer
        self.synt_extractor = synt_extractor
        self.max_len = max_len

        for text, label in self.raw_data:
            mapped = self._map_label(label)
            self.data.append((text, mapped))

    def _map_label(self, label):
        if isinstance(label, int):
            return 1 if int(label) == 1 else 0
        if isinstance(label, str):
            lab = label.strip().lower()
            if lab in ('metaphor', 'metaphorical', 'fig', 'figurative', '1'):
                return 1
            if lab in ('normal', 'non-metaphor', 'nonmetaphor', 'literal', 'not_metaphor', 'not-metaphor', '0'):
                return 0
        return 0

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text, label_idx = self.data[idx]
        text_clean = re.sub(r'[^\u0C00-\u0C7F\s]', '', text).strip()
        if not text_clean:
            text_clean = text

        encoding = self.tokenizer(
            text_clean, max_length=self.max_len, padding='max_length',
            truncation=True, return_tensors='pt'
        )

        cmt_feat = self.cmt_analyzer.extract_features(text_clean)
        synt_feat = self.synt_extractor.extract_features(text_clean)

        return {
            'text': text_clean,
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'cmt_features': cmt_feat,
            'synt_features': synt_feat,
            'label': label_idx
        }


def collate_fn(batch):
    """Custom collate function"""
    return {
        'texts': [item['text'] for item in batch],
        'input_ids': torch.stack([item['input_ids'] for item in batch]),
        'attention_mask': torch.stack([item['attention_mask'] for item in batch]),
        'cmt_features': torch.stack([item['cmt_features'] for item in batch]),
        'synt_features': torch.stack([item['synt_features'] for item in batch]),
        'labels': torch.tensor([item['label'] for item in batch], dtype=torch.long)
    }


# ===========================
# 5. TRAINING PIPELINE
# ===========================

class TeluguMetaphorPipeline:
    """Training and inference pipeline"""
    def __init__(self, model_path='google/muril-base-cased', device=None):
        print("Initializing Telugu Metaphor Detection Pipeline...")

        # Register custom model type with Transformers
        register_custom_model()

        self.device = device or (torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'))
        print(f"Using device: {self.device}")

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

        # Create config and model
        self.config = TeluguMetaphorConfig(muril_model_name=model_path)
        self.model = TeluguMetaphorModel(self.config)
        self.model.to(self.device)

        print("‚úÖ Model initialized successfully")

        # Load spaCy
        self.nlp = None
        if spacy is not None:
            try:
                self.nlp = spacy.load("xx_ent_wiki_sm")
                print("‚úÖ spaCy model loaded")
            except:
                print("‚ö†Ô∏è  spaCy model not available (syntactic features will be zeros)")

        # Analyzers
        self.cmt_analyzer = TeluguCMTAnalyzer()
        self.synt_extractor = SyntacticFeatureExtractor(self.nlp)

        self.checkpoint_path = 'telugu_metaphor_best'

    def compute_feature_stats(self, dataset, batch_size=64):
        """Compute normalization stats"""
        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
        cmt_list, synt_list = [], []

        for batch in loader:
            cmt_list.append(batch['cmt_features'].numpy())
            synt_list.append(batch['synt_features'].numpy())

        cmt_all = np.vstack(cmt_list)
        synt_all = np.vstack(synt_list)

        cmt_mean = torch.tensor(cmt_all.mean(axis=0), dtype=torch.float32)
        cmt_std = torch.tensor(cmt_all.std(axis=0) + 1e-6, dtype=torch.float32)
        synt_mean = torch.tensor(synt_all.mean(axis=0), dtype=torch.float32)
        synt_std = torch.tensor(synt_all.std(axis=0) + 1e-6, dtype=torch.float32)

        self.model.update_normalization_stats(cmt_mean, cmt_std, synt_mean, synt_std)
        print("‚úÖ Computed feature normalization stats")

    def train(self, train_data, val_data, epochs=20, batch_size=16, lr=3e-4):
        """Train the model"""
        print("\nüîÑ Preparing datasets...")
        train_dataset = TeluguMetaphorDataset(
            train_data, self.tokenizer, self.cmt_analyzer, self.synt_extractor
        )
        val_dataset = TeluguMetaphorDataset(
            val_data, self.tokenizer, self.cmt_analyzer, self.synt_extractor
        )

        self.compute_feature_stats(train_dataset)

        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

        # Class weights
        all_labels = [lbl for _, lbl in train_dataset.data]
        counts = Counter(all_labels)
        weights = torch.tensor([1.0/max(counts.get(0,1),1), 1.0/max(counts.get(1,1),1)], dtype=torch.float32).to(self.device)
        weights = weights * (2.0 / weights.sum())

        print(f"üìä Class weights: Normal={weights[0]:.3f}, Metaphor={weights[1]:.3f}")

        criterion = nn.CrossEntropyLoss(weight=weights)

        # Optimizer (exclude MuRIL parameters)
        trainable_params = [p for n, p in self.model.named_parameters() if 'muril' not in n and p.requires_grad]
        optimizer = optim.AdamW(trainable_params, lr=lr, weight_decay=0.01)
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3)

        best_val_acc = 0.0
        patience_counter = 0

        print("\nüöÄ Starting training...")
        for epoch in range(epochs):
            # Training
            self.model.train()
            train_loss, train_correct, train_total = 0.0, 0, 0

            for batch in train_loader:
                optimizer.zero_grad()

                outputs = self.model(
                    input_ids=batch['input_ids'].to(self.device),
                    attention_mask=batch['attention_mask'].to(self.device),
                    cmt_features=batch['cmt_features'].to(self.device),
                    synt_features=batch['synt_features'].to(self.device),
                    labels=batch['labels'].to(self.device)
                )

                loss = outputs.loss
                loss.backward()
                torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
                optimizer.step()

                train_loss += loss.item()
                _, predicted = torch.max(outputs.logits, 1)
                train_total += batch['labels'].size(0)
                train_correct += (predicted == batch['labels'].to(self.device)).sum().item()

            train_acc = 100.0 * train_correct / max(train_total, 1)

            # Validation
            self.model.eval()
            val_loss, val_correct, val_total = 0.0, 0, 0

            with torch.no_grad():
                for batch in val_loader:
                    outputs = self.model(
                        input_ids=batch['input_ids'].to(self.device),
                        attention_mask=batch['attention_mask'].to(self.device),
                        cmt_features=batch['cmt_features'].to(self.device),
                        synt_features=batch['synt_features'].to(self.device),
                        labels=batch['labels'].to(self.device)
                    )

                    val_loss += outputs.loss.item()
                    _, predicted = torch.max(outputs.logits, 1)
                    val_total += batch['labels'].size(0)
                    val_correct += (predicted == batch['labels'].to(self.device)).sum().item()

            val_acc = 100.0 * val_correct / max(val_total, 1)
            scheduler.step(val_acc)

            print(f'Epoch [{epoch+1}/{epochs}] Train Loss: {train_loss/len(train_loader):.4f} '
                  f'Train Acc: {train_acc:.2f}% | Val Loss: {val_loss/len(val_loader):.4f} Val Acc: {val_acc:.2f}%')

            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                self.save_model(self.checkpoint_path)
                print(f" ‚Üí ‚úÖ New best model saved! (Acc: {val_acc:.2f}%)")
            else:
                patience_counter += 1

            if patience_counter >= 7:
                print(f"\n‚èπÔ∏è  Early stopping after {epoch+1} epochs")
                break

        print(f"\n‚úÖ Training complete! Best val acc: {best_val_acc:.2f}%")
        self.load_model(self.checkpoint_path)
        self._detect_and_fix_inversion(val_loader)

    def _detect_and_fix_inversion(self, val_loader):
        """Detect label inversion"""
        self.model.eval()
        total, correct, inverted_correct = 0, 0, 0

        with torch.no_grad():
            for batch in val_loader:
                outputs = self.model(
                    input_ids=batch['input_ids'].to(self.device),
                    attention_mask=batch['attention_mask'].to(self.device),
                    cmt_features=batch['cmt_features'].to(self.device),
                    synt_features=batch['synt_features'].to(self.device)
                )
                preds = torch.argmax(outputs.logits, dim=-1)
                labels = batch['labels'].to(self.device)

                total += labels.size(0)
                correct += (preds == labels).sum().item()
                inverted_correct += ((1 - preds) == labels).sum().item()

        acc = 100.0 * correct / max(total, 1)
        inv_acc = 100.0 * inverted_correct / max(total, 1)

        print(f"\nüîç Label Inversion Check:")
        print(f"   Normal: {acc:.2f}% | Inverted: {inv_acc:.2f}%")

        if inv_acc > acc + 5.0:
            print("   ‚ö†Ô∏è  INVERSION DETECTED! Activating correction.")
            self.config.invert_output = True
            self.model.config.invert_output = True
            self.save_model(self.checkpoint_path)
        else:
            print("   ‚úÖ No inversion detected.")

    def predict(self, text):
        """Predict metaphor"""
        self.model.eval()

        text_clean = re.sub(r'[^\u0C00-\u0C7F\s]', '', text).strip() or text

        encoding = self.tokenizer(text_clean, max_length=64, padding='max_length', truncation=True, return_tensors='pt')
        cmt_feat = self.cmt_analyzer.extract_features(text_clean).unsqueeze(0)
        synt_feat = self.synt_extractor.extract_features(text_clean).unsqueeze(0)

        with torch.no_grad():
            outputs = self.model(
                input_ids=encoding['input_ids'].to(self.device),
                attention_mask=encoding['attention_mask'].to(self.device),
                cmt_features=cmt_feat.to(self.device),
                synt_features=synt_feat.to(self.device)
            )
            probs = torch.softmax(outputs.logits, dim=-1)
            pred_class = torch.argmax(probs, dim=-1).item()
            confidence = probs[0, pred_class].item()

        label = 'METAPHOR' if pred_class == 1 else 'NORMAL'
        return {'text': text, 'prediction': label, 'confidence': confidence}

    def save_model(self, save_directory):
        """Save model in HuggingFace format with safetensors only"""
        os.makedirs(save_directory, exist_ok=True)

        # Save model weights using safetensors
        self.model.save_pretrained(save_directory, safe_serialization=True)
        self.tokenizer.save_pretrained(save_directory)

        # Ensure only safetensors exists (remove .bin if created)
        bin_path = os.path.join(save_directory, 'pytorch_model.bin')
        if os.path.exists(bin_path):
            os.remove(bin_path)

        print(f"‚úÖ Model saved to {save_directory} (safetensors format)")

    def load_model(self, load_directory):
        """Load model from HuggingFace format (safetensors only)"""
        # CRITICAL: Register custom model BEFORE any loading attempt
        from transformers import AutoConfig, AutoModelForSequenceClassification
        try:
            AutoConfig.register("telugu_metaphor_detector", TeluguMetaphorConfig, exist_ok=True)
            AutoModelForSequenceClassification.register(TeluguMetaphorConfig, TeluguMetaphorModel, exist_ok=True)
            print("‚úÖ Custom model type registered")
        except Exception as reg_error:
            print(f"‚ö†Ô∏è  Registration note: {reg_error}")

        # Manual loading approach (most reliable)
        config_path = os.path.join(load_directory, 'config.json')
        weights_path = os.path.join(load_directory, 'model.safetensors')

        if not os.path.exists(config_path):
            raise FileNotFoundError(f"Config file not found: {config_path}")

        if not os.path.exists(weights_path):
            raise FileNotFoundError(f"Safetensors file not found: {weights_path}")

        # Load config
        self.config = TeluguMetaphorConfig.from_pretrained(load_directory)
        print(f"‚úÖ Config loaded: model_type={self.config.model_type}")

        # Create model from config
        self.model = TeluguMetaphorModel(self.config)

        # Load weights from safetensors
        state_dict = load_file(weights_path)
        missing_keys, unexpected_keys = self.model.load_state_dict(state_dict, strict=False)

        if missing_keys:
            print(f"‚ö†Ô∏è  Missing keys: {missing_keys}")
        if unexpected_keys:
            print(f"‚ö†Ô∏è  Unexpected keys: {unexpected_keys}")

        self.model.to(self.device)
        self.model.eval()
        print(f"‚úÖ Model loaded from {load_directory}")


# ===========================
# 6. DATASET LOADING
# ===========================

def load_dataset_from_file(filepath):
    """Load dataset from file"""
    import csv
    import sys

    data = []
    ext = filepath.lower().split('.')[-1]

    try:
        if ext == 'json':
            with open(filepath, 'r', encoding='utf-8') as f:
                json_data = json.load(f)
                for item in json_data:
                    if isinstance(item, dict) and 'text' in item and 'label' in item:
                        data.append((item['text'], item['label']))
                    elif isinstance(item, (list, tuple)) and len(item) >= 2:
                        data.append((item[0], item[1]))

        elif ext in ('csv', 'tsv'):
            delimiter = '\t' if ext == 'tsv' else ','
            with open(filepath, 'r', encoding='utf-8') as f:
                reader = csv.reader(f, delimiter=delimiter)
                header = next(reader, None)

                text_idx, label_idx = 0, 1
                if header and 'text' in [h.lower() for h in header]:
                    header_lower = [h.lower() for h in header]
                    text_idx = header_lower.index('text')
                    label_idx = header_lower.index('label') if 'label' in header_lower else text_idx + 1
                elif header and len(header) >= 2:
                    data.append((header[0], header[1]))

                for row in reader:
                    if len(row) >= 2:
                        data.append((row[text_idx], row[label_idx]))

        elif ext == 'txt':
            with open(filepath, 'r', encoding='utf-8') as f:
                for line in f:
                    line = line.strip()
                    if not line:
                        continue
                    parts = line.split('\t') if '\t' in line else line.split(',', 1)
                    if len(parts) >= 2:
                        data.append((parts[0].strip(), parts[1].strip()))

        elif ext == 'py':
            module_name = os.path.splitext(os.path.basename(filepath))[0]
            spec = importlib.util.spec_from_file_location(module_name, filepath)
            module = importlib.util.module_from_spec(spec)
            sys.modules[module_name] = module
            spec.loader.exec_module(module)

            for var_name in ['data', 'dataset', 'telugu_metaphor_data', 'test_sentences']:
                if hasattr(module, var_name):
                    loaded_data = getattr(module, var_name)
                    if isinstance(loaded_data, list):
                        for item in loaded_data:
                            if isinstance(item, dict) and 'text' in item and 'label' in item:
                                data.append((item['text'], item['label']))
                            elif isinstance(item, (list, tuple)) and len(item) >= 2:
                                data.append((item[0], item[1]))
                    break

        else:
            raise ValueError(f"Unsupported file format: {ext}. Use .json, .csv, .tsv, .txt, or .py")

        if not data:
            raise ValueError(f"No valid data found in {filepath}")

        print(f"‚úÖ Loaded {len(data)} samples from {filepath}")
        return data

    except Exception as e:
        print(f"‚ùå Error loading dataset: {e}")
        return []


# ===========================
# 7. INTERACTIVE INTERFACE
# ===========================

def interactive_prediction_mode(pipeline):
    """Interactive mode for single sentence predictions"""
    print("\n" + "="*60)
    print("üéØ TELUGU METAPHOR DETECTOR - INTERACTIVE MODE")
    print("="*60)
    print("\nEnter Telugu sentences to check if they contain metaphors.")
    print("Commands: 'quit' or 'exit' to stop\n")

    while True:
        try:
            user_input = input("Enter a Telugu sentence: ").strip()

            if not user_input:
                print("‚ö†Ô∏è  Please enter a sentence.\n")
                continue

            if user_input.lower() in ['quit', 'exit', 'q']:
                print("\nüëã Exiting interactive mode. Goodbye!")
                break

            # Make prediction
            result = pipeline.predict(user_input)

            # Display results
            print("\n" + "-"*60)
            print(f"üìù Sentence: {result['text']}")
            print(f"üîÆ Prediction: {result['prediction']}")
            print(f"üìä Confidence: {result['confidence']*100:.2f}%")

            # Visual indicator
            if result['prediction'] == 'METAPHOR':
                print("üé≠ This sentence contains metaphorical language!")
            else:
                print("üìñ This sentence is literal/normal.")
            print("-"*60 + "\n")

        except KeyboardInterrupt:
            print("\n\nüëã Exiting interactive mode. Goodbye!")
            break
        except Exception as e:
            print(f"‚ùå Error: {e}\n")


def batch_prediction_from_file(pipeline, filepath):
    """Batch prediction from a text file"""
    print(f"\nüìÇ Loading sentences from {filepath}...")

    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            sentences = [line.strip() for line in f if line.strip()]

        print(f"‚úÖ Found {len(sentences)} sentences\n")
        print("="*60)

        results = []
        for i, sentence in enumerate(sentences, 1):
            result = pipeline.predict(sentence)
            results.append(result)

            print(f"\n[{i}/{len(sentences)}]")
            print(f"üìù Sentence: {result['text']}")
            print(f"üîÆ Prediction: {result['prediction']} (Confidence: {result['confidence']*100:.2f}%)")

        print("\n" + "="*60)
        print(f"‚úÖ Completed predictions for {len(sentences)} sentences")

        # Summary
        metaphor_count = sum(1 for r in results if r['prediction'] == 'METAPHOR')
        normal_count = len(results) - metaphor_count
        print(f"\nüìä Summary:")
        print(f"   Metaphors: {metaphor_count} ({metaphor_count/len(results)*100:.1f}%)")
        print(f"   Normal: {normal_count} ({normal_count/len(results)*100:.1f}%)")

        return results

    except Exception as e:
        print(f"‚ùå Error reading file: {e}")
        return []


# ===========================
# 8. MAIN EXECUTION
# ===========================

def main():
    """Main execution function"""
    print("\n" + "="*70)
    print("üé≠ TELUGU METAPHOR DETECTION SYSTEM")
    print("="*70)

    # Register custom model type early
    register_custom_model()

    # Menu
    print("\nWhat would you like to do?")
    print("1. Train a new model")
    print("2. Load existing model and predict")
    print("3. Interactive prediction mode (no training)")

    choice = input("\nEnter your choice (1/2/3): ").strip()

    # Initialize pipeline
    pipeline = TeluguMetaphorPipeline()

    if choice == '1':
        # Training mode
        print("\nüìö TRAINING MODE")
        print("-" * 70)

        dataset_path = input("Enter path to your dataset file (.json, .csv, .txt, .py): ").strip()

        if not os.path.exists(dataset_path):
            print(f"‚ùå File not found: {dataset_path}")
            return

        # Load dataset
        data = load_dataset_from_file(dataset_path)

        if not data:
            print("‚ùå No data loaded. Exiting.")
            return

        # Split data
        train_data, val_data = train_test_split(data, test_size=0.2, random_state=42)
        print(f"üìä Split: {len(train_data)} training, {len(val_data)} validation samples")

        # Training parameters
        epochs = int(input("Enter number of epochs (default 20): ").strip() or "20")
        batch_size = int(input("Enter batch size (default 16): ").strip() or "16")

        # Train
        pipeline.train(train_data, val_data, epochs=epochs, batch_size=batch_size)

        # Ask for interactive mode
        continue_pred = input("\nüéØ Enter interactive prediction mode? (y/n): ").strip().lower()
        if continue_pred == 'y':
            interactive_prediction_mode(pipeline)

    elif choice == '2':
        # Load existing model
        print("\nüìÇ LOAD MODEL MODE")
        print("-" * 70)

        model_path = input("Enter path to saved model directory (default: telugu_metaphor_best): ").strip()
        model_path = model_path or 'telugu_metaphor_best'

        if not os.path.exists(model_path):
            print(f"‚ùå Model directory not found: {model_path}")
            return

        try:
            pipeline.load_model(model_path)
            print("\n‚úÖ Model loaded successfully!")

            # Prediction mode
            print("\nChoose prediction mode:")
            print("1. Interactive (one sentence at a time)")
            print("2. Batch (from text file)")

            pred_choice = input("\nEnter choice (1/2): ").strip()

            if pred_choice == '1':
                interactive_prediction_mode(pipeline)
            elif pred_choice == '2':
                file_path = input("Enter path to text file with sentences: ").strip()
                batch_prediction_from_file(pipeline, file_path)
            else:
                print("‚ùå Invalid choice")

        except Exception as e:
            print(f"‚ùå Error loading model: {e}")

    elif choice == '3':
        # Interactive mode without training
        print("\nüéØ INTERACTIVE MODE (Using untrained model)")
        print("-" * 70)
        print("‚ö†Ô∏è  Warning: Model is not trained. Predictions may be random.")
        print("    For meaningful predictions, please train the model first (option 1)")

        continue_anyway = input("\nContinue anyway? (y/n): ").strip().lower()
        if continue_anyway == 'y':
            interactive_prediction_mode(pipeline)

    else:
        print("‚ùå Invalid choice. Exiting.")


if __name__ == "__main__":
    main()


üé≠ TELUGU METAPHOR DETECTION SYSTEM

What would you like to do?
1. Train a new model
2. Load existing model and predict
3. Interactive prediction mode (no training)

Enter your choice (1/2/3): 1
Initializing Telugu Metaphor Detection Pipeline...
Using device: cuda


tokenizer_config.json:   0%|          | 0.00/206 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/411 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/113 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/953M [00:00<?, ?B/s]

‚úÖ Model initialized successfully
‚ö†Ô∏è  spaCy model not available (syntactic features will be zeros)

üìö TRAINING MODE
----------------------------------------------------------------------
Enter path to your dataset file (.json, .csv, .txt, .py): /content/telugu_metaphor_dataset.py
‚úÖ Loaded 1000 samples from /content/telugu_metaphor_dataset.py
üìä Split: 800 training, 200 validation samples
Enter number of epochs (default 20): 20
Enter batch size (default 16): 16

üîÑ Preparing datasets...
‚úÖ Computed feature normalization stats
üìä Class weights: Normal=0.997, Metaphor=1.002

üöÄ Starting training...
Epoch [1/20] Train Loss: 0.7376 Train Acc: 52.38% | Val Loss: 0.6621 Val Acc: 67.50%
‚úÖ Model saved to telugu_metaphor_best (safetensors format)
 ‚Üí ‚úÖ New best model saved! (Acc: 67.50%)
Epoch [2/20] Train Loss: 0.6898 Train Acc: 56.88% | Val Loss: 0.6061 Val Acc: 59.50%
Epoch [3/20] Train Loss: 0.2146 Train Acc: 92.50% | Val Loss: 0.0359 Val Acc: 99.00%
‚úÖ Model saved t

In [None]:
from google.colab import files
import shutil

# Path to your model folder
folder_path = "/content/telugu_metaphor_best"
zip_path = "/content/telugu_metaphor_model_1.zip"

# Zip the folder
shutil.make_archive(zip_path.replace(".zip", ""), 'zip', folder_path)

# Download the zip file
files.download(zip_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>