<a href="https://colab.research.google.com/github/joepareti54/joepareti54/blob/main/lm_rag_gpt2_test5LARGE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Install necessary libraries
!pip install pymupdf sentence-transformers faiss-gpu transformers
import fitz
import os
import numpy as np
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AutoTokenizer, AutoModelForCausalLM
from google.colab import drive
import warnings
import re
import json
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from tqdm import tqdm
import gc
import pickle

class EnhancedFinanceNewsProcessor:
    def __init__(self, directory_path: str):
        self.directory_path = directory_path
        self.documents: List[str] = []
        self.document_metadata: List[Dict[str, Any]] = []
        self.embed_model: Optional[SentenceTransformer] = None
        self.tokenizer: Optional[AutoTokenizer] = None
        self.model: Optional[AutoModelForCausalLM] = None
        self.device: Optional[torch.device] = None
        self.index: Optional[faiss.IndexFlatL2] = None

        # Configuration
        self.chunk_size = 500  # Document processing chunk size
        self.embedding_batch_size = 100  # Embedding creation batch size
        self.max_context_tokens = 300  # Maximum context tokens for generation
        self.cache_dir = 'finance_news_cache'
        self.create_cache_dir()

    def create_cache_dir(self) -> None:
        """Create cache directory if it doesn't exist."""
        if not os.path.exists(self.cache_dir):
            os.makedirs(self.cache_dir)

    def get_cache_paths(self) -> Tuple[str, str, str]:
        """Get paths for cache files."""
        embeddings_cache = os.path.join(self.cache_dir, 'embeddings.faiss')
        documents_cache = os.path.join(self.cache_dir, 'documents.pkl')
        metadata_cache = os.path.join(self.cache_dir, 'metadata.json')
        return embeddings_cache, documents_cache, metadata_cache

    def init_models(self) -> None:
        """Initialize NLP models with improved error handling."""
        print("\nInitializing models...")
        try:
            # Initialize embedding model
            self.embed_model = SentenceTransformer('all-MiniLM-L6-v2')

            # Initialize language model and tokenizer
            model_name = 'gpt2'  # Can be changed to other models
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(model_name)

            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Set up device
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            self.model.to(self.device)
            print(f"Using device: {self.device}")

        except Exception as e:
            raise RuntimeError(f"Error initializing models: {str(e)}")

    def clean_text(self, text: str) -> str:
        """Enhanced text cleaning with better preservation of important content."""
        # Remove special characters while preserving essential punctuation
        text = re.sub(r'[^\w\s.,!?;:()\-\'\"$%]+', ' ', text)

        # Fix common spacing issues
        text = re.sub(r'\s+', ' ', text)
        text = re.sub(r'\s([.,!?])', r'\1', text)

        # Fix common OCR issues
        text = text.replace('|', 'I').replace('1', 'l')

        # Remove repeated punctuation
        text = re.sub(r'([.,!?])\1+', r'\1', text)

        return text.strip()

    def extract_metadata(self, filename: str, text: str) -> Dict[str, Any]:
        """Enhanced metadata extraction with more features."""
        # Extract date from filename or content
        date_patterns = [
            r'\d{4}-\d{2}-\d{2}',
            r'\d{2}/\d{2}/\d{4}',
            r'(?:January|February|March|April|May|June|July|August|September|October|November|December)\s+\d{1,2},\s+\d{4}'
        ]

        date = None
        for pattern in date_patterns:
            date_match = re.search(pattern, filename) or re.search(pattern, text[:200])
            if date_match:
                date = date_match.group(0)
                break

        # Extract title
        first_lines = text.split('\n')[:3]
        title = next((line for line in first_lines if len(line.split()) > 3), filename)

        # Extract keywords
        keywords = set()
        important_words = re.findall(r'\b[A-Z][a-z]{2,}\b', text[:500])
        keywords.update(important_words[:10])

        return {
            'filename': filename,
            'date': date,
            'title': title[:100],
            'length': len(text),
            'keywords': list(keywords),
            'processed_date': datetime.now().isoformat()
        }

    def extract_text_from_pdf(self, pdf_path: str) -> Optional[str]:
        """Enhanced PDF text extraction with better handling of complex PDFs."""
        try:
            doc = fitz.open(pdf_path)
            text_parts = []

            for page in doc:
                # Get text and remove headers/footers
                text = page.get_text()
                lines = text.split('\n')

                # Remove headers and footers (usually first and last lines)
                if len(lines) > 2:
                    lines = lines[1:-1]

                cleaned_text = '\n'.join(lines)
                if cleaned_text.strip():
                    text_parts.append(cleaned_text)

            doc.close()
            full_text = ' '.join(text_parts)
            return self.clean_text(full_text)

        except Exception as e:
            print(f"Error extracting text from {pdf_path}: {str(e)}")
            return None

    def load_documents_in_chunks(self, total_limit: int = 2000) -> bool:
        """Load documents in chunks with progress tracking."""
        print("\nLoading documents in chunks...")

        files = [f for f in os.listdir(self.directory_path) if f.lower().endswith('.pdf')]
        total_loaded = 0
        chunk_start = 0

        with tqdm(total=min(total_limit, len(files)), desc="Loading documents") as pbar:
            while chunk_start < len(files) and total_loaded < total_limit:
                chunk_end = min(chunk_start + self.chunk_size, len(files))
                chunk_files = files[chunk_start:chunk_end]

                for filename in chunk_files:
                    if total_loaded >= total_limit:
                        break

                    try:
                        pdf_path = os.path.join(self.directory_path, filename)
                        text = self.extract_text_from_pdf(pdf_path)

                        if text and len(text.split()) > 50:  # Minimum word count threshold
                            self.documents.append(text)
                            metadata = self.extract_metadata(filename, text)
                            self.document_metadata.append(metadata)
                            total_loaded += 1
                            pbar.update(1)

                            # Periodic garbage collection
                            if total_loaded % 500 == 0:
                                gc.collect()

                    except Exception as e:
                        print(f"\nError processing {filename}: {str(e)}")

                chunk_start = chunk_end

        print(f"\nTotal documents loaded: {total_loaded}")
        return total_loaded > 0

    def create_embeddings_in_chunks(self) -> None:
        """Create embeddings in chunks with progress tracking and memory management."""
        print("\nCreating embeddings in chunks...")

        all_embeddings = []
        total_chunks = (len(self.documents) + self.embedding_batch_size - 1) // self.embedding_batch_size

        with tqdm(total=total_chunks, desc="Creating embeddings") as pbar:
            for i in range(0, len(self.documents), self.embedding_batch_size):
                chunk = self.documents[i:i+self.embedding_batch_size]

                # Create embeddings for chunk
                chunk_embeddings = self.embed_model.encode(chunk, show_progress_bar=False)
                all_embeddings.append(chunk_embeddings)

                pbar.update(1)

                # Periodic garbage collection
                if i % (self.embedding_batch_size * 5) == 0:
                    gc.collect()

        # Combine all embeddings
        combined_embeddings = np.vstack(all_embeddings)

        # Create FAISS index with improved parameters
        dimension = combined_embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(combined_embeddings.astype('float32'))

        print("\nEmbeddings created and indexed successfully")

    def save_cache(self) -> None:
        """Save all data to cache with error handling."""
        try:
            embeddings_cache, documents_cache, metadata_cache = self.get_cache_paths()

            # Save FAISS index
            if self.index is not None:
                faiss.write_index(self.index, embeddings_cache)

            # Save documents
            with open(documents_cache, 'wb') as f:
                pickle.dump(self.documents, f)

            # Save metadata
            with open(metadata_cache, 'w', encoding='utf-8') as f:
                json.dump(self.document_metadata, f, ensure_ascii=False, indent=2)

            print("\nCache saved successfully")

        except Exception as e:
            print(f"Error saving cache: {str(e)}")

    def load_cache(self) -> bool:
        """Load cached data with validation."""
        try:
            embeddings_cache, documents_cache, metadata_cache = self.get_cache_paths()

            # Check if all cache files exist
            if not all(os.path.exists(f) for f in [embeddings_cache, documents_cache, metadata_cache]):
                return False

            # Load FAISS index
            self.index = faiss.read_index(embeddings_cache)

            # Load documents
            with open(documents_cache, 'rb') as f:
                self.documents = pickle.load(f)

            # Load metadata
            with open(metadata_cache, 'r', encoding='utf-8') as f:
                self.document_metadata = json.load(f)

            # Validate loaded data
            if not (len(self.documents) == len(self.document_metadata) and
                   self.index.ntotal == len(self.documents)):
                raise ValueError("Cache validation failed: inconsistent data lengths")

            print(f"\nSuccessfully loaded {len(self.documents)} documents from cache")
            return True

        except Exception as e:
            print(f"Error loading cache: {str(e)}")
            return False

    def retrieve_and_generate(self, query: str, k: int = 5) -> str:
        """Enhanced retrieval and response generation."""
        print(f"\nProcessing query: {query}")

        # Generate query embedding and retrieve similar documents
        query_embedding = self.embed_model.encode([query])[0]
        distances, indices = self.index.search(
            np.array([query_embedding]).astype('float32'),
            min(k, len(self.documents))
        )

        # Build context from retrieved documents
        query_terms = set(query.lower().split())
        retrieved_texts = []
        total_tokens = 0

        print("\nRetrieved relevant documents:")
        for i, idx in enumerate(indices[0]):
            if distances[0][i] > 1.5:  # Relevance threshold
                continue

            metadata = self.document_metadata[idx]
            print(f"{i+1}. Score: {distances[0][i]:.4f}")
            print(f"Title: {metadata['title']}")
            print(f"Date: {metadata.get('date', 'Unknown')}\n")

            # Extract relevant context
            text = self.documents[idx]
            sentences = re.split(r'[.!?]+', text)
            relevant_sentences = []

            for sentence in sentences:
                sentence = sentence.strip()
                if not sentence:
                    continue

                # Check relevance using both query terms and semantic similarity
                term_match = any(term in sentence.lower() for term in query_terms)

                if term_match:
                    sentence_tokens = self.tokenizer.encode(sentence)
                    if total_tokens + len(sentence_tokens) <= self.max_context_tokens:
                        relevant_sentences.append(sentence)
                        total_tokens += len(sentence_tokens)

            if relevant_sentences:
                retrieved_texts.append(' '.join(relevant_sentences))

        if not retrieved_texts:
            return "No relevant information found for this query."

        context = " ".join(retrieved_texts)

        # Generate response with improved prompt
        prompt = (
            f"Based on recent financial news articles, provide a detailed analysis about {query}. "
            f"Focus on key developments, their significance, and potential implications.\n\n"
            f"Source information:\n{context}\n\n"
            "Analysis:"
        )

        try:
            input_ids = self.tokenizer.encode(
                prompt,
                truncation=True,
                max_length=512,
                padding=False,
                return_tensors='pt'
            ).to(self.device)

            attention_mask = torch.ones_like(input_ids)

            outputs = self.model.generate(
                input_ids,
                attention_mask=attention_mask,
                max_new_tokens=200,
                num_beams=4,
                no_repeat_ngram_size=3,
                pad_token_id=self.tokenizer.eos_token_id,
                early_stopping=True,
                do_sample=True,
                temperature=0.7,
                top_k=50,
                top_p=0.9,
                length_penalty=1.0,
                repetition_penalty=1.2
            )

            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract and clean response
            if "Analysis:" in generated_text:
                response = generated_text.split("Analysis:")[-1].strip()
            else:
                response = generated_text.strip()

            # Clean up response
            response = re.sub(r'\s+', ' ', response).strip()
            if not response.endswith(('.', '!', '?')):
                response += '.'

            return response

        except Exception as e:
            print(f"Error during generation: {str(e)}")
            return f"Error generating response: {str(e)}"

def main():
    """Main execution function with improved error handling and user interaction."""
    try:
        # Mount Google Drive
        drive.mount('/content/drive')

        # Initialize processor
        directory_path = '/content/drive/My Drive/All_Finance_PDF_files_old/'
        processor = EnhancedFinanceNewsProcessor(directory_path)

        # Try loading from cache first
        cache_loaded = processor.load_cache()

        if not cache_loaded:
            print("\nCache not found or invalid. Processing documents...")
            processor.init_models()

            if not processor.load_documents_in_chunks(total_limit=2000):
                print("Failed to load documents. Exiting.")
                return

            processor.create_embeddings_in_chunks()
            processor.save_cache()

        # Interactive query loop with improved user interaction
        print("\nFinance News Analysis System")
        print("Enter your queries (type 'quit' to exit)")
        print("Type 'help' for usage instructions")

        while True:
            try:
                query = input("\nQuery: ").strip()

                if query.lower() == 'quit':
                    print("\nExiting system. Thank you!")
                    break

                if query.lower() == 'help':
                    print("\nUsage Instructions:")
                    print("- Enter your query about financial news and press Enter")
                    print("- The system will analyze relevant documents and provide a response")
                    print("- Type 'quit' to exit")
                    print("- Type 'help' to see these instructions again")
                    continue

                if not query:
                    print("Please enter a valid query")
                    continue

                response = processor.retrieve_and_generate(query)
                print(f"\nResponse:\n{response}")

            except KeyboardInterrupt:
                print("\nExiting...")
                break
            except Exception as e:
                print(f"Error processing query: {str(e)}")
                print("Please try another query or type 'quit' to exit")

    except Exception as e:
        print(f"Fatal error: {str(e)}")

if __name__ == "__main__":
    warnings.filterwarnings('ignore')
    main()

Collecting pymupdf
  Downloading pymupdf-1.25.1-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (3.4 kB)
Collecting faiss-gpu
  Downloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.4 kB)
Downloading pymupdf-1.25.1-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (20.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m22.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading faiss_gpu-1.7.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (85.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.5/85.5 MB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-gpu, pymupdf
Successfully installed faiss-gpu-1.7.2 pymupdf-1.25.1
Mounted at /content/drive

Cache not found or invalid. Processing documents...

Initializing models...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Using device: cuda

Loading documents in chunks...


Loading documents: 100%|██████████| 2000/2000 [04:21<00:00,  7.64it/s]



Total documents loaded: 2000

Creating embeddings in chunks...


Creating embeddings: 100%|██████████| 20/20 [00:21<00:00,  1.06s/it]



Embeddings created and indexed successfully

Cache saved successfully

Finance News Analysis System
Enter your queries (type 'quit' to exit)
Type 'help' for usage instructions

Query: what is the impact of real estate business on the chinese economy

Processing query: what is the impact of real estate business on the chinese economy

Retrieved relevant documents:
1. Score: 0.6180
Title: non-personal use or to order multiple copies, please contact Dow Jones Reprints at l-800-843-0008 or
Date: None

2. Score: 0.6928
Title: non-personal use or to order multiple copies, please contact Dow Jones Reprints at l-800-843-0008 or
Date: None

3. Score: 0.7223
Title: https: www.djreprints.com. https: www.wsj.com articles beyond-evergrande-chinas-property-market-face
Date: None

4. Score: 0.8488
Title: https: www.djreprints.com. https: www.wsj.com articles chinas-economy-takes-a-deeper-hit-as-retail-s
Date: None

5. Score: 0.8509
Title: https: www.djreprints.com. https: www.wsj.com articles evergr