<a href="https://colab.research.google.com/github/navneetkrc/Deep_learning_experiments/blob/master/GLINER_apps.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required packages
# !pip install gliner pycountry scipy==1.12 gradio==4.31.5 spaces

In [None]:


import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER

# Model configuration
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["product", "product type", "price", "memory", "feature"]
QUERY = "What is the difference between the Samsung Galaxy S23 and S23 Ultra?, What does 'Dynamic AMOLED 2X' mean in a Samsung display?, How do I use the S Pen on my Samsung Galaxy Note or Ultra phone?, What is 'One UI' on Samsung phones?, What is 'Samsung Knox' and why is it important?, What is the difference between Samsung QLED and Neo QLED TVs?, What is 'SmartThings' in Samsung appliances?, What is 'Bespoke' in Samsung refrigerators?, What does 'EcoBubble' mean in Samsung washing machines?, What is the Samsung Galaxy Watch and what are its features?, What are Samsung's different memory and storage devices?, What is the difference between a Samsung soundbar and a home theater system?"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1", "urchade/gliner_multi-v2.1", "urchade/gliner_large-v2.1"]

print(f"Cache directory: {_CACHE_DIR}")

def get_model(model_name: str = None):
    start = datetime.now()

    if model_name is None:
        model_name = "urchade/gliner_base"

    global _MODEL

    if _MODEL.get(model_name) is None:
        _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

    if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
        _MODEL[model_name] = _MODEL[model_name].to("cuda")

    print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")

    return _MODEL[model_name]

def get_country(country_name: str):
    try:
        return pycountry.countries.search_fuzzy(country_name)
    except LookupError:
        return None

# Removed @spaces.GPU decorator since it's not needed in Colab
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
    start = datetime.now()
    model = get_model(model_name)

    if isinstance(labels, str):
        labels = [i.strip() for i in labels.split(",")]

    entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)

    print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")

    return entities

def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    entities = []
    _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)

    for entity in _entities:
        if entity["label"] == "country":
            country = get_country(entity["text"])
            if country:
                entity["normalized"] = [dict(c) for c in country]
                entities.append(entity)
        else:
            entities.append(entity)

    payload = {"query": query, "entities": entities}
    print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")

    return payload

def annotate_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    payload = parse_query(query, labels, threshold, nested_ner, model_name)

    return {
        "text": query,
        "entities": [
            {
                "entity": entity["label"],
                "word": entity["text"],
                "start": entity["start"],
                "end": entity["end"],
                "score": entity["score"],
            }
            for entity in payload["entities"]
        ],
    }

# Function to create and launch the Gradio interface
def create_gliner_interface():
    # Initialize model here
    print("Initializing models...")
    # Only initialize the base model by default to save time
    predict_entities("urchade/gliner_base", QUERY, LABELS, threshold=THRESHOLD)

    with gr.Blocks(title="GLiNER-query-parser") as demo:
        gr.Markdown(
        """
        # GLiNER-based Query Parser (a zero-shot NER model)
        This demonstrates the GLiNER model's ability to predict entities in a given text query. Given a set of entities to track, the model can then identify instances of these entities in the query. The parsed entities are then displayed in the output. A special case is the "country" entity, which is normalized to the ISO 3166-1 alpha-2 code using the pycountry library. This GLiNER mode is licensed under the Apache 2.0 license.
        ## Links
        * Model: https://huggingface.co/urchade/gliner_medium-v2.1, https://huggingface.co/urchade/gliner_base
        * All GLiNER models: https://huggingface.co/models?library=gliner
        * Paper: https://arxiv.org/abs/2311.08526
        * Repository: https://github.com/urchade/GLiNER
        """
        )

        query = gr.Textbox(
            value=QUERY, label="query", placeholder="Enter your query here"
        )
        with gr.Row() as row:
            model_name = gr.Radio(
                choices=MODELS,
                value="urchade/gliner_base",
                label="Model",
            )
            entities = gr.Textbox(
                value=", ".join(LABELS),
                label="entities",
                placeholder="Enter the entities to detect here (comma separated)",
                scale=2,
            )
            threshold = gr.Slider(
                0,
                1,
                value=THRESHOLD,
                step=0.01,
                label="Threshold",
                info="Lower threshold may extract more false-positive entities from the query.",
                scale=1,
            )
            is_nested = gr.Checkbox(
                value=False,
                label="Nested NER",
                info="Setting to True extracts nested entities",
                scale=0,
            )

        output = gr.HighlightedText(label="Annotated entities")
        submit_btn = gr.Button("Submit")

        # Submitting
        query.submit(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        entities.submit(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        threshold.release(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        submit_btn.click(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        is_nested.change(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        model_name.change(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )

        # Configure for Colab environment
        # Using share=True to generate a public URL
        demo.queue(default_concurrency_limit=5)
        demo.launch(debug=True, share=True)

    return demo

# This will be called from a Colab cell
if __name__ == "__main__":
    create_gliner_interface()

Cache directory: None
Initializing models...


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


2025-02-26 21:14:12.552231 :: get_model :: 0:00:08.268311
2025-02-26 21:14:13.743901 :: predict_entities :: 0:00:09.459982
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
IMPORTANT: You are using gradio version 4.31.5, however version 4.44.1 is available, please upgrade.
--------
Running on public URL: https://c2c94330f8da742cdd.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/pydantic/type_adapter.py", line 271, in _init_core_attrs
    self.core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__')
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pydantic/type_adapter.py", line 55, in _getattr_no_parents
    raise AttributeError(attribute)
AttributeError: __pydantic_core_schema__

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://c2c94330f8da742cdd.gradio.live


In [None]:
# GLiNER Query Parser - Google Colab Implementation

# Cell 1: Install Required Libraries
# !pip install gliner pycountry scipy==1.12 gradio==4.31.5 spaces





In [None]:
# Cell 2: Import Libraries and Define Functions
import os
import json
import gradio as gr
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER


In [None]:
# Cell 3: Model Configuration
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1", "urchade/gliner_multi-v2.1", "urchade/gliner_large-v2.1"]

print(f"Cache directory: {_CACHE_DIR}")





Cache directory: None


In [None]:
# Cell 4: Define Helper Functions
def get_model(model_name: str = None):
    start = datetime.now()

    if model_name is None:
        model_name = "urchade/gliner_base"

    global _MODEL

    if _MODEL.get(model_name) is None:
        _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

    if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
        _MODEL[model_name] = _MODEL[model_name].to("cuda")

    print(f"{datetime.now()} :: get_model :: {datetime.now() - start}")

    return _MODEL[model_name]

def get_country(country_name: str):
    try:
        return pycountry.countries.search_fuzzy(country_name)
    except LookupError:
        return None



In [None]:
# Cell 5: Define Prediction Functions
def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
    start = datetime.now()
    model = get_model(model_name)

    if isinstance(labels, str):
        labels = [i.strip() for i in labels.split(",")]

    entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)

    print(f"{datetime.now()} :: predict_entities :: {datetime.now() - start}")

    return entities

def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    entities = []
    _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)

    for entity in _entities:
        if entity["label"] == "country":
            country = get_country(entity["text"])
            if country:
                entity["normalized"] = [dict(c) for c in country]
                entities.append(entity)
        else:
            entities.append(entity)

    payload = {"query": query, "entities": entities}
    print(f"{datetime.now()} :: parse_query :: {json.dumps(payload)}\n")

    return payload

def annotate_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    payload = parse_query(query, labels, threshold, nested_ner, model_name)

    return {
        "text": query,
        "entities": [
            {
                "entity": entity["label"],
                "word": entity["text"],
                "start": entity["start"],
                "end": entity["end"],
                "score": entity["score"],
            }
            for entity in payload["entities"]
        ],
    }




In [None]:
# Cell 6: Initialize Model and Create Gradio Interface
# Initialize only base model to save time
print("Initializing base model...")
predict_entities("urchade/gliner_base", QUERY, LABELS, threshold=THRESHOLD)

# Create the Gradio interface
def create_interface():
    with gr.Blocks(title="GLiNER-query-parser") as demo:
        gr.Markdown(
        """
        # GLiNER-based Query Parser (a zero-shot NER model)
        This demonstrates the GLiNER model's ability to predict entities in a given text query.
        The model identifies instances of specified entities in the query, and the parsed entities are displayed in the output.
        """
        )

        query = gr.Textbox(
            value=QUERY, label="Query", placeholder="Enter your query here"
        )
        with gr.Row() as row:
            model_name = gr.Radio(
                choices=MODELS,
                value="urchade/gliner_base",
                label="Model",
            )
            entities = gr.Textbox(
                value=", ".join(LABELS),
                label="Entities",
                placeholder="Enter the entities to detect here (comma separated)",
                scale=2,
            )
            threshold = gr.Slider(
                0,
                1,
                value=THRESHOLD,
                step=0.01,
                label="Threshold",
                info="Lower threshold may extract more false-positive entities from the query.",
                scale=1,
            )
            is_nested = gr.Checkbox(
                value=False,
                label="Nested NER",
                info="Setting to True extracts nested entities",
                scale=0,
            )

        output = gr.HighlightedText(label="Annotated entities")
        submit_btn = gr.Button("Submit")

        # Submitting
        query.submit(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        entities.submit(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        threshold.release(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        submit_btn.click(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        is_nested.change(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )
        model_name.change(
            fn=annotate_query, inputs=[query, entities, threshold, is_nested, model_name], outputs=output
        )

        # Enable public URL sharing for Colab
        demo.queue(default_concurrency_limit=5)
        return demo.launch(debug=True, share=True)

Initializing base model...


Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


2025-02-26 21:16:17.141558 :: get_model :: 0:00:07.807576
2025-02-26 21:16:17.745978 :: predict_entities :: 0:00:08.412004


In [None]:
# Cell 7: Launch the App
# This will output a public URL that you can access from anywhere
interface = create_interface()

Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
IMPORTANT: You are using gradio version 4.31.5, however version 4.44.1 is available, please upgrade.
--------
Running on public URL: https://67f370a7541491a9c1.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


ERROR:    Exception in ASGI application
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/pydantic/type_adapter.py", line 271, in _init_core_attrs
    self.core_schema = _getattr_no_parents(self._type, '__pydantic_core_schema__')
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/pydantic/type_adapter.py", line 55, in _getattr_no_parents
    raise AttributeError(attribute)
AttributeError: __pydantic_core_schema__

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/protocols/http/h11_impl.py", line 403, in run_asgi
    result = await app(  # type: ignore[func-returns-value]
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/uvicorn/middleware/proxy_headers.py", line 60, in __call__
    return await self.app(scope

Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://67f370a7541491a9c1.gradio.live


In [None]:
# Cell 8: Test API Usage (Optional)
# You can uncomment this to test the API
'''
from gradio_client import Client

# Use the public URL that was generated above
client = Client("YOUR_PUBLIC_URL_HERE")
result = client.predict(
    query="gdp, m3, and child mortality of india and southeast asia 2024",
    labels="country, year, statistical indicator, region",
    threshold=0.3,
    nested_ner=False,
    api_name="/parse_query"
)
print(result)
'''

##streamlit app

In [None]:
# # Install required packages
# !pip install gliner pycountry scipy==1.12 streamlit torch

# import os
# import json
# import streamlit as st
# import pycountry
# import torch
# from datetime import datetime
# from typing import Dict, Union
# from gliner import GLiNER

# # Model configuration
# _MODEL = {}
# _CACHE_DIR = os.environ.get("CACHE_DIR", None)
# THRESHOLD = 0.3
# LABELS = ["country", "year", "statistical indicator", "geographic region"]
# QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
# MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1", "urchade/gliner_multi-v2.1", "urchade/gliner_large-v2.1"]

# # Helper functions
# def get_model(model_name: str = None):
#     if model_name is None:
#         model_name = "urchade/gliner_base"

#     global _MODEL

#     if _MODEL.get(model_name) is None:
#         _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

#     if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
#         _MODEL[model_name] = _MODEL[model_name].to("cuda")

#     return _MODEL[model_name]

# def get_country(country_name: str):
#     try:
#         return pycountry.countries.search_fuzzy(country_name)
#     except LookupError:
#         return None

# def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
#     model = get_model(model_name)

#     if isinstance(labels, str):
#         labels = [i.strip() for i in labels.split(",")]

#     entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)

#     return entities

# def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
#     entities = []
#     _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)

#     for entity in _entities:
#         if entity["label"] == "country":
#             country = get_country(entity["text"])
#             if country:
#                 entity["normalized"] = [dict(c) for c in country]
#                 entities.append(entity)
#         else:
#             entities.append(entity)

#     return {"query": query, "entities": entities}

# # Initialize only base model
# @st.cache_resource
# def initialize_model():
#     return get_model("urchade/gliner_base")

# # Pre-initialize the model
# init_model = initialize_model()

# # Create the Streamlit interface
# st.title("GLiNER-based Query Parser")
# st.markdown("""
# This app demonstrates the GLiNER model's ability to predict entities in a given text query.
# Given a set of entities to track, the model can identify instances of these entities in the query.
# """)

# # User inputs
# query = st.text_area("Query", value=QUERY, placeholder="Enter your query here")
# model_name = st.radio("Model", options=MODELS, index=0)
# entities = st.text_input("Entities to detect (comma separated)", value=", ".join(LABELS))
# threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, value=THRESHOLD, step=0.01,
#                      help="Lower threshold may extract more false-positive entities from the query.")
# is_nested = st.checkbox("Nested NER", value=False, help="Setting to True extracts nested entities")

# # Process when the form is submitted
# if st.button("Submit"):
#     with st.spinner("Processing..."):
#         # Process the query
#         result = parse_query(query, entities, threshold, is_nested, model_name)

#         # Display the annotated text
#         st.subheader("Detected Entities")

#         # Create highlighted text display
#         text = result["query"]
#         annotations = []

#         for entity in result["entities"]:
#             annotations.append({
#                 "start": entity["start"],
#                 "end": entity["end"],
#                 "label": entity["label"],
#                 "score": round(entity["score"], 3)
#             })

#         # Sort annotations by start position
#         annotations.sort(key=lambda x: x["start"])

#         # Display original text with annotations
#         processed_text = ""
#         last_end = 0

#         for annotation in annotations:
#             # Add text before the entity
#             processed_text += text[last_end:annotation["start"]]
#             # Add highlighted entity
#             entity_text = text[annotation["start"]:annotation["end"]]
#             processed_text += f"**[{entity_text}]({annotation['label']}, {annotation['score']:.3f})**"
#             last_end = annotation["end"]

#         # Add any remaining text
#         processed_text += text[last_end:]

#         st.markdown(processed_text)

#         # Display raw JSON output for more details
#         with st.expander("Raw JSON Output"):
#             st.json(result)

# # Save this file as "streamlit_app.py"
# # Run it in Colab with: !streamlit run streamlit_app.py -- --server.port=8501 --server.enableCORS=False --server.enableXsrfProtection=False

# # To get a public URL, you can use ngrok
# # !pip install pyngrok
# # from pyngrok import ngrok
# # public_url = ngrok.connect(port=8501)
# # print(f"Public URL: {public_url}")

In [None]:
%%writefile streamlit_app.py
# Install required packages
import os
import json
import streamlit as st
import pycountry
import torch
from datetime import datetime
from typing import Dict, Union
from gliner import GLiNER

# Model configuration
_MODEL = {}
_CACHE_DIR = os.environ.get("CACHE_DIR", None)
THRESHOLD = 0.3
LABELS = ["country", "year", "statistical indicator", "geographic region"]
QUERY = "gdp, co2 emissions, and mortality rate of the philippines vs. south asia in 2024"
MODELS = ["urchade/gliner_base", "urchade/gliner_medium-v2.1", "urchade/gliner_multi-v2.1", "urchade/gliner_large-v2.1"]

# Helper functions
def get_model(model_name: str = None):
    if model_name is None:
        model_name = "urchade/gliner_base"

    global _MODEL

    if _MODEL.get(model_name) is None:
        _MODEL[model_name] = GLiNER.from_pretrained(model_name, cache_dir=_CACHE_DIR)

    if torch.cuda.is_available() and not next(_MODEL[model_name].parameters()).device.type.startswith("cuda"):
        _MODEL[model_name] = _MODEL[model_name].to("cuda")

    return _MODEL[model_name]

def get_country(country_name: str):
    try:
        return pycountry.countries.search_fuzzy(country_name)
    except LookupError:
        return None

def predict_entities(model_name: str, query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False):
    model = get_model(model_name)

    if isinstance(labels, str):
        labels = [i.strip() for i in labels.split(",")]

    entities = model.predict_entities(query, labels, threshold=threshold, flat_ner=not nested_ner)

    return entities

def parse_query(query: str, labels: Union[str, list], threshold: float = 0.3, nested_ner: bool = False, model_name: str = None) -> Dict[str, Union[str, list]]:
    entities = []
    _entities = predict_entities(model_name=model_name, query=query, labels=labels, threshold=threshold, nested_ner=nested_ner)

    for entity in _entities:
        if entity["label"] == "country":
            country = get_country(entity["text"])
            if country:
                entity["normalized"] = [dict(c) for c in country]
                entities.append(entity)
        else:
            entities.append(entity)

    return {"query": query, "entities": entities}

# Initialize only base model
@st.cache_resource
def initialize_model():
    return get_model("urchade/gliner_base")

# Pre-initialize the model
init_model = initialize_model()

# Create the Streamlit interface
st.title("GLiNER-based Query Parser")
st.markdown("""
This app demonstrates the GLiNER model's ability to predict entities in a given text query.
Given a set of entities to track, the model can identify instances of these entities in the query.
""")

# User inputs
query = st.text_area("Query", value=QUERY, placeholder="Enter your query here")
model_name = st.radio("Model", options=MODELS, index=0)
entities = st.text_input("Entities to detect (comma separated)", value=", ".join(LABELS))
threshold = st.slider("Threshold", min_value=0.0, max_value=1.0, value=THRESHOLD, step=0.01,
                     help="Lower threshold may extract more false-positive entities from the query.")
is_nested = st.checkbox("Nested NER", value=False, help="Setting to True extracts nested entities")

# Process when the form is submitted
if st.button("Submit"):
    with st.spinner("Processing..."):
        # Process the query
        result = parse_query(query, entities, threshold, is_nested, model_name)

        # Display the annotated text
        st.subheader("Detected Entities")

        # Create highlighted text display
        text = result["query"]
        annotations = []

        for entity in result["entities"]:
            annotations.append({
                "start": entity["start"],
                "end": entity["end"],
                "label": entity["label"],
                "score": round(entity["score"], 3)
            })

        # Sort annotations by start position
        annotations.sort(key=lambda x: x["start"])

        # Display original text with annotations
        processed_text = ""
        last_end = 0

        for annotation in annotations:
            # Add text before the entity
            processed_text += text[last_end:annotation["start"]]
            # Add highlighted entity
            entity_text = text[annotation["start"]:annotation["end"]]
            processed_text += f"**[{entity_text}]({annotation['label']}, {annotation['score']:.3f})**"
            last_end = annotation["end"]

        # Add any remaining text
        processed_text += text[last_end:]

        st.markdown(processed_text)

        # Display raw JSON output for more details
        with st.expander("Raw JSON Output"):
            st.json(result)

Overwriting streamlit_app.py


In [None]:
# Install required packages
!pip install gliner pycountry scipy==1.12 streamlit torch pyngrok

In [None]:
# Set up ngrok authentication
import os
from google.colab import userdata

# Get the ngrok auth token from Colab secrets
NGROK_AUTH_TOKEN = userdata.get('NGROK_AUTH_TOKEN2')

# Configure ngrok with the auth token
!ngrok authtoken $NGROK_AUTH_TOKEN

Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [None]:
# Start Streamlit app in the background
!nohup streamlit run streamlit_app.py --server.port=8509 --server.enableCORS=False --server.enableXsrfProtection=False &

nohup: appending output to 'nohup.out'


In [None]:
# Alternative: Use ngrok directly via command line
!ngrok http 8509 --log=stdout > /dev/null &

# Wait a moment for ngrok to start
import time
time.sleep(5)

# Get the public URL from ngrok API
import requests
import json
tunnels = json.loads(requests.get('http://localhost:4040/api/tunnels').text)
public_url = tunnels['tunnels'][0]['public_url']
print(f"Public URL: {public_url}")