In [1]:
import json
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from transformers import BartTokenizer, BartForConditionalGeneration
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import nltk
import re
from tqdm import tqdm
import os

# Download NLTK resources
nltk.download('punkt', quiet=True)
nltk.download('stopwords', quiet=True)
from nltk.corpus import stopwords

def load_qtl_data(file_path):
    """
    Load QTL data from JSON file with robust error handling
    """
    papers = []
    
    try:
        # First try to load as a JSON array
        with open(file_path, 'r', encoding='utf-8') as f:
            try:
                data = json.load(f)
                if isinstance(data, list):
                    papers = data
                    print(f"Successfully loaded {len(papers)} papers as JSON array")
                    return papers
            except json.JSONDecodeError:
                # Not a valid JSON array, continue to line-by-line parsing
                pass
                
        # Try line-by-line JSON objects
        with open(file_path, 'r', encoding='utf-8') as f:
            for i, line in enumerate(f):
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                try:
                    # Handle case where commas might be missing between objects
                    if line.endswith('},') or line.endswith('}'):
                        if not line.startswith('{'):
                            line = '{' + line
                        paper = json.loads(line.rstrip(','))
                        papers.append(paper)
                except json.JSONDecodeError as e:
                    print(f"Error parsing line {i+1}: {e}")
                    # Try to fix common JSON issues
                    try:
                        if line.endswith(','):
                            line = line[:-1]
                        if not line.startswith('{'):
                            line = '{' + line
                        if not line.endswith('}'):
                            line = line + '}'
                        paper = json.loads(line)
                        papers.append(paper)
                        print(f"Fixed and loaded line {i+1}")
                    except:
                        print(f"Skipping problematic line {i+1}")
                        
            print(f"Successfully loaded {len(papers)} papers line by line")
    except FileNotFoundError:
        print(f"File not found: {file_path}")
        # Create a small sample dataset for testing
        papers = [
            {
                "PMID": "17179536", 
                "Journal": "J Anim Sci. 2007", 
                "Title": "Variance component analysis of quantitative trait loci for pork carcass composition", 
                "Abstract": "In a previous study, QTL for carcass composition and meat quality were identified...", 
                "Category": "1"
            },
            {
                "PMID": "17177700", 
                "Journal": "J Anim Breed Genet", 
                "Title": "Single nucleotide polymorphism identification in porcine genes", 
                "Abstract": "Pituitary adenylate cyclase-activating polypeptide is a neuropeptide with diverse biological actions...", 
                "Category": "0"
            }
        ]
        print("Created a small sample dataset for testing")
    
    if len(papers) == 0:
        # Create a small sample dataset as fallback
        papers = [
            {
                "PMID": "17179536", 
                "Journal": "J Anim Sci. 2007", 
                "Title": "Variance component analysis of quantitative trait loci for pork carcass composition", 
                "Abstract": "In a previous study, QTL for carcass composition and meat quality were identified...", 
                "Category": "1"
            },
            {
                "PMID": "17177700", 
                "Journal": "J Anim Breed Genet", 
                "Title": "Single nucleotide polymorphism identification in porcine genes", 
                "Abstract": "Pituitary adenylate cyclase-activating polypeptide is a neuropeptide with diverse biological actions...", 
                "Category": "0"
            }
        ]
        print("No papers found. Created a small sample dataset.")
    
    # Verify data structure and remove PMID
    cleaned_papers = []
    for paper in papers:
        # Skip malformed papers
        if not isinstance(paper, dict):
            continue
            
        # Ensure all required fields exist
        if not all(key in paper for key in ['Title', 'Abstract']):
            continue
            
        # Create a copy without PMID
        paper_copy = {k: v for k, v in paper.items() if k != 'PMID'}
        cleaned_papers.append(paper_copy)
    
    print(f"Cleaned and prepared {len(cleaned_papers)} valid papers")
    return cleaned_papers

def load_test_data(file_path):
    """
    Load test data from TSV file with robust error handling
    """
    try:
        # Try pandas read_csv
        df = pd.read_csv(file_path, sep='\t')
        test_data = []
        original_titles = []
        
        for _, row in df.iterrows():
            # Skip rows with missing data
            if pd.isna(row['Abstract']) or pd.isna(row['Title']):
                continue
            
            # Include PMID for reference (though we won't use it for prediction)
            test_item = {
                'Abstract': row['Abstract'],
                'Label': row['Label'] if 'Label' in row else 0
            }
            
            # Save PMID separately for reference
            if 'PMID' in row and not pd.isna(row['PMID']):
                test_item['PMID'] = str(row['PMID'])
                
            test_data.append(test_item)
            original_titles.append(row['Title'])
        
        print(f"Successfully loaded {len(test_data)} test samples")
        return test_data, original_titles
        
    except Exception as e:
        print(f"Error loading test data: {e}")
        
        # Create sample test data as fallback
        test_data = [
            {
                'Abstract': "Porcine circovirus type 3 is regularly reported in association with various clinical presentations...",
                'Label': 0
            },
            {
                'Abstract': "This study investigated using imputed genotypes from non-genotyped animals...",
                'Label': 0
            }
        ]
        
        original_titles = [
            "Detection of porcine circovirus type 3 DNA in serum and semen samples",
            "Imputation of non-genotyped F1 dams to improve genetic gain in swine"
        ]
        
        print("Created a small sample test dataset")
        return test_data, original_titles

def extract_key_phrases(text):
    """Extract important phrases from the abstract"""
    # Get scientific terms (capitalized terms)
    scientific_terms = re.findall(r'\b[A-Z][a-zA-Z0-9]+([-\/][a-zA-Z0-9]+)*\b', text)
    
    # Get method terms
    method_terms = re.findall(r'\b(analysis|study|investigation|evaluation|detection|identification|characterization|assessment|method|approach|technique|mapping|imputation)\b', text.lower())
    
    # Get species or organisms
    species_terms = re.findall(r'\b(pig|porcine|swine|boar|piglet|animal|livestock|cattle|sheep|chicken|goat)\b', text.lower())
    
    # Get important biomedical terms
    biomedical_terms = re.findall(r'\b(gene|protein|genetic|genomic|DNA|RNA|SNP|QTL|marker|chromosome|trait|allele|variant|phenotype|genotype)\b', text.lower())
    
    return {
        'scientific_terms': list(set(scientific_terms)),
        'method_terms': list(set(method_terms)),
        'species_terms': list(set(species_terms)),
        'biomedical_terms': list(set(biomedical_terms))
    }

def generate_rule_based_title(abstract):
    """Generate a title using linguistic patterns common in scientific papers"""
    # Extract key information
    key_phrases = extract_key_phrases(abstract)
    
    # Get first and last sentences
    sentences = nltk.sent_tokenize(abstract)
    first_sentence = sentences[0] if sentences else ""
    
    # Extract main focus from first sentence
    focus_match = re.search(r'(investigat|stud|examin|analyz|assess|determin|evaluat|develop|identif|characteriz)[a-z]* (the|of|how|whether) ([^\.]+)', first_sentence, re.IGNORECASE)
    focus_phrase = focus_match.group(3) if focus_match else ""
    
    # Generate title based on patterns
    if key_phrases['scientific_terms'] and key_phrases['method_terms'] and key_phrases['species_terms']:
        # Pattern: "Analysis of [Scientific Term] in [Species]"
        method = key_phrases['method_terms'][0].capitalize()
        term = key_phrases['scientific_terms'][0]
        species = key_phrases['species_terms'][0]
        return f"{method} of {term} in {species}"
        
    elif key_phrases['scientific_terms'] and key_phrases['biomedical_terms']:
        # Pattern: "[Scientific Term] and its effect on [Biomedical Term]"
        term = key_phrases['scientific_terms'][0]
        bio_term = key_phrases['biomedical_terms'][0]
        return f"{term} and its association with {bio_term}"
        
    elif focus_phrase:
        # Extract a concise version of the focus phrase
        words = focus_phrase.split()
        if len(words) > 8:
            focus_phrase = " ".join(words[:8])
        return f"{focus_phrase.capitalize()}"
        
    else:
        # Extract key phrases from first sentence
        words = first_sentence.split()
        if len(words) > 10:
            return " ".join(words[:10])
        else:
            return first_sentence
            
def train_bart_model(train_data, val_data=None, model_name="facebook/bart-base", epochs=3):
    """Fine-tune BART model for title generation"""
    # Load pre-trained model and tokenizer
    tokenizer = BartTokenizer.from_pretrained(model_name)
    model = BartForConditionalGeneration.from_pretrained(model_name)
    
    # Check for valid training data
    if len(train_data) == 0:
        print("No training data available. Returning base model.")
        return model, tokenizer
        
    # Prepare data
    train_inputs = []
    train_targets = []
    
    for paper in train_data:
        if 'Abstract' in paper and 'Title' in paper:
            train_inputs.append(paper['Abstract'])
            train_targets.append(paper['Title'])
    
    if len(train_inputs) == 0:
        print("No valid training examples found. Returning base model.")
        return model, tokenizer
        
    print(f"Training with {len(train_inputs)} examples")
    
    # Tokenize inputs and targets
    max_input_length = min(512, max([len(tokenizer.encode(text)) for text in train_inputs]))
    max_target_length = min(128, max([len(tokenizer.encode(text)) for text in train_targets]))
    
    print(f"Max input length: {max_input_length}, Max target length: {max_target_length}")
    
    train_encodings = tokenizer(train_inputs, truncation=True, padding='max_length', 
                               max_length=max_input_length, return_tensors='pt')
    target_encodings = tokenizer(train_targets, truncation=True, padding='max_length', 
                                max_length=max_target_length, return_tensors='pt')
    
    # Create dataset
    train_dataset = torch.utils.data.TensorDataset(
        train_encodings.input_ids,
        train_encodings.attention_mask,
        target_encodings.input_ids
    )
    
    # Set batch size based on data size
    batch_size = min(4, len(train_dataset))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    
    # Training settings
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    
    # Training loop
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            labels = batch[2].to(device)
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            loss = outputs.loss
            total_loss += loss.item()
            
            # Backward pass
            loss.backward()
            optimizer.step()
        
        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")
    
    return model, tokenizer

def generate_titles(model, tokenizer, test_data, method='bart'):
    """Generate titles using the specified method"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()
    
    generated_titles = []
    
    for paper in tqdm(test_data, desc="Generating titles"):
        abstract = paper['Abstract']
        
        if method == 'rule':
            # Rule-based title generation
            title = generate_rule_based_title(abstract)
            generated_titles.append(title)
            
        elif method == 'bart':
            # Generate with BART
            input_ids = tokenizer(abstract, return_tensors='pt', truncation=True, 
                                 max_length=512).input_ids.to(device)
            
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=input_ids,
                    max_length=64,
                    num_beams=4,
                    length_penalty=2.0,
                    early_stopping=True,
                    no_repeat_ngram_size=2
                )
            
            title = tokenizer.decode(outputs[0], skip_special_tokens=True)
            generated_titles.append(title)
            
        elif method == 'hybrid':
            # Generate both and select the better one
            rule_title = generate_rule_based_title(abstract)
            
            input_ids = tokenizer(abstract, return_tensors='pt', truncation=True, 
                                 max_length=512).input_ids.to(device)
            
            with torch.no_grad():
                outputs = model.generate(
                    input_ids=input_ids,
                    max_length=64,
                    num_beams=4,
                    length_penalty=2.0,
                    early_stopping=True
                )
            
            bart_title = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Choose based on quality heuristics
            key_phrases = extract_key_phrases(abstract)
            bart_quality = sum(1 for term in key_phrases['scientific_terms'] if term.lower() in bart_title.lower())
            rule_quality = sum(1 for term in key_phrases['scientific_terms'] if term.lower() in rule_title.lower())
            
            if len(bart_title.split()) >= 4 and (bart_quality >= rule_quality or len(bart_title) < 100):
                generated_titles.append(bart_title)
            else:
                generated_titles.append(rule_title)
    
    return generated_titles

def evaluate_titles(generated_titles, reference_titles):
    """Calculate BLEU and ROUGE scores"""
    if len(generated_titles) == 0 or len(reference_titles) == 0:
        print("No titles to evaluate")
        return {
            'bleu': 0.0,
            'rouge2': 0.0,
            'rougeL': 0.0
        }
    
    # Calculate BLEU scores
    smoothing = SmoothingFunction().method1
    bleu_scores = []
    
    for gen, ref in zip(generated_titles, reference_titles):
        gen_tokens = gen.lower().split()
        ref_tokens = [ref.lower().split()]
        
        if len(gen_tokens) == 0:
            bleu_scores.append(0.0)
            continue
            
        try:
            bleu = sentence_bleu(ref_tokens, gen_tokens, smoothing_function=smoothing)
            bleu_scores.append(bleu)
        except Exception as e:
            print(f"BLEU calculation error: {e}")
            bleu_scores.append(0.0)
    
    avg_bleu = sum(bleu_scores) / len(bleu_scores) if bleu_scores else 0.0
    
    # Calculate ROUGE scores
    scorer = rouge_scorer.RougeScorer(['rouge2', 'rougeL'], use_stemmer=True)
    rouge2_scores = []
    rougeL_scores = []
    
    for gen, ref in zip(generated_titles, reference_titles):
        try:
            scores = scorer.score(ref, gen)
            rouge2_scores.append(scores['rouge2'].fmeasure)
            rougeL_scores.append(scores['rougeL'].fmeasure)
        except Exception as e:
            print(f"ROUGE calculation error: {e}")
            rouge2_scores.append(0.0)
            rougeL_scores.append(0.0)
    
    avg_rouge2 = sum(rouge2_scores) / len(rouge2_scores) if rouge2_scores else 0.0
    avg_rougeL = sum(rougeL_scores) / len(rougeL_scores) if rougeL_scores else 0.0
    
    return {
        'bleu': avg_bleu,
        'rouge2': avg_rouge2,
        'rougeL': avg_rougeL
    }

def main():
    # File paths - adjust as needed
    qtl_json_path = "QTL_text.json"
    test_tsv_path = "test_unlabeled.tsv"
    
    # Load data with robust error handling
    print("Loading QTL data...")
    qtl_data = load_qtl_data(qtl_json_path)
    
    # Split data if we have enough samples
    if len(qtl_data) > 5:
        train_data, val_data = train_test_split(qtl_data, test_size=0.2, random_state=42)
        print(f"Split data: {len(train_data)} training, {len(val_data)} validation")
    else:
        # Use all data for training if small dataset
        train_data = qtl_data
        val_data = qtl_data
        print(f"Using all {len(qtl_data)} samples for both training and validation")
    
    # Load test data
    print("Loading test data...")
    test_data, original_titles = load_test_data(test_tsv_path)
    
    # Extract PMIDs from test data for tracking purposes
    # (We'll only use these for saving results, not for prediction)
    test_pmids = []
    for paper in test_data:
        if 'PMID' in paper:
            test_pmids.append(paper['PMID'])
        else:
            test_pmids.append('Unknown')
    
    # Generate titles using rule-based approach
    print("Generating titles with rule-based approach...")
    rule_based_titles = []
    for paper in test_data:
        title = generate_rule_based_title(paper['Abstract'])
        rule_based_titles.append(title)
    
    # Evaluate rule-based approach
    print("Evaluating rule-based approach...")
    rule_metrics = evaluate_titles(rule_based_titles, original_titles)
    
    # Train BART model with safeguards
    print("Training BART model...")
    try:
        model, tokenizer = train_bart_model(train_data, val_data, epochs=3)
        
        # Generate titles with BART
        print("Generating titles with BART...")
        bart_titles = generate_titles(model, tokenizer, test_data, method='bart')
        
        # Evaluate BART approach
        print("Evaluating BART approach...")
        bart_metrics = evaluate_titles(bart_titles, original_titles)
        
        # Generate titles with hybrid approach
        print("Generating titles with hybrid approach...")
        hybrid_titles = generate_titles(model, tokenizer, test_data, method='hybrid')
        
        # Evaluate hybrid approach
        print("Evaluating hybrid approach...")
        hybrid_metrics = evaluate_titles(hybrid_titles, original_titles)
        
        # Print results
        print("\n=== RESULTS ===")
        print("\nRULE-BASED APPROACH:")
        print(f"BLEU: {rule_metrics['bleu']:.4f}")
        print(f"ROUGE-2: {rule_metrics['rouge2']:.4f}")
        print(f"ROUGE-L: {rule_metrics['rougeL']:.4f}")
        
        print("\nBART APPROACH:")
        print(f"BLEU: {bart_metrics['bleu']:.4f}")
        print(f"ROUGE-2: {bart_metrics['rouge2']:.4f}")
        print(f"ROUGE-L: {bart_metrics['rougeL']:.4f}")
        
        print("\nHYBRID APPROACH:")
        print(f"BLEU: {hybrid_metrics['bleu']:.4f}")
        print(f"ROUGE-2: {hybrid_metrics['rouge2']:.4f}")
        print(f"ROUGE-L: {hybrid_metrics['rougeL']:.4f}")
        
        # Determine best approach based on BLEU score
        approaches = {
            'Rule-based': rule_metrics['bleu'],
            'BART': bart_metrics['bleu'],
            'Hybrid': hybrid_metrics['bleu']
        }
        best_approach = max(approaches.items(), key=lambda x: x[1])[0]
        
        print(f"\nBest approach based on BLEU score: {best_approach}")
        
        # Print sample predictions
        best_titles = {
            'Rule-based': rule_based_titles,
            'BART': bart_titles,
            'Hybrid': hybrid_titles
        }[best_approach]
        
        # Save all generated titles to CSV files for comparison
        results_df = pd.DataFrame({
            'PMID': [paper.get('PMID', 'Unknown') for paper in test_data],
            'Original_Title': original_titles,
            'Rule_Based_Title': rule_based_titles,
            'BART_Title': bart_titles,
            'Hybrid_Title': hybrid_titles,
            'BLEU_Score': [sentence_bleu([ref.lower().split()], gen.lower().split(), smoothing_function=SmoothingFunction().method1) 
                           for ref, gen in zip(original_titles, best_titles)]
        })
        
        # Save to CSV
        results_df.to_csv('generated_titles_comparison.csv', index=False)
        print(f"\nSaved all generated titles to 'generated_titles_comparison.csv'")
        
        # Also save individual approach results
        for approach_name, titles in [('rule_based', rule_based_titles), 
                                     ('bart', bart_titles), 
                                     ('hybrid', hybrid_titles)]:
            approach_df = pd.DataFrame({
                'PMID': [paper.get('PMID', 'Unknown') for paper in test_data],
                'Original_Title': original_titles,
                'Generated_Title': titles
            })
            approach_df.to_csv(f'{approach_name}_titles.csv', index=False)
            print(f"Saved {approach_name} titles to '{approach_name}_titles.csv'")
        
        print(f"\nSample predictions from {best_approach} approach:")
        for i in range(min(5, len(test_data))):
            print(f"\nAbstract (truncated): {test_data[i]['Abstract'][:150]}...")
            print(f"Ground truth: {original_titles[i]}")
            print(f"Predicted: {best_titles[i]}")
            
    except Exception as e:
        print(f"Error during model training or evaluation: {e}")
        print("Falling back to rule-based approach only")
        
        # Print rule-based results
        print("\n=== RESULTS ===")
        print("\nRULE-BASED APPROACH:")
        print(f"BLEU: {rule_metrics['bleu']:.4f}")
        print(f"ROUGE-2: {rule_metrics['rouge2']:.4f}")
        print(f"ROUGE-L: {rule_metrics['rougeL']:.4f}")
        
        # Print sample predictions
        print("\nSample predictions from Rule-based approach:")
        for i in range(min(5, len(test_data))):
            print(f"\nAbstract (truncated): {test_data[i]['Abstract'][:150]}...")
            print(f"Ground truth: {original_titles[i]}")
            print(f"Predicted: {rule_based_titles[i]}")

if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Loading QTL data...
Successfully loaded 11278 papers as JSON array
Split data: 9022 training, 2256 validation
Loading test data...
Successfully loaded 1097 test samples
Generating titles with rule-based approach...
Evaluating rule-based approach...
Training BART model...
Training with 9022 examples
Max input length: 512, Max target length: 77


Epoch 1/3: 100%|██████████| 2256/2256 [03:44<00:00, 10.06it/s]


Epoch 1/3, Loss: 0.9111


Epoch 2/3: 100%|██████████| 2256/2256 [03:44<00:00, 10.04it/s]


Epoch 2/3, Loss: 0.6501


Epoch 3/3: 100%|██████████| 2256/2256 [03:44<00:00, 10.04it/s]


Epoch 3/3, Loss: 0.5456
Generating titles with BART...


Generating titles: 100%|██████████| 1097/1097 [03:25<00:00,  5.35it/s]


Evaluating BART approach...
Generating titles with hybrid approach...


Generating titles: 100%|██████████| 1097/1097 [04:49<00:00,  3.79it/s]


Evaluating hybrid approach...

=== RESULTS ===

RULE-BASED APPROACH:
BLEU: 0.0147
ROUGE-2: 0.0399
ROUGE-L: 0.1770

BART APPROACH:
BLEU: 0.1293
ROUGE-2: 0.2616
ROUGE-L: 0.4205

HYBRID APPROACH:
BLEU: 0.1297
ROUGE-2: 0.2616
ROUGE-L: 0.4209

Best approach based on BLEU score: Hybrid

Saved all generated titles to 'generated_titles_comparison.csv'
Saved rule_based titles to 'rule_based_titles.csv'
Saved bart titles to 'bart_titles.csv'
Saved hybrid titles to 'hybrid_titles.csv'

Sample predictions from Hybrid approach:

Abstract (truncated): Porcine circovirus type 3 (PCV3) is regularly reported in association with various clinical presentations, including porcine dermatitis and nephropath...
Ground truth: Detection of porcine circovirus type 3 DNA in serum and semen samples of boars from a German boar stud.
Predicted: Detection of porcine circovirus type 3 DNA in boar semen from a German stud supplying semen for artificial insemination.

Abstract (truncated): This study investigated using