In [1]:
# Cell 1: Install required packages
!pip install -q transformers torch torchvision torchaudio accelerate bitsandbytes
!pip install -q sentence-transformers
!pip install -q chromadb
!pip install -q fastapi uvicorn pyngrok
!pip install -q pandas numpy
!pip install -q python-dotenv airtable-python-wrapper
!pip install -q huggingface_hub


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.3/61.3 MB[0m [31m11.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.8/19.8 MB[0m [31m48.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m284.2/284.2 kB[0m [31m26.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m73.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m103.3/103.3 kB[0m [31m10.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.5/16.5 MB[0m [31m50.8 MB/s[0m eta [36

In [2]:
# Cell 2: Import libraries and setup
import os
import json
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Optional
import torch
from sentence_transformers import SentenceTransformer
import chromadb
from chromadb.config import Settings
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import uvicorn
from pyngrok import ngrok
import asyncio
from datetime import datetime
import re
from airtable import Airtable
import gc
from huggingface_hub import login

# Check if GPU is available
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB" if torch.cuda.is_available() else "No GPU")


Using device: cuda
GPU Memory: 14.7 GB


In [3]:
# Airtable configuration
from huggingface_hub import login
from google.colab import userdata
HF_TOKEN = userdata.get("HF_TOKEN")
os.environ["HUGGINGFACEHUB_API_TOKEN"] = HF_TOKEN
login(HF_TOKEN)
os.environ['AIRTABLE_API_KEY'] = userdata.get("AIRTABLE_API_KEY")
os.environ['AIRTABLE_BASE_ID'] = 'app8snf0kPAMzIJEa'
os.environ['AIRTABLE_TABLE_NAME'] = 'Traiff'

In [4]:
# Cell 4: Data Models
class QueryRequest(BaseModel):
    question: str
    context: Optional[Dict[str, Any]] = None
    top_k: int = 5

class QueryResponse(BaseModel):
    answer: str
    sources: List[Dict[str, Any]]
    confidence: float

class AirtableData(BaseModel):
    records: List[Dict[str, Any]]

In [5]:
# Cell 5: Llama Model Setup
class LlamaGenerator:
    def __init__(self):
        self.model_name = "meta-llama/Llama-3.1-8B-Instruct"
        self.tokenizer = None
        self.model = None
        self.pipeline = None
        self.max_length = 1024  # Adjust based on your needs

        print("Loading Llama 3.1 8B Instruct model...")
        self._load_model()

    def _load_model(self):
        try:
            # Load tokenizer
            self.tokenizer = AutoTokenizer.from_pretrained(
                self.model_name,
                trust_remote_code=True
            )

            # Add padding token if not present
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token

            # Load model with optimizations for GPU memory
            if torch.cuda.is_available():
                # Use 4-bit quantization to fit in GPU memory
                from transformers import BitsAndBytesConfig

                quantization_config = BitsAndBytesConfig(
                    load_in_4bit=True,
                    bnb_4bit_compute_dtype=torch.float16,
                    bnb_4bit_use_double_quant=True,
                    bnb_4bit_quant_type="nf4"
                )

                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    quantization_config=quantization_config,
                    device_map="auto",
                    trust_remote_code=True,
                    torch_dtype=torch.float16
                )
            else:
                # CPU fallback
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.model_name,
                    trust_remote_code=True,
                    torch_dtype=torch.float32
                )

            # Create pipeline
            self.pipeline = pipeline(
                "text-generation",
                model=self.model,
                tokenizer=self.tokenizer,
                device_map="auto" if torch.cuda.is_available() else None,
                max_length=self.max_length,
                do_sample=True,
                temperature=0.1,
                top_p=0.9,
                pad_token_id=self.tokenizer.eos_token_id
            )

            print("✅ Llama 3.1 8B model loaded successfully!")

        except Exception as e:
            print(f"❌ Error loading model: {e}")
            print("Falling back to a smaller model...")
            # Fallback to a smaller model if 8B doesn't fit
            try:
                self._load_fallback_model()
            except Exception as fallback_error:
                print(f"❌ Fallback also failed: {fallback_error}")
                raise

    def _load_fallback_model(self):
        """Fallback to a smaller model if 8B doesn't fit"""
        fallback_model = "microsoft/DialoGPT-medium"
        print(f"Loading fallback model: {fallback_model}")

        self.tokenizer = AutoTokenizer.from_pretrained(fallback_model)
        self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            fallback_model,
            device_map="auto" if torch.cuda.is_available() else None
        )

        self.pipeline = pipeline(
            "text-generation",
            model=self.model,
            tokenizer=self.tokenizer,
            max_length=512,
            do_sample=True,
            temperature=0.1
        )

    def generate_response(self, prompt: str) -> str:
        try:
            # Format prompt for Llama 3.1 Instruct
            formatted_prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>

You are a customs classification expert with deep knowledge of HTS codes, trade regulations, and tariff classifications. You analyze tariff data and provide clear, accurate explanations for classification decisions. Be specific and reference the provided examples when relevant.<|eot_id|><|start_header_id|>user<|end_header_id|>

{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

            # Generate response
            with torch.no_grad():
                response = self.pipeline(
                    formatted_prompt,
                    max_new_tokens=400,
                    num_return_sequences=1,
                    return_full_text=False,
                    clean_up_tokenization_spaces=True
                )

            # Extract generated text
            generated_text = response[0]['generated_text'].strip()

            # Clean up the response
            if generated_text.endswith('<|eot_id|>'):
                generated_text = generated_text[:-10].strip()

            return generated_text

        except Exception as e:
            print(f"Error generating response: {e}")
            return f"I apologize, but I encountered an error while processing your question: {str(e)}"

    def cleanup(self):
        """Free up GPU memory"""
        if self.model:
            del self.model
        if self.pipeline:
            del self.pipeline
        torch.cuda.empty_cache()
        gc.collect()

In [6]:
# Cell 7: Initialize RAG Service and Load Data
print("🚀 Initializing Tariff RAG Service with Llama 3.1...")

# Initialize ChromaDB (local storage in Colab) with the new recommended approach
from chromadb.config import Settings
import chromadb

class TariffRAGService:
    def __init__(self):
        # Initialize embedding model (runs on GPU if available)
        print("Loading embedding model...")
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)

        # Initialize ChromaDB with the new API
        self.chroma_client = chromadb.PersistentClient(path="/content/chroma_db")

        # Create or get collection for tariff data
        try:
            self.collection = self.chroma_client.get_collection("tariff_classification")
            print(f"✅ Loaded existing collection with {self.collection.count()} documents")
        except:
            self.collection = self.chroma_client.create_collection(
                name="tariff_classification",
                metadata={"description": "Tariff classification data with HTS codes and rationales"}
            )
            print("✅ Created new collection 'tariff_classification'")


        # Initialize Llama model
        print("Initializing Llama model...")
        # Re-initialize LlamaGenerator here to ensure correct token usage
        self.llama = LlamaGenerator()


        # Initialize Airtable connection
        if os.environ.get('AIRTABLE_API_KEY'):
            self.airtable = Airtable(
                os.environ['AIRTABLE_BASE_ID'],
                os.environ['AIRTABLE_TABLE_NAME'],
                api_key=os.environ['AIRTABLE_API_KEY']
            )
        else:
            self.airtable = None
            print("Airtable not configured - using sample data only")

        print(f"✅ Initialized Tariff RAG service.") # Count will be printed after loading data


    def create_embeddings(self, texts: List[str]) -> List[List[float]]:
        """Create embeddings using sentence transformer"""
        embeddings = self.embedding_model.encode(texts, convert_to_tensor=True)
        return embeddings.cpu().numpy().tolist()

    def process_tariff_record(self, record: Dict[str, Any]) -> Dict[str, str]:
        """Convert tariff record to searchable text content"""
        fields = record.get('fields', record)  # Handle both Airtable format and direct fields

        # Create comprehensive text for embedding
        content_parts = []

        # Part information
        if fields.get('manufacture_part_number'):
            content_parts.append(f"Part Number: {fields['manufacture_part_number']}")

        if fields.get('original_part_description'):
            content_parts.append(f"Original Description: {fields['original_part_description']}")

        if fields.get('final_generated_part_description'):
            content_parts.append(f"Final Description: {fields['final_generated_part_description']}")

        # Classification information
        if fields.get('final_hts_code'):
            hts_code = fields['final_hts_code']
            content_parts.append(f"HTS Code: {hts_code}")

            # Extract chapter from HTS code
            chapter = hts_code.split('.')[0] if '.' in hts_code else hts_code[:2]
            content_parts.append(f"Chapter: {chapter}")

        if fields.get('rationale'):
            content_parts.append(f"Classification Rationale: {fields['rationale']}")

        if fields.get('chapters_used'):
            content_parts.append(f"Chapters Considered: {fields['chapters_used']}")

        # Technical specifications
        if fields.get('material_type'):
            content_parts.append(f"Material Type: {fields['material_type']}")

        if fields.get('rohs_compliance'):
            content_parts.append(f"RoHS Compliance: {fields['rohs_compliance']}")

        if fields.get('eccn'):
            content_parts.append(f"ECCN: {fields['eccn']}")

        # Supply chain information
        if fields.get('country_of_origin'):
            content_parts.append(f"Country of Origin: {fields['country_of_origin']}")

        if fields.get('distributor'):
            content_parts.append(f"Distributor: {fields['distributor']}")

        return {
            'content': ' | '.join(content_parts),
            'summary': f"{fields.get('manufacture_part_number', 'Unknown Part')} - {fields.get('final_hts_code', 'No HTS')} - {fields.get('original_part_description', 'No description')[:100]}"
        }

    def load_airtable_data(self):
        """Load data from Airtable"""
        if not self.airtable:
            print("Airtable not configured, using sample data")
            return self.get_sample_data()

        try:
            records = self.airtable.get_all()
            print(f"✅ Loaded {len(records)} records from Airtable")
            return records
        except Exception as e:
            print(f"⚠️ Error loading from Airtable: {e}")
            return self.get_sample_data()

    def get_sample_data(self):
        """Sample data based on your CSV structure"""
        return [
            {
                'id': 'sample_1',
                'fields': {
                    'manufacture_part_number': 'ABC-123-XYZ',
                    'original_part_description': 'High precision optical lens assembly for telescopic sights',
                    'final_hts_code': '9013.10.10',
                    'rationale': 'This item is classified under 9013.10.10 as telescopic sights for fitting to arms. The optical lens assembly is specifically designed for telescopic sights used on firearms, meeting the criteria for this HTS classification.',
                    'confidence': 0.95,
                    'chapters_used': '90, 85, 84',
                    'country_of_origin': 'Germany',
                    'rohs_compliance': 'Yes',
                    'eccn': 'EAR99',
                    'distributor': 'Optical Components Inc',
                    'material_type': 'Glass/Metal',
                    'lead_time': '6-8 weeks',
                    'coo_source': 'Manufacturer Certificate',
                    'final_generated_part_description': 'Precision optical lens assembly designed for telescopic sights, featuring multi-coated glass elements in metal housing',
                    'generated_at': '2024-01-15T10:30:00Z'
                }
            },
            {
                'id': 'sample_2',
                'fields': {
                    'manufacture_part_number': 'DEF-456-LAS',
                    'original_part_description': 'Industrial laser cutting system controller',
                    'final_hts_code': '9013.20.00',
                    'rationale': 'Classified as lasers under 9013.20.00. This industrial laser system is used for cutting applications and falls under the laser equipment category, excluding laser diodes.',
                    'confidence': 0.88,
                    'chapters_used': '90, 84, 85',
                    'country_of_origin': 'Japan',
                    'rohs_compliance': 'Yes',
                    'eccn': '6A005',
                    'distributor': 'Industrial Laser Systems',
                    'material_type': 'Electronic/Metal',
                    'lead_time': '10-12 weeks',
                    'coo_source': 'Certificate of Origin',
                    'final_generated_part_description': 'Industrial laser cutting system controller with precision beam control and safety systems',
                    'generated_at': '2024-01-15T11:45:00Z'
                }
            }
        ]

    def index_tariff_data(self, records: List[Dict[str, Any]]):
        """Index tariff classification data for semantic search"""
        texts = []
        metadatas = []
        ids = []

        for record in records:
            processed = self.process_tariff_record(record)

            texts.append(processed['content'])

            # Create comprehensive metadata
            fields = record.get('fields', record)
            metadata = {
                'id': record.get('id', f"record_{len(ids)}"),
                'part_number': fields.get('manufacture_part_number', ''),
                'hts_code': fields.get('final_hts_code', ''),
                'description': fields.get('original_part_description', ''),
                'rationale': fields.get('rationale', ''),
                'confidence': fields.get('confidence', 0.0),
                'country_origin': fields.get('country_of_origin', ''),
                'material_type': fields.get('material_type', ''),
                'chapter': self.extract_chapter(fields.get('final_hts_code', '')),
                'summary': processed['summary']
            }

            metadatas.append(metadata)
            ids.append(f"tariff_{record.get('id', len(ids))}")

        # Create embeddings and add to collection
        if texts:
            embeddings = self.create_embeddings(texts)

            # Clear existing data and add new - using add will update if ID exists or add if not
            # Consider clearing if you want a full refresh, but add is generally safer for updates
            # For simplicity and to match original intent of full refresh, let's clear and re-add
            try:
                 self.chroma_client.delete_collection("tariff_classification")
                 print("Cleared existing collection.")
            except Exception as e:
                 print(f"Could not clear collection (might not exist): {e}")

            # Re-create collection after deletion
            self.collection = self.chroma_client.create_collection(
                name="tariff_classification",
                metadata={"description": "Tariff classification data"}
            )
            print("Re-created collection 'tariff_classification'.")


            self.collection.add(
                embeddings=embeddings,
                documents=texts,
                metadatas=metadatas,
                ids=ids
            )

            # Persist the changes - not needed with PersistentClient
            # self.chroma_client.persist()

            print(f"✅ Indexed {len(texts)} tariff records successfully")

    def extract_chapter(self, hts_code: str) -> str:
        """Extract chapter number from HTS code"""
        if not hts_code:
            return ''

        # Extract first 2 digits (chapter)
        match = re.match(r'^(\d{2})', hts_code.replace('.', ''))
        return match.group(1) if match else ''

    def retrieve_relevant_records(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
        """Retrieve relevant tariff records using semantic search"""
        query_embedding = self.create_embeddings([query])[0]

        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=min(top_k, self.collection.count()),
            include=["documents", "metadatas", "distances"]
        )

        relevant_records = []
        # Check if results and documents exist and are not empty
        if results and results.get('documents') and results['documents'][0]:
            for i in range(len(results['documents'][0])):
                relevant_records.append({
                    'content': results['documents'][0][i],
                    'metadata': results['metadatas'][0][i],
                    'similarity': 1 - results['distances'][0][i]  # Convert distance to similarity
                })

        return relevant_records

    async def generate_answer(self, question: str, context: Dict[str, Any] = None) -> Dict[str, Any]:
        """Generate answer using Llama with tariff classification context"""

        # Retrieve relevant tariff records
        relevant_records = self.retrieve_relevant_records(question, top_k=3)

        # Build context for the prompt
        context_str = ""
        if context and context.get('value'):
            context_str += f"Current selected value: {context['value']}\n"
            context_str += f"Field: {context.get('fieldName', 'Unknown')}\n\n"

        # Build knowledge context from retrieved records
        knowledge_context = "Relevant tariff classification examples from your data:\n\n"
        if relevant_records:
            for i, record in enumerate(relevant_records, 1):
                metadata = record['metadata']
                knowledge_context += f"Example {i}:\n"
                knowledge_context += f"  Part: {metadata.get('part_number', 'N/A')}\n"
                knowledge_context += f"  HTS Code: {metadata.get('hts_code', 'N/A')}\n"
                knowledge_context += f"  Description: {metadata.get('description', 'N/A')}\n"
                knowledge_context += f"  Rationale: {metadata.get('rationale', 'N/A')}\n"
                knowledge_context += f"  Similarity: {record['similarity']:.3f}\n\n"
        else:
            knowledge_context += "No relevant examples found in the tariff data.\n\n"


        # Create the prompt
        prompt = f"""You are analyzing tariff classification data. Answer the following question based on the provided context and knowledge base.

{context_str}

{knowledge_context}

Question: {question}

Please provide a detailed answer based on the tariff classification examples above. When discussing HTS codes, explain:
1. Why specific classifications were chosen
2. What chapter the item belongs to and why
3. Key characteristics that determine the classification
4. Any alternative classifications that were considered

Be specific and reference the examples when relevant."""

        try:
            # Generate response using Llama
            answer = self.llama.generate_response(prompt)

            # Calculate confidence based on similarity scores
            avg_similarity = np.mean([record['similarity'] for record in relevant_records]) if relevant_records else 0.5

            return {
                'answer': answer,
                'sources': relevant_records,
                'confidence': float(avg_similarity)
            }

        except Exception as e:
            print(f"❌ Error generating answer: {e}")
            return {
                'answer': f"I apologize, but I encountered an error while processing your question: {str(e)}. However, I found {len(relevant_records)} relevant examples in your tariff data.",
                'sources': relevant_records,
                'confidence': 0.0
            }


# Re-run the LlamaGenerator initialization before creating the service
print("Initializing Llama model before RAG service...")
llama_generator = LlamaGenerator()

rag_service = TariffRAGService()
rag_service.llama = llama_generator # Assign the initialized LlamaGenerator instance

# Load and index your tariff data
print("📊 Loading tariff classification data...")
tariff_records = rag_service.load_airtable_data()
rag_service.index_tariff_data(tariff_records)

print(f"✅ RAG service ready with {rag_service.collection.count()} tariff records and Llama 3.1 8B!")

🚀 Initializing Tariff RAG Service with Llama 3.1...
Initializing Llama model before RAG service...
Loading Llama 3.1 8B Instruct model...


tokenizer_config.json:   0%|          | 0.00/55.4k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/855 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/1.17G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/184 [00:00<?, ?B/s]

Device set to use cuda:0


✅ Llama 3.1 8B model loaded successfully!
Loading embedding model...


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

✅ Created new collection 'tariff_classification'
Initializing Llama model...
Loading Llama 3.1 8B Instruct model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use cuda:0


✅ Llama 3.1 8B model loaded successfully!
✅ Initialized Tariff RAG service.
📊 Loading tariff classification data...
✅ Loaded 5 records from Airtable
Cleared existing collection.
Re-created collection 'tariff_classification'.
✅ Indexed 5 tariff records successfully
✅ RAG service ready with 5 tariff records and Llama 3.1 8B!


In [7]:
# Cell 8: FastAPI Application
app = FastAPI(title="Tariff Classification RAG Service with Llama 3.1", version="1.0.0")

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # In production, replace with your frontend URL
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
async def root():
    return {
        "message": "Tariff Classification RAG Service with Llama 3.1 is running",
        "model": "meta-llama/Llama-3.1-8B-Instruct",
        "records_count": rag_service.collection.count(),
        "device": device,
        "sample_questions": [
            "Why is this classified in Chapter 90?",
            "What are the key characteristics for this HTS classification?",
            "Are there similar parts with different classifications?",
            "What documentation is required for this classification?",
            "How confident are we in this classification?"
        ]
    }

@app.post("/ask", response_model=QueryResponse)
async def ask_question(request: QueryRequest):
    try:
        result = await rag_service.generate_answer(request.question, request.context)
        return QueryResponse(**result)
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/refresh-data")
async def refresh_airtable_data():
    """Reload data from Airtable"""
    try:
        tariff_records = rag_service.load_airtable_data()
        rag_service.index_tariff_data(tariff_records)
        return {
            "message": f"Data refreshed successfully",
            "records_count": len(tariff_records)
        }
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/search/{query}")
async def search_records(query: str, top_k: int = 5):
    """Search tariff records"""
    try:
        results = rag_service.retrieve_relevant_records(query, top_k)
        return {"query": query, "results": results}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/stats")
async def get_stats():
    """Get statistics about the knowledge base"""
    try:
        count = rag_service.collection.count()

        # Get some sample metadata for stats
        if count > 0:
            sample = rag_service.collection.get(limit=count, include=["metadatas"])

            chapters = [m.get('chapter', '') for m in sample['metadatas'] if m.get('chapter')]
            unique_chapters = list(set(chapters))

            hts_codes = [m.get('hts_code', '') for m in sample['metadatas'] if m.get('hts_code')]
            unique_hts_codes = len(set(hts_codes))

            return {
                "total_records": count,
                "unique_chapters": unique_chapters,
                "unique_hts_codes": unique_hts_codes,
                "model": "meta-llama/Llama-3.1-8B-Instruct",
                "device": device,
                "sample_parts": [m.get('part_number', 'N/A') for m in sample['metadatas'][:5]]
            }
        else:
            return {"total_records": 0, "message": "No data indexed"}

    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.post("/cleanup")
async def cleanup_gpu():
    """Free up GPU memory"""
    try:
        rag_service.llama.cleanup()
        torch.cuda.empty_cache()
        return {"message": "GPU memory cleaned up"}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

# Cell 9: Start the server with ngrok
import nest_asyncio
nest_asyncio.apply()

# Kill any existing ngrok processes
!pkill -f ngrok

# Start ngrok tunnel
ngrok.set_auth_token(userdata.get("NGROCK"))
public_url = ngrok.connect(8000)
print("=" * 70)
print("🚀 Tariff Classification RAG Service with Llama 3.1 is available at:")
print(f"📡 {public_url}")
print(f"📚 Documentation: {public_url}/docs")
print(f"📊 Stats: {public_url}/stats")
print("=" * 70)
print("💡 Features:")
print("  - Llama 3.1 8B Instruct (100% free)")
print("  - GPU-accelerated embeddings")
print("  - Your Airtable tariff data")
print("  - Semantic search with ChromaDB")
print("=" * 70)
print("📋 COPY THIS URL TO YOUR LOCAL BACKEND:")
print(f"🔗 {public_url}")
print("=" * 70)

# Run the server
uvicorn.run(app, host="0.0.0.0", port=8000)

🚀 Tariff Classification RAG Service with Llama 3.1 is available at:
📡 NgrokTunnel: "https://e15801c4083b.ngrok-free.app" -> "http://localhost:8000"
📚 Documentation: NgrokTunnel: "https://e15801c4083b.ngrok-free.app" -> "http://localhost:8000"/docs
📊 Stats: NgrokTunnel: "https://e15801c4083b.ngrok-free.app" -> "http://localhost:8000"/stats
💡 Features:
  - Llama 3.1 8B Instruct (100% free)
  - GPU-accelerated embeddings
  - Your Airtable tariff data
  - Semantic search with ChromaDB
📋 COPY THIS URL TO YOUR LOCAL BACKEND:
🔗 NgrokTunnel: "https://e15801c4083b.ngrok-free.app" -> "http://localhost:8000"


INFO:     Started server process [440]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8000 (Press CTRL+C to quit)


INFO:     155.33.132.13:0 - "OPTIONS /ask HTTP/1.1" 200 OK
INFO:     155.33.132.13:0 - "POST /ask HTTP/1.1" 200 OK


INFO:     Shutting down
INFO:     Waiting for application shutdown.
INFO:     Application shutdown complete.
INFO:     Finished server process [440]
