In [8]:
import json

with open("labeled_data.json", "r") as f:
    labeled_data = json.load(f)

with open("documents.json", "r") as f:
    documents = json.load(f)

In [2]:
from sklearn.model_selection import train_test_split

# Split into train/val (80%/20%)
train_data, val_data = train_test_split(labeled_data, test_size=0.2, random_state=42)

print(f"Train samples: {len(train_data)}, Val samples: {len(val_data)}")

Train samples: 2352, Val samples: 588


In [3]:
# Discriminator

import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, vocab_size, embed_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2, 64),  # Concatenated query+doc embeddings
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, query, doc):
        query_embed = self.embedding(query).mean(dim=1)  # Average query embeddings
        doc_embed = self.embedding(doc).mean(dim=1)      # Average doc embeddings
        combined = torch.cat([query_embed, doc_embed], dim=1)
        return self.fc(combined)

In [4]:
!pip install rank_bm25

Collecting rank_bm25
  Downloading rank_bm25-0.2.2-py3-none-any.whl.metadata (3.2 kB)
Downloading rank_bm25-0.2.2-py3-none-any.whl (8.6 kB)
Installing collected packages: rank_bm25
Successfully installed rank_bm25-0.2.2


In [10]:
from rank_bm25 import BM25Okapi
import torch
import torch.nn as nn

# --- Define Models ---
class Discriminator(nn.Module):
    def __init__(self, vocab_size, embed_dim=128):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Sequential(
            nn.Linear(embed_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

    def forward(self, query, doc):
        query_embed = self.embedding(query).mean(dim=1)
        doc_embed = self.embedding(doc).mean(dim=1)
        combined = torch.cat([query_embed, doc_embed], dim=1)
        return self.fc(combined)

class Generator(nn.Module):
    def __init__(self, doc_ids):  # <-- Requires doc_ids
        super().__init__()
        self.doc_ids = doc_ids

    def forward(self, query, top_k=5):
        scores = bm25.get_scores(query.split())
        noise = torch.randn(len(scores))
        noisy_scores = scores + noise.numpy()
        top_indices = noisy_scores.argsort()[-top_k:][::-1]
        return [self.doc_ids[i] for i in top_indices]


doc_ids = list(documents.keys())  # Get all document IDs
corpus = list(documents.values())
tokenized_corpus = [doc.split() for doc in corpus]
bm25 = BM25Okapi(tokenized_corpus)  # Initialize BM25

discriminator = Discriminator(vocab_size=10000)
generator = Generator(doc_ids)  # Now correctly initialized with doc_ids


In [12]:
import torch.optim as optim
from tqdm import tqdm
import numpy as np

# Hyperparameters
NUM_EPOCHS = 10
BATCH_SIZE = 32
LEARNING_RATE = 0.001

# Optimizers
d_optimizer = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE)

# Loss function
criterion = nn.BCELoss()

# Training loop
for epoch in range(NUM_EPOCHS):
    discriminator.train()

    # Shuffle data
    np.random.shuffle(labeled_data)

    total_d_loss = 0
    total_g_loss = 0

    for i in tqdm(range(0, len(labeled_data), BATCH_SIZE), desc=f"Epoch {epoch+1}"):
        batch = labeled_data[i:i+BATCH_SIZE]

        # Prepare batch data
        real_queries = []
        real_docs = []
        real_labels = []
        fake_docs = []

        for query, doc_id, label in batch:
            # Real data
            real_queries.append(query)
            real_docs.append(doc_id)
            real_labels.append(label)

            # Fake data from generator
            fake_doc_id = generator(query)[0]
            fake_docs.append(fake_doc_id)

        # Convert to tensors with proper shape [batch_size, 1]
        real_query_tensors = torch.LongTensor([[hash(q) % 10000] for q in real_queries])
        real_doc_tensors = torch.LongTensor([[hash(d) % 10000] for d in real_docs])
        fake_doc_tensors = torch.LongTensor([[hash(d) % 10000] for d in fake_docs])
        real_labels = torch.FloatTensor(real_labels).unsqueeze(1)  # [batch_size, 1]

        # Train Discriminator
        d_optimizer.zero_grad()

        # Real data loss - ensure inputs are [batch_size, 1]
        real_outputs = discriminator(real_query_tensors, real_doc_tensors)
        real_loss = criterion(real_outputs, real_labels)

        # Fake data loss
        fake_outputs = discriminator(real_query_tensors, fake_doc_tensors)
        fake_loss = criterion(fake_outputs, torch.zeros_like(fake_outputs))

        d_loss = real_loss + fake_loss
        d_loss.backward()
        d_optimizer.step()
        total_d_loss += d_loss.item()

        # Train Generator (REINFORCE)
        sampled_docs = []
        for query in real_queries:
            sampled_docs.append(generator(query)[0])

        # Get rewards from discriminator
        with torch.no_grad():
            query_tensors = torch.LongTensor([[hash(q) % 10000] for q in real_queries])
            doc_tensors = torch.LongTensor([[hash(d) % 10000] for d in sampled_docs])
            rewards = discriminator(query_tensors, doc_tensors)

        # Generator loss
        g_loss = -torch.mean(torch.log(rewards + 1e-8))
        total_g_loss += g_loss.item()

    print(f"Epoch {epoch+1} | D Loss: {total_d_loss/len(labeled_data):.4f} | G Loss: {total_g_loss/len(labeled_data):.4f}")

# Evaluation function
def evaluate(discriminator, test_data):
    discriminator.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for query, doc_id, label in test_data:
            # Ensure input shape is [1, 1]
            query_tensor = torch.LongTensor([[hash(query) % 10000]])
            doc_tensor = torch.LongTensor([[hash(doc_id) % 10000]])

            output = discriminator(query_tensor, doc_tensor)
            predicted = 1 if output.item() > 0.5 else 0
            correct += (predicted == label)
            total += 1

    accuracy = correct / total
    print(f"Test Accuracy: {accuracy:.4f}")
    return accuracy

# Evaluate on test data
test_accuracy = evaluate(discriminator, val_data)

Epoch 1: 100%|██████████| 92/92 [00:11<00:00,  8.31it/s]


Epoch 1 | D Loss: 0.0263 | G Loss: 0.0429


Epoch 2: 100%|██████████| 92/92 [00:09<00:00,  9.63it/s]


Epoch 2 | D Loss: 0.0215 | G Loss: 0.0573


Epoch 3: 100%|██████████| 92/92 [00:10<00:00,  8.54it/s]


Epoch 3 | D Loss: 0.0209 | G Loss: 0.0590


Epoch 4: 100%|██████████| 92/92 [00:10<00:00,  8.72it/s]


Epoch 4 | D Loss: 0.0203 | G Loss: 0.0581


Epoch 5: 100%|██████████| 92/92 [00:10<00:00,  8.76it/s]


Epoch 5 | D Loss: 0.0199 | G Loss: 0.0594


Epoch 6: 100%|██████████| 92/92 [00:10<00:00,  8.74it/s]


Epoch 6 | D Loss: 0.0194 | G Loss: 0.0612


Epoch 7: 100%|██████████| 92/92 [00:10<00:00,  8.62it/s]


Epoch 7 | D Loss: 0.0191 | G Loss: 0.0606


Epoch 8: 100%|██████████| 92/92 [00:09<00:00,  9.41it/s]


Epoch 8 | D Loss: 0.0186 | G Loss: 0.0627


Epoch 9: 100%|██████████| 92/92 [00:10<00:00,  8.68it/s]


Epoch 9 | D Loss: 0.0182 | G Loss: 0.0641


Epoch 10: 100%|██████████| 92/92 [00:10<00:00,  8.83it/s]

Epoch 10 | D Loss: 0.0177 | G Loss: 0.0650
Test Accuracy: 0.8129





In [13]:
import torch
import os
from datetime import datetime

# Create export directory
export_dir = "irgan_model"
os.makedirs(export_dir, exist_ok=True)

# Save model state dicts
torch.save({
    'discriminator_state_dict': discriminator.state_dict(),
    'generator_state_dict': generator.state_dict(),
    'doc_ids': doc_ids,  # Save document IDs for generator
    'vocab_size': 10000,  # Save vocabulary size
    'timestamp': datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
}, os.path.join(export_dir, 'irgan_model.pth'))

# Save complete models (optional)
torch.save(discriminator, os.path.join(export_dir, 'discriminator_full.pth'))
torch.save(generator, os.path.join(export_dir, 'generator_full.pth'))

print(f"Models saved to {export_dir}")

Models saved to irgan_model


In [16]:
# Model evaluation

import numpy as np
from sklearn.metrics import precision_recall_fscore_support, average_precision_score

def detailed_evaluate(discriminator, test_data, doc_texts, top_k=5):
    discriminator.eval()
    results = {
        'accuracy': 0,
        'precision': [],
        'recall': [],
        'f1': [],
        'average_precision': [],
        'top_k_accuracy': 0,
        'examples': []
    }

    total = 0
    top_k_correct = 0

    all_labels = []
    all_scores = []

    with torch.no_grad():
        for query, true_doc_id, label in tqdm(test_data, desc="Evaluating"):
            # Score all documents for this query
            scores = []
            for doc_id in doc_ids:
                # Tokenize (replace with your actual tokenization)
                query_tensor = torch.LongTensor([[hash(query) % 10000]])
                doc_tensor = torch.LongTensor([[hash(doc_id) % 10000]])

                score = discriminator(query_tensor, doc_tensor).item()
                scores.append((doc_id, score))

            # Sort by score descending
            sorted_docs = sorted(scores, key=lambda x: x[1], reverse=True)

            # Top-k accuracy
            top_k_ids = [doc_id for doc_id, _ in sorted_docs[:top_k]]
            if true_doc_id in top_k_ids:
                top_k_correct += 1

            # For binary metrics
            pred_label = 1 if sorted_docs[0][0] == true_doc_id else 0
            results['accuracy'] += (pred_label == label)

            # Store for aggregate metrics
            all_labels.append(label)
            all_scores.append(sorted_docs[0][1])

            # Save some examples
            if len(results['examples']) < 5:  # Save first 5 examples
                results['examples'].append({
                    'query': query,
                    'true_doc': true_doc_id,
                    'true_doc_text': doc_texts.get(true_doc_id, ""),
                    'predicted_top1': sorted_docs[0][0],
                    'predicted_top1_text': doc_texts.get(sorted_docs[0][0], ""),
                    'top_k': top_k_ids
                })

            total += 1

    # Calculate metrics
    results['accuracy'] /= total
    results['top_k_accuracy'] = top_k_correct / total

    # Precision/Recall/F1
    precision, recall, f1, _ = precision_recall_fscore_support(
        all_labels,
        [1 if score > 0.5 else 0 for score in all_scores],
        average='binary'
    )
    results['precision'] = precision
    results['recall'] = recall
    results['f1'] = f1

    # Average Precision
    results['average_precision'] = average_precision_score(all_labels, all_scores)

    return results

# Run evaluation
doc_texts = {doc_id: text for doc_id, text in zip(doc_ids, corpus)}  # Create doc_id:text mapping
eval_results = detailed_evaluate(discriminator, val_data, doc_texts, top_k=5)

# Print results
print("\n=== Evaluation Results ===")
print(f"Accuracy: {eval_results['accuracy']:.4f}")
print(f"Top-{5} Accuracy: {eval_results['top_k_accuracy']:.4f}")
print(f"Precision: {eval_results['precision']:.4f}")
print(f"Recall: {eval_results['recall']:.4f}")
print(f"F1 Score: {eval_results['f1']:.4f}")
print(f"Average Precision: {eval_results['average_precision']:.4f}")

# Print example predictions
print("\n=== Example Predictions ===")
for i, example in enumerate(eval_results['examples']):
    print(f"\nExample {i+1}:")
    print(f"Query: {example['query']}")
    print(f"True Document: {example['true_doc']}")
    print(f"True Doc Text: {example['true_doc_text'][:100]}...")
    print(f"Predicted Top1: {example['predicted_top1']}")
    print(f"Predicted Text: {example['predicted_top1_text'][:100]}...")
    print(f"Top-{5} Predictions: {example['top_k']}")

Evaluating: 100%|██████████| 588/588 [00:32<00:00, 18.16it/s]


=== Evaluation Results ===
Accuracy: 0.7670
Top-5 Accuracy: 0.0221
Precision: 0.2393
Recall: 0.9929
F1 Score: 0.3857
Average Precision: 0.2080

=== Example Predictions ===

Example 1:
Query: What are the welfare issues associated with feeding housed cattle
True Document: https__www_part1
True Doc Text: iv • Contents
28 Factors Affecting Milk Quality 373 (b) Enzootic bovine leukosis 693
R.W.Blowey and ...
Predicted Top1: https__www_part182
Predicted Text: Neurological Disorders • 897
When the lumbosacral site has been surgically pre- arachnoid space.The ...
Top-5 Predictions: ['https__www_part182', 'https__www_part129', 'https__www_part58', 'https__www_part61', 'https__www_part142']

Example 2:
Query: What are the effects of vitamin A toxicity in calves
True Document: https__www_part245
True Doc Text: 1212 • Index
pregnancy (cont’d) prostaglandin(s) pyaemia 737
establishment 684–6 breeding synchroniz...
Predicted Top1: https__www_part182
Predicted Text: Neurological Disorders • 897
Whe


