# LSR: Latent Structure Refinement

Paper: https://aclanthology.org/2020.acl-main.141.pdf

Git: https://github.com/aisingapore/sgnlp/tree/main

## LSR Training

Currently not working

In [None]:
import pandas as pd
from tqdm import tqdm
import torch
import re
import spacy
import numpy as np
from collections import defaultdict
import os
import traceback

# Install required packages if not already installed
# !pip install sgnlp spacy
# !python -m spacy download en_core_web_sm

from sgnlp.sgnlp.models.lsr import LsrModel, LsrConfig, LsrPreprocessor, LsrPostprocessor

# --- Device selection ---
device_to_use = 0 if torch.cuda.is_available() else -1
device = torch.device(f'cuda:{device_to_use}' if device_to_use >= 0 else 'cpu')
print(f"Using device: {device}")

# --- Load NER model for entity recognition ---
nlp = spacy.load("en_core_web_sm")

# --- Load LSR model and preprocessor ---
print("Loading LSR model files...")

# load the files from /content/ directory
rel2id_path = '/content/rel2id.json'
word2id_path = '/content/word2id.json'
ner2id_path = '/content/ner2id.json'
rel_info_path = '/content/rel_info.json'

# Set prediction threshold
PRED_THRESHOLD = 0.3

# Initialize preprocessor and postprocessor
try:
    preprocessor = LsrPreprocessor(
        rel2id_path=rel2id_path, 
        word2id_path=word2id_path, 
        ner2id_path=ner2id_path
    )

    postprocessor = LsrPostprocessor.from_file_paths(
        rel2id_path=rel2id_path, 
        rel_info_path=rel_info_path, 
        pred_threshold=PRED_THRESHOLD
    )

    # Load model
    config_path = "/content/config.json"
    model_path = "/content/pytorch_model.bin"

    config = LsrConfig.from_json_file(config_path)
    model = LsrModel(config)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    print("LSR model loaded successfully.")
except Exception as e:
    print(f"Error loading LSR model: {e}")
    traceback.print_exc()

# Map spaCy entity types to LSR expected types
def map_entity_type(ent_type):
    """Map spaCy entity types to LSR entity types"""
    # This mapping might need adjustment based on ner2id.json content
    type_mapping = {
        'PERSON': 'PER',
        'ORG': 'ORG',
        'GPE': 'LOC',
        'LOC': 'LOC',
        'NORP': 'MISC',
        'CARDINAL': 'NUM',
        'DATE': 'TIME',
        'MONEY': 'NUM',
        'TIME': 'TIME',
        'PERCENT': 'NUM',
        'PRODUCT': 'MISC',
        'EVENT': 'MISC',
        'FAC': 'LOC',
        'WORK_OF_ART': 'MISC',
        'LAW': 'MISC',
        'LANGUAGE': 'MISC',
        'QUANTITY': 'NUM'
    }
    # Default to 'MISC' if not in mapping
    return type_mapping.get(ent_type, 'MISC')

def get_sentences(text):
    """Split text into sentences"""
    doc = nlp(text)
    return [sent.text.strip() for sent in doc.sents]

def extract_entities(text):
    """Extract entities from text using spaCy with type mapping"""
    doc = nlp(text)
    entities = []
    entity_to_idx = {}
    idx = 0
    
    # Extract entities by sentence
    sentences = list(doc.sents)
    
    for sent_idx, sent in enumerate(sentences):
        for ent in sent.ents:
            # Skip entities that are too short
            if len(ent.text.strip()) < 2:
                continue
                
            # Map entity type to LSR compatible type
            mapped_type = map_entity_type(ent.label_)
            
            # Create entity object compatible with LSR format
            entity = {
                "name": ent.text,
                "pos": [ent.start - sent.start, ent.end - sent.start],
                "sent_id": sent_idx,
                "type": mapped_type  # Use mapped entity type
            }
            
            # Group mentions of the same entity (simple exact match)
            if ent.text not in entity_to_idx:
                entity_to_idx[ent.text] = idx
                entities.append([])
                idx += 1
            
            entities[entity_to_idx[ent.text]].append(entity)
    
    return entities

def prepare_lsr_input(text):
    """Prepare input in the format required by LSR model"""
    # Get sentences from text
    text = text[:5000]  # Limit text length to prevent issues
    sentences = get_sentences(text)
    
    if not sentences:
        return None
    
    # Extract entities with mapped types
    vertex_set = extract_entities(text)
    
    # Skip if no entities found
    if not vertex_set:
        return None
    
    # Format sentences as tokens
    tokenized_sents = [[token for token in sent.split()] for sent in sentences]
    
    # LSR requires this specific format
    instance = {
        "vertexSet": vertex_set,
        "labels": [],  # No pre-defined labels when extracting
        "sents": tokenized_sents
    }
    
    return instance

def extract_triplets_lsr(text):
    """Extract triplets using the LSR model with improved error handling"""
    triplets = []
    
    try:
        # Prepare input for the LSR model
        instance = prepare_lsr_input(text)
        
        # Skip processing if no valid input
        if not instance:
            return triplets
        
        # Process through LSR pipeline
        model_inputs = preprocessor([instance])
        
        # Skip if preprocessing failed
        if not model_inputs:
            return triplets
        
        # Move inputs to the right device
        for key in model_inputs.keys():
            if isinstance(model_inputs[key], torch.Tensor):
                model_inputs[key] = model_inputs[key].to(device)
        
        # Get predictions
        with torch.no_grad():
            predictions = model(model_inputs)
        
        # Skip if no predictions
        if not predictions or "logits" not in predictions:
            return triplets
        
        # Post-process predictions
        processed_preds = postprocessor(
            instances=[instance],
            pred_logits=predictions["logits"].cpu().numpy()
        )
        
        # Convert predictions to triplets format
        for pred in processed_preds[0]:
            head_idx = pred.get('h_idx')
            tail_idx = pred.get('t_idx')
            
            # Skip if invalid indices
            if (head_idx is None or tail_idx is None or 
                head_idx >= len(instance["vertexSet"]) or 
                tail_idx >= len(instance["vertexSet"]) or
                not instance["vertexSet"][head_idx] or 
                not instance["vertexSet"][tail_idx]):
                continue
            
            # Extract triplet information
            head_entity = instance["vertexSet"][head_idx][0]["name"]
            tail_entity = instance["vertexSet"][tail_idx][0]["name"]
            relation = pred.get('pred_rel', 'unknown')
            head_type = instance["vertexSet"][head_idx][0].get("type", "unknown")
            tail_type = instance["vertexSet"][tail_idx][0].get("type", "unknown")
            
            # Add triplet to results
            triplets.append({
                'head': head_entity,
                'head_type': head_type,
                'type': relation,
                'tail': tail_entity,
                'tail_type': tail_type
            })
            
    except Exception as e:
        # Print detailed error information
        print(f"Error extracting triplets with LSR: {e}")
        traceback.print_exc()
        
    return triplets

# --- Main processing function ---
def process_data(input_csv, output_csv, debug=False):
    """Process data with better error handling and debugging options"""
    print(f"Loading data from {input_csv}...")
    
    try:
        df = pd.read_csv(input_csv)
        all_triplets = []
        
        for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing documents with LSR"):
            try:
                document = row['ID']
                text = row['Text']
                
                if not isinstance(text, str) or len(text.strip()) == 0:
                    if debug:
                        print(f"Skipping document {document}: Invalid text")
                    continue
                
                # Process text in manageable chunks
                max_chunk_length = 500  # Reduced chunk size for better stability
                chunks = [text[i:i+max_chunk_length] for i in range(0, len(text), max_chunk_length)]
                
                document_triplets = []
                
                for chunk_idx, chunk in enumerate(chunks):
                    if debug and chunk_idx > 0:
                        print(f"Processing chunk {chunk_idx+1}/{len(chunks)} for document {document}")
                    
                    # Extract triplets using LSR
                    chunk_triplets = extract_triplets_lsr(chunk)
                    document_triplets.extend(chunk_triplets)
                
                # Add to results
                for i, triplet in enumerate(document_triplets, 1):
                    all_triplets.append({
                        "DOCUMENT": document,
                        "SUBLABEL": i,
                        "MODEL": "lsr",
                        "HEAD": triplet.get('head', ''),
                        "RELATION": triplet.get('type', ''),
                        "TAIL": triplet.get('tail', ''),
                        "HEAD_TYPE": triplet.get('head_type', ''),
                        "TAIL_TYPE": triplet.get('tail_type', '')
                    })
                    
                if debug and document_triplets:
                    print(f"Found {len(document_triplets)} triplets in document {document}")
                    
            except Exception as e:
                print(f"Error processing document at index {index}: {e}")
                if debug:
                    traceback.print_exc()
        
        # Create DataFrame and save results
        result_df = pd.DataFrame(all_triplets)
        result_df.to_csv(output_csv, index=False)
        print(f"Processing complete. Found {len(all_triplets)} triplets. Saved to {output_csv}")
        
    except Exception as e:
        print(f"Error in process_data: {e}")
        traceback.print_exc()

# Run the processor
if __name__ == "__main__":
    # Set debug=True to get more detailed output
    process_data('JuanRana_split.csv', 'lsr_triplets_output.csv', debug=True)