In [2]:
# %pip install -r requirements.txt

In [3]:
import os
import json
import pandas as pd
import numpy as np
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import WordNetLemmatizer
from utils.vector_db import VectorDB
from chromadb import EmbeddingFunction
from tqdm import tqdm


import torch
from transformers import BertTokenizer, BertModel

In [4]:
# Check for GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Available GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

# Load pre-trained BERT model and tokenizer with GPU support
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased').to(device)

# Enable mixed precision for faster training (if GPU supports it)
if torch.cuda.is_available():
    model.half()  # Use FP16 for faster inference
    print("Model loaded with FP16 precision for faster GPU processing")
else:
    print("GPU not available, using CPU")

print(f"Model loaded on: {next(model.parameters()).device}")

Using device: cuda
GPU: NVIDIA GeForce RTX 3070
CUDA Version: 12.6
Available GPU memory: 8.0 GB
Model loaded with FP16 precision for faster GPU processing
Model loaded on: cuda:0
Model loaded with FP16 precision for faster GPU processing
Model loaded on: cuda:0


In [5]:
# nltk.download('punkt_tab')
# nltk.download('stopwords')
# nltk.download('wordnet')
# nltk.download('omw-1.4')

In [6]:
def read_text(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        return f.read()

def load_or_create_paired_df(data_dir, csv_path, has_real=True):
    """
    If csv_path exists -> load it.
    Else -> loop through article_* folders in data_dir and build a dataframe with:
    - text_1, text_2
    - real (only if has_real=True), looked up from <parent_of_data_dir>/train.csv
    """

    if os.path.exists(csv_path):
        return pd.read_csv(csv_path)

    rows = []

    if has_real:
        # load the csv at "data/train.csv"
        real_df = pd.read_csv("data/train.csv")

    for article_dir in sorted(d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))):
        article_path = os.path.join(data_dir, article_dir)
        f1 = os.path.join(article_path, "file_1.txt")
        f2 = os.path.join(article_path, "file_2.txt")

        text_1 = read_text(f1)
        text_2 = read_text(f2)

        row = {"text_1": text_1, "text_2": text_2}

        if has_real:
            # lookup the "real" value from the real_df
            real_row = real_df[real_df["id"] == int(article_dir.split("_")[1])]
            real_value = real_row["real"].values[0] if not real_row.empty else np.nan
            row["real"] = real_value

        rows.append(row)

    df = pd.DataFrame(rows)
    df.to_csv(csv_path, index=False)

    return df

# Usage
train_data_dir = "data/train"
test_data_dir  = "data/test"
train_csv = "data/stored_train_data.csv"
test_csv  = "data/stored_test_data.csv"

paired_df = load_or_create_paired_df(train_data_dir, train_csv, has_real=True)
test_df   = load_or_create_paired_df(test_data_dir,  test_csv,  has_real=False)

In [7]:
def clean_text(text):
    if not isinstance(text, str):
        return ""
    # Tokenize the text
    tokens = word_tokenize(text.lower())
    
    # Remove stopwords
    stop_words = set(stopwords.words('english'))
    filtered_tokens = [word for word in tokens if word.isalnum() and word not in stop_words]
    
    # Lemmatize the tokens
    lemmatizer = WordNetLemmatizer()
    lemmatized_tokens = [lemmatizer.lemmatize(word) for word in filtered_tokens]
    
    # Join the tokens back into a cleaned string
    cleaned_text = ' '.join(lemmatized_tokens)
    return cleaned_text


def clean_df(df):
    df['cleaned_text_1'] = df['text_1'].apply(clean_text)
    df['cleaned_text_2'] = df['text_2'].apply(clean_text)
    return df

paired_df = clean_df(paired_df)
paired_df.head()

Unnamed: 0,text_1,text_2,real,cleaned_text_1,cleaned_text_2
0,The VIRSA (Visible Infrared Survey Telescope A...,The China relay network has released a signifi...,1,virsa visible infrared survey telescope array ...,china relay network released significant amoun...
1,China\nThe goal of this project involves achie...,The project aims to achieve an accuracy level ...,2,china goal project involves achieving accuracy...,project aim achieve accuracy level dex analyzi...
2,Scientists can learn about how galaxies form a...,Dinosaur eggshells offer clues about what dino...,1,scientist learn galaxy form evolve two method ...,dinosaur eggshell offer clue dinosaur ate long...
3,China\nThe study suggests that multiple star s...,The importance for understanding how stars evo...,2,china study suggests multiple star system play...,importance understanding star evolve led resea...
4,Dinosaur Rex was excited about his new toy set...,Analyzing how fast stars rotate within a galax...,2,dinosaur rex excited new toy set many dinosaur...,analyzing fast star rotate within galaxy compa...


In [8]:
test_df = clean_df(test_df)

In [9]:
def extract_bert_embeddings(text, device=None):
    if device is None:
        device = next(model.parameters()).device
    
    # Tokenize input text
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # Get BERT embeddings
    with torch.no_grad():
        if device.type == 'cuda':
            with torch.cuda.amp.autocast():  # Use automatic mixed precision
                outputs = model(**inputs)
        else:
            outputs = model(**inputs)
        # The last hidden state contains the embeddings
        embeddings = outputs.last_hidden_state.cpu()  # Move back to CPU for return

    return embeddings

class MyEmbeddingFunction(EmbeddingFunction):
    def __init__(self, model, tokenizer, device=None):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device if device is not None else next(model.parameters()).device

    def __call__(self, input: list) -> list:
        # input: list of strings
        embeddings = []
        
        # Process in batches for better GPU utilization
        batch_size = 16 if self.device.type == 'cuda' else 4
        
        for i in range(0, len(input), batch_size):
            batch_texts = input[i:i + batch_size]
            
            # Tokenize batch
            inputs = self.tokenizer(
                batch_texts, 
                return_tensors='pt', 
                truncation=True, 
                padding=True, 
                max_length=512
            )
            inputs = {k: v.to(self.device) for k, v in inputs.items()}
            
            with torch.no_grad():
                if self.device.type == 'cuda':
                    with torch.cuda.amp.autocast():
                        outputs = self.model(**inputs)
                else:
                    outputs = self.model(**inputs)
                
                # Use the [CLS] token embedding as sentence embedding
                cls_embeddings = outputs.last_hidden_state[:, 0, :].cpu().numpy()
                
                for emb in cls_embeddings:
                    embeddings.append(emb.tolist())
        
        return embeddings

# Test the GPU-accelerated embedding function
sample_embedding = extract_bert_embeddings("Sample text for embedding.")
print(f"Sample embedding shape: {sample_embedding.shape}")
print(f"Model device: {next(model.parameters()).device}")

# Clear GPU cache if using CUDA
if torch.cuda.is_available():
    torch.cuda.empty_cache()

  with torch.cuda.amp.autocast():  # Use automatic mixed precision


Sample embedding shape: torch.Size([1, 9, 768])
Model device: cuda:0


In [10]:
# documents = []
# for idx, row in paired_df.iterrows():
#     if str(row['cleaned_text_1']).strip():
#         documents.append({
#             "id": f"{idx}_1",
#             "content": row['cleaned_text_1'],
#             "metadata": {"real": row["real"] == 1}
#         })
#     if str(row['cleaned_text_2']).strip():
#         documents.append({
#             "id": f"{idx}_2",
#             "content": row['cleaned_text_2'],
#             "metadata": {"real": row["real"] == 2}
#         })

# # Delete the existing collection if it exists (to fix dimension mismatch)
# rebuild_collection = False
# if rebuild_collection:
#     vector_db_tmp = VectorDB(
#         collection_name="impostor_hunt_texts",
#         embedding_length=384,
#         working_dir=os.getcwd()
#     )
#     vector_db_tmp.delete_collection()

# embedding_function = MyEmbeddingFunction(model, tokenizer)


# # Initialize VectorDB (embedding_function can be left as None to use default)
# vector_db = VectorDB(
#     collection_name="impostor_hunt_texts",
#     embedding_length=768,
#     working_dir=os.getcwd(),
#     documents=documents,
#     dont_add_if_collection_exist=not rebuild_collection
# )

# vector_db.search("""ChromeDriver music player
# This study focused on identifying any non-spherical shapes within specific types of celestial bodies (music music) using various techniques like comparing how they look from different directions and analyzing their changes in sound pressure vs time .
# The extent to which these artists' images show evidence for an overall shape rather than individual tracks was found across multiple tracks:
# Two specific songs had clearly visible distortions due to their complex structure compared to others playing just simple beats
# This research found that while most recordings showed a relatively simple structure (like when you only see one instrument rather than an entire grand orchestra), some featured noticeable deviations from those expectations (like if there were multiple instruments playing at once). These results suggest there may be a correlation between how musicians program their compositions and how much curvature they chose for their soundscape — it seems as though tracks with more intricate arrangements tend towards greater complexity!
# Please note: This is just an example response based on your input text as I am not able access real world information such as music information or even what "music music" means without further context!
# Let me know if you want me to try working through some real world examples instead? I can also provide alternative ways I could rephrase your initial statement!""")

In [11]:
# --- Late Chunking for 'real' and 'not real' groups ---
real_docs = []
not_real_docs = []
for idx, row in paired_df.iterrows():
    text_1 = row['cleaned_text_1']
    text_2 = row['cleaned_text_2']
    # Only process if text_1 is a string and not empty
    if isinstance(text_1, str) and text_1.strip():
        doc = {
            "id": f"{idx}_1",
            "content": text_1,
            "metadata": {"real": row["real"] == 1}
        }
        if row["real"] == 1:
            real_docs.append(doc)
        else:
            not_real_docs.append(doc)
    # Only process if text_2 is a string and not empty
    if isinstance(text_2, str) and text_2.strip():
        doc = {
            "id": f"{idx}_2",
            "content": text_2,
            "metadata": {"real": row["real"] == 2}
        }
        if row["real"] == 2:
            real_docs.append(doc)
        else:
            not_real_docs.append(doc)

# Delete the existing collection if it exists (to fix dimension mismatch)
rebuild_collection = False
if rebuild_collection:
    vector_db_tmp = VectorDB(
        collection_name="impostor_hunt_texts",
        embedding_length=384,
        working_dir=os.getcwd()
    )
    vector_db_tmp.delete_collection()


# Add late chunked documents for both groups
vector_db_real = VectorDB(
    collection_name="impostor_hunt_texts_real",
    embedding_length=768,
    working_dir=os.getcwd(),
    # embedding_function=embedding_function
)

if rebuild_collection:
    vector_db_real.add_documents_with_late_chunking(real_docs, chunk_size=1500, chunk_overlap=200, max_context=8192)
    vector_db_real.add_documents_with_late_chunking(not_real_docs, chunk_size=1500, chunk_overlap=200, max_context=8192)

search_limit = 20

# count real/fake
def count_real_fake(results, search_limit):
    real_count = sum(1 for doc in results if doc['metadata']['real'])
    fake_count = len(results) - real_count
    return (real_count / search_limit)


In [12]:
def get_cls_embedding(text, device=None):
    if device is None:
        device = next(model.parameters()).device
    
    inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    with torch.no_grad():
        if device.type == 'cuda':
            with torch.cuda.amp.autocast():
                outputs = model(**inputs)
        else:
            outputs = model(**inputs)
        cls_emb = outputs.last_hidden_state[:, 0, :].squeeze().cpu().numpy()
    return cls_emb

def get_features_gpu_optimized(df, vector_db_real, search_limit=20, batch_size=8):
    """GPU-optimized feature extraction with batching"""
    device = next(model.parameters()).device
    features = []
    labels = []
    
    # Prepare all texts for batch processing
    all_texts_1 = []
    all_texts_2 = []
    valid_indices = []
    
    for idx, row in df.iterrows():
        t1 = row['text_1']
        t2 = row['text_2']
        if isinstance(t1, str) and isinstance(t2, str):
            all_texts_1.append(row['cleaned_text_1'])
            all_texts_2.append(row['cleaned_text_2'])
            valid_indices.append(idx)
    
    print(f"Processing {len(valid_indices)} valid text pairs...")
    
    # Batch process embeddings for better GPU utilization
    all_emb1 = []
    all_emb2 = []
    
    # Process text_1 embeddings in batches
    for i in tqdm(range(0, len(all_texts_1), batch_size), desc="Processing text_1 embeddings"):
        batch_texts = all_texts_1[i:i + batch_size]
        inputs = tokenizer(
            batch_texts, 
            return_tensors='pt', 
            truncation=True, 
            padding=True, 
            max_length=512
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            if device.type == 'cuda':
                with torch.cuda.amp.autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            batch_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            all_emb1.extend(batch_emb)
    
    # Process text_2 embeddings in batches
    for i in tqdm(range(0, len(all_texts_2), batch_size), desc="Processing text_2 embeddings"):
        batch_texts = all_texts_2[i:i + batch_size]
        inputs = tokenizer(
            batch_texts, 
            return_tensors='pt', 
            truncation=True, 
            padding=True, 
            max_length=512
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}
        
        with torch.no_grad():
            if device.type == 'cuda':
                with torch.cuda.amp.autocast():
                    outputs = model(**inputs)
            else:
                outputs = model(**inputs)
            batch_emb = outputs.last_hidden_state[:, 0, :].cpu().numpy()
            all_emb2.extend(batch_emb)
    
    # Clear GPU cache after batch processing
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Now process RAG scores and combine features
    for i, idx in tqdm(enumerate(valid_indices), desc="Extracting RAG scores and combining features"):
        row = df.iloc[idx]
        t1 = row['text_1']
        t2 = row['text_2']
        
        emb1 = all_emb1[i]
        emb2 = all_emb2[i]
        
        # Get RAG scores
        score1 = count_real_fake(vector_db_real.search(t1, limit=search_limit), search_limit)
        score2 = count_real_fake(vector_db_real.search(t2, limit=search_limit), search_limit)
        
        # Combine features
        feat = np.concatenate([emb1, emb2, [score1, score2], emb1-emb2])
        features.append(feat)
        
        if 'real' in row:
            labels.append(1 if row['real'] == 1 else 2)
    
    return np.array(features), np.array(labels)

# Keep the original function as backup
def get_features(df, vector_db_real, search_limit=20):
    features = []
    labels = []
    for _, row in tqdm(df.iterrows(), total=len(df), desc="Extracting features"):
        ct1 = row['cleaned_text_1']
        ct2 = row['cleaned_text_2']
        t1 = row['text_1']
        t2 = row['text_2']
        # Skip rows where t1 or t2 is not a string
        if not isinstance(t1, str) or not isinstance(t2, str):
            continue
        emb1 = get_cls_embedding(ct1)
        emb2 = get_cls_embedding(ct2)
        score1 = count_real_fake(vector_db_real.search(t1, limit=search_limit), search_limit)
        score2 = count_real_fake(vector_db_real.search(t2, limit=search_limit), search_limit)
        feat = np.concatenate([emb1, emb2, [score1, score2], emb1-emb2])
        features.append(feat)
        if 'real' in row:
            labels.append(1 if row['real'] == 1 else 2)
    return np.array(features), np.array(labels)

In [13]:
# --- Prepare train/test features ---
X_train, y_train = get_features(paired_df, vector_db_real, search_limit=20)
X_test, _ = get_features(test_df, vector_db_real, search_limit=20)

  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
Extracting features: 100%|██████████| 95/95 [00:11<00:00,  8.46it/s]
Extracting features: 100%|██████████| 95/95 [00:11<00:00,  8.46it/s]
Extracting features: 100%|██████████| 1068/1068 [02:04<00:00,  8.61it/s]
Extracting features: 100%|██████████| 1068/1068 [02:04<00:00,  8.61it/s]


In [15]:
# Fix BERTClassifier training issues
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import accuracy_score, classification_report
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

print("Setting up BERTClassifier training...")

class BERTClassifier(nn.Module):
    """
    True BERT-based classifier that processes raw text through BERT
    """
    def __init__(self, model_name='bert-base-uncased', num_classes=2, dropout=0.3, freeze_bert=False):
        super(BERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Linear(self.bert.config.hidden_size * 2, num_classes)  # *2 for pair classification
        
        if freeze_bert:
            for param in self.bert.parameters():
                param.requires_grad = False
    
    def forward(self, input_ids_1, attention_mask_1, input_ids_2, attention_mask_2):
        # Process both texts through BERT
        outputs_1 = self.bert(input_ids=input_ids_1, attention_mask=attention_mask_1)
        outputs_2 = self.bert(input_ids=input_ids_2, attention_mask=attention_mask_2)
        
        # Use CLS tokens (pooler_output)
        pooled_1 = outputs_1.pooler_output
        pooled_2 = outputs_2.pooler_output
        
        # Concatenate representations
        combined = torch.cat([pooled_1, pooled_2], dim=1)
        combined = self.dropout(combined)
        
        return self.classifier(combined)

class TextPairDataset(torch.utils.data.Dataset):
    """Dataset for text pair classification"""
    def __init__(self, texts_1, texts_2, labels, tokenizer, max_length=128):
        self.texts_1 = texts_1
        self.texts_2 = texts_2
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts_1)
    
    def __getitem__(self, idx):
        text_1 = str(self.texts_1[idx])
        text_2 = str(self.texts_2[idx])
        label = self.labels[idx]
        
        tokens_1 = self.tokenizer(text_1, max_length=self.max_length, 
                                 truncation=True, padding='max_length', return_tensors='pt')
        tokens_2 = self.tokenizer(text_2, max_length=self.max_length, 
                                 truncation=True, padding='max_length', return_tensors='pt')
        
        return {
            'input_ids_1': tokens_1['input_ids'].squeeze(),
            'attention_mask_1': tokens_1['attention_mask'].squeeze(),
            'input_ids_2': tokens_2['input_ids'].squeeze(),
            'attention_mask_2': tokens_2['attention_mask'].squeeze(),
            'labels': torch.tensor(label, dtype=torch.long)
        }

def train_bert_classifier_fixed(train_df, device, num_epochs=2, batch_size=8, max_length=128):
    """Fixed BERT classifier training"""
    
    print("Preparing training data...")
    
    # Prepare training data
    texts_1, texts_2, labels = [], [], []
    for _, row in train_df.iterrows():
        if pd.notna(row['cleaned_text_1']) and pd.notna(row['cleaned_text_2']):
            texts_1.append(str(row['cleaned_text_1'])[:500])  # Truncate for safety
            texts_2.append(str(row['cleaned_text_2'])[:500])  # Truncate for safety
            # Convert labels: if real == 1, then text_1 is real (label 0), else text_2 is real (label 1)
            labels.append(0 if row['real'] == 1 else 1)
    
    print(f"Prepared {len(texts_1)} training samples")
    
    # Use existing tokenizer to avoid loading another model
    bert_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
    
    # Initialize model with smaller batch size to avoid memory issues
    print("Initializing BERTClassifier...")
    bert_classifier = BERTClassifier(num_classes=2, dropout=0.3, freeze_bert=True).to(device)  # Freeze BERT for faster training
    
    # Prepare data loader
    dataset = TextPairDataset(texts_1, texts_2, labels, bert_tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    
    # Training setup
    optimizer = optim.AdamW(bert_classifier.parameters(), lr=2e-5, weight_decay=0.01)
    criterion = nn.CrossEntropyLoss()
    
    print(f"Starting training for {num_epochs} epochs with batch size {batch_size}...")
    
    # Training loop with debugging
    bert_classifier.train()
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        
        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Training")):
            try:
                # Move batch to device
                input_ids_1 = batch['input_ids_1'].to(device)
                attention_mask_1 = batch['attention_mask_1'].to(device)
                input_ids_2 = batch['input_ids_2'].to(device)
                attention_mask_2 = batch['attention_mask_2'].to(device)
                labels_batch = batch['labels'].to(device)
                
                # Forward pass
                optimizer.zero_grad()
                outputs = bert_classifier(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2)
                loss = criterion(outputs, labels_batch)
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                num_batches += 1
                
                # Print progress every 10 batches
                if batch_idx % 10 == 0:
                    print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
                    
                # Clear cache periodically
                if batch_idx % 20 == 0:
                    torch.cuda.empty_cache()
                    
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue
        
        avg_loss = total_loss / num_batches if num_batches > 0 else 0
        print(f'Epoch {epoch+1} completed - Average Loss: {avg_loss:.4f}')
    
    print("Training completed!")
    return bert_classifier, bert_tokenizer

# Fixed RAG scorer to work with your vector_db
class RAGScorer:
    """Fixed RAG-based scoring system"""
    def __init__(self, vector_db_real):
        self.vector_db_real = vector_db_real
    
    def get_scores(self, text_1, text_2, search_limit=20):
        """Get RAG scores for text pair"""
        score1 = self.count_real_fake_fixed(self.vector_db_real.search(text_1, limit=search_limit), search_limit)
        score2 = self.count_real_fake_fixed(self.vector_db_real.search(text_2, limit=search_limit), search_limit)
        return score1, score2
    
    def count_real_fake_fixed(self, search_results, search_limit):
        """Fixed count real vs fake in search results"""
        if not search_results:
            return 0.5
        
        # Use the existing count_real_fake function that works with your vector_db
        return count_real_fake(search_results, search_limit)

# Clear GPU memory before starting
if torch.cuda.is_available():
    torch.cuda.empty_cache()
    print(f"GPU Memory before training: {torch.cuda.memory_allocated()/1024**3:.1f}GB")

# Train the BERTClassifier
print("Training BERTClassifier...")
trained_bert_classifier, bert_tokenizer = train_bert_classifier_fixed(paired_df, device, num_epochs=100, batch_size=4)

# Initialize RAG scorer
print("Initializing RAG scorer...")
rag_scorer = RAGScorer(vector_db_real)

# Simple ensemble prediction function
def bert_rag_ensemble_predict(test_df, bert_model, rag_scorer, tokenizer, device, alpha=0.6, search_limit=20, max_length=128):
    """Ensemble prediction with BERT + RAG"""
    results = []
    bert_model.eval()
    
    print(f"Making predictions on {len(test_df)} samples...")
    
    with torch.no_grad():
        for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="BERT+RAG Prediction"):
            try:
                text_1 = str(row.get('cleaned_text_1', ''))[:500]
                text_2 = str(row.get('cleaned_text_2', ''))[:500]
                
                # BERT prediction
                tokens_1 = tokenizer(text_1, max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')
                tokens_2 = tokenizer(text_2, max_length=max_length, truncation=True, padding='max_length', return_tensors='pt')
                
                input_ids_1 = tokens_1['input_ids'].to(device)
                attention_mask_1 = tokens_1['attention_mask'].to(device)
                input_ids_2 = tokens_2['input_ids'].to(device)
                attention_mask_2 = tokens_2['attention_mask'].to(device)
                
                outputs = bert_model(input_ids_1, attention_mask_1, input_ids_2, attention_mask_2)
                bert_probs = torch.softmax(outputs, dim=1)[0]
                
                # RAG prediction
                rag_score1, rag_score2 = rag_scorer.get_scores(text_1, text_2, search_limit)
                
                # Ensemble combination
                combined_score_1 = alpha * bert_probs[0] + (1-alpha) * rag_score1
                combined_score_2 = alpha * bert_probs[1] + (1-alpha) * rag_score2
                
                predicted_real = 1 if combined_score_1 >= combined_score_2 else 2
                
                results.append({'id': idx, 'real_text_id': predicted_real})
                
            except Exception as e:
                print(f"Error processing sample {idx}: {e}")
                # Fallback to RAG-only prediction
                try:
                    rag_score1, rag_score2 = rag_scorer.get_scores(text_1, text_2, search_limit)
                    predicted_real = 1 if rag_score1 >= rag_score2 else 2
                    results.append({'id': idx, 'real_text_id': predicted_real})
                except:
                    results.append({'id': idx, 'real_text_id': 1})  # Default prediction
    
    return pd.DataFrame(results)

# Make predictions
print("Generating ensemble predictions...")
final_predictions = bert_rag_ensemble_predict(
    test_df, 
    trained_bert_classifier, 
    rag_scorer, 
    bert_tokenizer, 
    device, 
    alpha=0.6
)

# Save results
final_predictions.to_csv("bert_rag_ensemble_predictions.csv", index=False)
print("Predictions saved to 'bert_rag_ensemble_predictions.csv'")
print(final_predictions.head())

# Monitor GPU usage
if torch.cuda.is_available():
    print(f"Peak GPU Memory: {torch.cuda.max_memory_allocated()/1024**3:.1f}GB")

BERT+RAG Prediction: 100%|██████████| 1068/1068 [01:09<00:00, 15.40it/s]

Predictions saved to 'bert_rag_ensemble_predictions.csv'
   id  real_text_id
0   0             2
1   1             2
2   2             1
3   3             1
4   4             2
Peak GPU Memory: 1.1GB



