<a href="https://colab.research.google.com/github/navneetkrc/Deep_learning_experiments/blob/master/NER_streamlit.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Both NER and Classification

In [None]:
# == CELL 1: Install Dependencies and Set Up ngrok ==
# Run this cell FIRST.

!pip install streamlit gliner gliclass transformers torch pyngrok nest_asyncio

import subprocess
import threading
import time
import socket
from pyngrok import ngrok, conf
import os
import nest_asyncio

# Apply nest_asyncio (required for Colab)
nest_asyncio.apply()

def get_free_port():
    """Finds a free port."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        return s.getsockname()[1]

def run_streamlit(port):
    """Starts Streamlit in a separate process."""
    print(f"Starting Streamlit on port {port}...")
    cmd = [
        "streamlit", "run", "temp_app.py",
        "--server.port", str(port),
        "--server.headless", "true",
        "--browser.serverAddress", "localhost"
    ]
    subprocess.Popen(cmd)

def start_ngrok(port):
    """Starts ngrok and returns the URL."""
    from google.colab import userdata
    ngrok_token = userdata.get('NGROK_TOKEN')

    conf.get_default().auth_token = ngrok_token
    conf.get_default().region = 'us'
    os.system("killall ngrok")  # Kill existing processes

    print(f"Starting ngrok tunnel for port {port}...")
    tunnel = ngrok.connect(port)
    print("Ngrok tunnel:", tunnel.public_url)
    return tunnel.public_url

In [None]:
# == CELL 2: Define and Run Streamlit App (ALL-IN-ONE) ==
# Run this cell SECOND.

if __name__ == "__main__":
    free_port = get_free_port()

    # ALL Streamlit app code, including imports, goes into this string:
    streamlit_app_code = """
import streamlit as st
from gliner import GLiNER
from gliclass import GLiClassModel, ZeroShotClassificationPipeline
from transformers import AutoTokenizer
import torch
import subprocess
import threading
import time
import socket
from pyngrok import ngrok, conf  #Although not used, kept if future modification is required
import os
import nest_asyncio  #Although not used, kept if future modification is required

# Apply nest_asyncio (required for Colab) #Although not used, kept if future modification is required
nest_asyncio.apply()

# Determine available device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load GLiNER models (cached)
@st.cache_resource
def load_gliner_model(model_name):
    return GLiNER.from_pretrained(model_name)

# Load GLiClass models (cached)
@st.cache_resource
def load_gliclass_model(model_name):
    model = GLiClassModel.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    pipeline = ZeroShotClassificationPipeline(model, tokenizer, classification_type='multi-label', device=device)
    return pipeline

# Available models
GLINER_MODELS = {
    "knowledgator/modern-gliner-bi-large-v1.0": "GLiNER BiLarge v1.0",
    "knowledgator/gliner-multitask-large-v0.5": "GLiNER multi task",
    "knowledgator/modern-gliner-bi-base-v1.0": "GLiNER Base"
}

GLICLASS_MODELS = {
    "knowledgator/gliclass-modern-base-v2.0-init": "GLiClass Modern Base v2.0",
    "knowledgator/gliclass-modern-large-v2.0-init": "GLiClass Modern Large v2.0"
}


# --- Streamlit App Code (defined as a function) ---
def streamlit_app():  # The Streamlit app logic
    st.set_page_config(page_title="Text Analysis with GLiNER & GLiClass", layout="wide")
    st.title("Text Analysis with GLiNER & GLiClass")
    st.markdown("Compare entity extraction and zero-shot classification across multiple models")

    if 'default_gliner' not in st.session_state:
        st.session_state.default_gliner = list(GLINER_MODELS.keys())[0]
    if 'default_gliclass' not in st.session_state:
        st.session_state.default_gliclass = list(GLICLASS_MODELS.keys())[0]
    if 'analysis_type' not in st.session_state:
        st.session_state.analysis_type = "NER"

    with st.sidebar:
        st.header("Settings")
        st.session_state.analysis_type = st.radio(
            "Analysis Type",
            ["Named Entity Recognition (NER)", "Zero-Shot Classification"],
            index=0,
            format_func=lambda x: x.split(" ")[0] if "(" in x else x
        )
        is_ner = "Named" in st.session_state.analysis_type
        st.subheader("Model Selection")

        if is_ner:
            st.session_state.default_gliner = st.selectbox(
                "Set default NER model:",
                options=list(GLINER_MODELS.keys()),
                format_func=lambda x: GLINER_MODELS[x],
                index=list(GLINER_MODELS.keys()).index(st.session_state.default_gliner)
            )
            selected_models = [st.session_state.default_gliner]
            remaining_models = [m for m in GLINER_MODELS.keys() if m != st.session_state.default_gliner]
            for i, model in enumerate(remaining_models[:2]):
                if st.checkbox(f"Also use {GLINER_MODELS[model]}", key=f"gliner_{i}"):
                    selected_models.append(model)
        else:
            st.session_state.default_gliclass = st.selectbox(
                "Set default classification model:",
                options=list(GLICLASS_MODELS.keys()),
                format_func=lambda x: GLICLASS_MODELS[x],
                index=list(GLICLASS_MODELS.keys()).index(st.session_state.default_gliclass)
            )
            selected_models = [st.session_state.default_gliclass]
            remaining_models = [m for m in GLICLASS_MODELS.keys() if m != st.session_state.default_gliclass]
            for i, model in enumerate(remaining_models[:2]):
                if st.checkbox(f"Also use {GLICLASS_MODELS[model]}", key=f"gliclass_{i}"):
                    selected_models.append(model)

        threshold = st.slider("Confidence Threshold", min_value=0.1, max_value=0.9, value=0.5, step=0.05)
        st.divider()
        st.markdown("### About")
        st.markdown("This app uses GLiNER for named entity recognition and GLiClass for zero-shot classification.")

    st.subheader("Text Input")
    example_text = "One day I will see the world!"

    text = st.text_area(
        "Enter text for analysis:",
        value=example_text,
        height=150,
        placeholder="Enter your text here..."
    )

    default_ner_labels = ["product", "product type", "price", "memory", "feature"]
    default_class_labels = ["product", "product type", "price", "memory", "feature"]

    if is_ner:
        st.subheader("Entity Labels")
        default_labels = default_ner_labels
    else:
        st.subheader("Classification Labels")
        default_labels = default_class_labels

    col1, col2 = st.columns([1, 1])
    selected_labels = {}

    with col1:
        st.markdown("**Default Labels:**")
        for label in default_labels:
            selected_labels[label] = st.checkbox(label, value=True)

    with col2:
        st.markdown("**Add Custom Labels:**")
        custom_label = st.text_input("New label:", placeholder="Enter a custom label")
        if custom_label and st.button("Add Label"):
            if custom_label.lower() not in [l.lower() for l in selected_labels.keys()]:
                selected_labels[custom_label] = True
                st.success(f"Added '{custom_label}' to labels")
            else:
                st.warning("This label already exists")

    final_labels = [label for label, selected in selected_labels.items() if selected]
    st.markdown("### Selected Labels:")
    st.write(", ".join(final_labels) if final_labels else "No labels selected")

    if st.button("Analyze Text"):
        if not text.strip():
            st.error("Please provide text for analysis.")
        elif not final_labels:
            st.error("Please select at least one label.")
        else:
            model_names = [GLINER_MODELS[model] if is_ner else GLICLASS_MODELS[model] for model in selected_models]
            tabs = st.tabs(model_names)
            colors = {}
            default_colors = ["#FF9AA2", "#FFB7B2", "#FFDAC1", "#E2F0CB", "#B5EAD7", "#C7CEEA"]
            model_cache = {}

            for i, model_name in enumerate(selected_models):
                with tabs[i]:
                    with st.spinner(f"Analyzing with {model_names[i]}..."):
                        if is_ner:
                            if model_name not in model_cache:
                                model_cache[model_name] = load_gliner_model(model_name)
                            model = model_cache[model_name]
                            entities = model.predict_entities(text, final_labels, threshold=threshold)
                            for entity in entities:
                                if entity['label'] not in colors:
                                    colors[entity['label']] = default_colors[len(colors) % len(default_colors)]

                            if entities:
                                entities_html = text
                                for entity in sorted(entities, key=lambda e: e.get('score', 0), reverse=True):
                                    if entity['text'] in entities_html:
                                        highlight = f'<span style="background-color: {colors[entity["label"]]}; padding: 2px; border-radius: 3px;">{entity["text"]} <small>({entity["label"]} - {entity.get("score", 0):.2f})</small></span>'
                                        entities_html = entities_html.replace(entity['text'], highlight, 1)
                                st.write("Text with highlighted entities:")
                                st.markdown(entities_html, unsafe_allow_html=True)
                                st.markdown("### Entity List:")
                                entity_data = {"Entity": [], "Label": [], "Confidence": []}
                                for entity in entities:
                                    entity_data["Entity"].append(entity['text'])
                                    entity_data["Label"].append(entity['label'])
                                    entity_data["Confidence"].append(f"{entity.get('score', 0):.2f}")
                                st.table(entity_data)
                                st.markdown("### Entity Statistics:")
                                label_counts = {}
                                for entity in entities:
                                    label = entity['label']
                                    if label in label_counts:
                                        label_counts[label] += 1
                                    else:
                                        label_counts[label] = 1
                                st.bar_chart(label_counts)
                            else:
                                st.info("No entities matching your labels were found in the text.")
                        else:
                            if model_name not in model_cache:
                                model_cache[model_name] = load_gliclass_model(model_name)
                            pipeline = model_cache[model_name]
                            results = pipeline(text, final_labels, threshold=threshold)[0]
                            if results:
                                st.markdown("### Classification Results:")
                                label_data = {}
                                for result in results:
                                    label = result["label"]
                                    score = result["score"]
                                    label_data[label] = score
                                    if label not in colors:
                                        colors[label] = default_colors[len(colors) % len(default_colors)]
                                st.bar_chart(label_data)
                                class_data = {"Label": [], "Confidence": []}
                                for result in sorted(results, key=lambda x: x["score"], reverse=True):
                                    class_data["Label"].append(result["label"])
                                    class_data["Confidence"].append(f"{result['score']:.4f}")
                                st.table(class_data)
                                st.markdown("### Text with predicted labels:")
                                classes_html = f'<div style="padding: 10px; border: 1px solid #ccc; border-radius: 5px;">{text}<br><br><b>Predicted labels:</b> '
                                for result in sorted(results, key=lambda x: x["score"], reverse=True):
                                    label = result["label"]
                                    score = result["score"]
                                    classes_html += f'<span style="background-color: {colors[label]}; padding: 2px 6px; margin: 0 4px; border-radius: 3px;">{label} ({score:.2f})</span>'
                                classes_html += '</div>'
                                st.markdown(classes_html, unsafe_allow_html=True)
                            else:
                                st.info("No labels crossed the confidence threshold.")

    if 'colors' in locals() and colors:
        st.sidebar.markdown("### Color Legend:")
        for label, color in colors.items():
            st.sidebar.markdown(
                f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin-bottom: 5px;">{label}</div>',
                unsafe_allow_html=True
            )

streamlit_app()  # Calling the function to run the app
"""

    # Write the ENTIRE app code to temp_app.py
    with open("temp_app.py", "w") as f:
        f.write(streamlit_app_code)

    # Run Streamlit in a separate thread
    thread = threading.Thread(target=run_streamlit, args=(free_port,))
    thread.start()
    time.sleep(5)  # Wait for Streamlit

    public_url = start_ngrok(free_port)  # Get the ngrok URL
    print(f"Access your Streamlit app at: {public_url}")

    while True:  # Keep the notebook alive
        time.sleep(60)

##Only NER

In [None]:
# == CELL 1: Install Dependencies and Set Up ngrok ==
# Run this cell FIRST.

!pip install streamlit gliner transformers torch pyngrok nest_asyncio

import subprocess
import threading
import time
import socket
from pyngrok import ngrok, conf
import os
import nest_asyncio

# Apply nest_asyncio (required for Colab)
nest_asyncio.apply()

def get_free_port():
    """Finds a free port."""
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(('', 0))
        return s.getsockname()[1]

def run_streamlit(port):
    """Starts Streamlit in a separate process."""
    print(f"Starting Streamlit on port {port}...")
    cmd = [
        "streamlit", "run", "temp_app.py",
        "--server.port", str(port),
        "--server.headless", "true",
        "--browser.serverAddress", "localhost"
    ]
    subprocess.Popen(cmd)

def start_ngrok(port):
    """Starts ngrok and returns the URL."""
    from google.colab import userdata
    ngrok_token = userdata.get('NGROK_TOKEN')

    conf.get_default().auth_token = ngrok_token
    conf.get_default().region = 'us'
    os.system("killall ngrok")  # Kill existing processes

    print(f"Starting ngrok tunnel for port {port}...")
    tunnel = ngrok.connect(port)
    print("Ngrok tunnel:", tunnel.public_url)
    return tunnel.public_url

In [None]:
# == CELL 2: Define and Run Streamlit App (NER ONLY) ==
# Run this cell SECOND.

if __name__ == "__main__":
    free_port = get_free_port()

    # ALL Streamlit app code (NER only), including imports, goes into this string:
    streamlit_app_code = """
import streamlit as st
from gliner import GLiNER
from transformers import AutoTokenizer  # We only need AutoTokenizer for GLiNER
import torch

# Determine available device
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Load GLiNER models (cached)
@st.cache_resource
def load_gliner_model(model_name):
    return GLiNER.from_pretrained(model_name)


# Available GLiNER models
GLINER_MODELS = {
    "knowledgator/modern-gliner-bi-large-v1.0": "GLiNER BiLarge v1.0",
    "knowledgator/gliner-multitask-large-v0.5": "GLiNER multi task",
    "knowledgator/modern-gliner-bi-base-v1.0": "GLiNER Base"
}


# --- Streamlit App Code (NER Only) ---
def streamlit_app():
    st.set_page_config(page_title="Text Analysis with GLiNER", layout="wide")  # Updated title
    st.title("Named Entity Recognition with GLiNER")  # Updated title
    st.markdown("Extract named entities from text using GLiNER.") # Updated description

    if 'default_gliner' not in st.session_state:
        st.session_state.default_gliner = list(GLINER_MODELS.keys())[0]


    with st.sidebar:
        st.header("Settings")
        st.subheader("Model Selection")

        # GLiNER models
        st.session_state.default_gliner = st.selectbox(
            "Select NER model:",  # Simplified label
            options=list(GLINER_MODELS.keys()),
            format_func=lambda x: GLINER_MODELS[x],
            index=list(GLINER_MODELS.keys()).index(st.session_state.default_gliner)
        )

        # Multiple model selection (up to 3)
        selected_models = []
        selected_models.append(st.session_state.default_gliner)  # Default model always selected

        # Add options for additional models
        remaining_models = [m for m in GLINER_MODELS.keys() if m != st.session_state.default_gliner]
        for i, model in enumerate(remaining_models[:2]):  # Limit to 2 additional models
            if st.checkbox(f"Also use {GLINER_MODELS[model]}", key=f"gliner_{i}"):
                selected_models.append(model)


        threshold = st.slider("Confidence Threshold", min_value=0.1, max_value=0.9, value=0.5, step=0.05)
        st.divider()
        st.markdown("### About")
        st.markdown("This app uses GLiNER for named entity recognition.")

    st.subheader("Text Input")
    example_text = "What is the difference between the Samsung Galaxy S23 and S23 Ultra?, What does 'Dynamic AMOLED 2X' mean in a Samsung display?, How do I use the S Pen on my Samsung Galaxy Note or Ultra phone?, What is 'One UI' on Samsung phones?, What is 'Samsung Knox' and why is it important?, What is the difference between Samsung QLED and Neo QLED TVs?, What is 'SmartThings' in Samsung appliances?, What is 'Bespoke' in Samsung refrigerators?, What does 'EcoBubble' mean in Samsung washing machines?, What is the Samsung Galaxy Watch and what are its features?, What are Samsung's different memory and storage devices?, What is the difference between a Samsung soundbar and a home theater system?"

    text = st.text_area(
        "Enter text for analysis:",
        value=example_text,
        height=150,
        placeholder="Enter your text here..."
    )

    default_ner_labels = ["product", "product type", "price", "memory", "feature"]
    st.subheader("Entity Labels")
    default_labels = default_ner_labels


    col1, col2 = st.columns([1, 1])
    selected_labels = {}

    with col1:
        st.markdown("**Default Labels:**")
        for label in default_labels:
            selected_labels[label] = st.checkbox(label, value=True)

    with col2:
        st.markdown("**Add Custom Labels:**")
        custom_label = st.text_input("New label:", placeholder="Enter a custom label")
        if custom_label and st.button("Add Label"):
            if custom_label.lower() not in [l.lower() for l in selected_labels.keys()]:
                selected_labels[custom_label] = True
                st.success(f"Added '{custom_label}' to labels")
            else:
                st.warning("This label already exists")

    final_labels = [label for label, selected in selected_labels.items() if selected]
    st.markdown("### Selected Labels:")
    st.write(", ".join(final_labels) if final_labels else "No labels selected")

    if st.button("Analyze Text"):
        if not text.strip():
            st.error("Please provide text for analysis.")
        elif not final_labels:
            st.error("Please select at least one label.")
        else:
            model_names = [GLINER_MODELS[model] for model in selected_models]
            tabs = st.tabs(model_names)
            colors = {}
            default_colors = ["#FF9AA2", "#FFB7B2", "#FFDAC1", "#E2F0CB", "#B5EAD7", "#C7CEEA"]
            model_cache = {}

            for i, model_name in enumerate(selected_models):
                with tabs[i]:
                    with st.spinner(f"Analyzing with {model_names[i]}..."):
                        if model_name not in model_cache:
                            model_cache[model_name] = load_gliner_model(model_name)
                        model = model_cache[model_name]
                        entities = model.predict_entities(text, final_labels, threshold=threshold)
                        for entity in entities:
                            if entity['label'] not in colors:
                                colors[entity['label']] = default_colors[len(colors) % len(default_colors)]

                        if entities:
                            entities_html = text
                            for entity in sorted(entities, key=lambda e: e.get('score', 0), reverse=True):
                                if entity['text'] in entities_html:
                                    highlight = f'<span style="background-color: {colors[entity["label"]]}; padding: 2px; border-radius: 3px;">{entity["text"]} <small>({entity["label"]} - {entity.get("score", 0):.2f})</small></span>'
                                    entities_html = entities_html.replace(entity['text'], highlight, 1)
                            st.write("Text with highlighted entities:")
                            st.markdown(entities_html, unsafe_allow_html=True)
                            st.markdown("### Entity List:")
                            entity_data = {"Entity": [], "Label": [], "Confidence": []}
                            for entity in entities:
                                entity_data["Entity"].append(entity['text'])
                                entity_data["Label"].append(entity['label'])
                                entity_data["Confidence"].append(f"{entity.get('score', 0):.2f}")
                            st.table(entity_data)
                            st.markdown("### Entity Statistics:")
                            label_counts = {}
                            for entity in entities:
                                label = entity['label']
                                if label in label_counts:
                                    label_counts[label] += 1
                                else:
                                    label_counts[label] = 1
                            st.bar_chart(label_counts)
                        else:
                            st.info("No entities matching your labels were found in the text.")


    if 'colors' in locals() and colors:
        st.sidebar.markdown("### Color Legend:")
        for label, color in colors.items():
            st.sidebar.markdown(
                f'<div style="background-color: {color}; padding: 5px; border-radius: 3px; margin-bottom: 5px;">{label}</div>',
                unsafe_allow_html=True
            )

streamlit_app()  # Calling the function to run the app
"""

    # Write the ENTIRE app code to temp_app.py
    with open("temp_app.py", "w") as f:
        f.write(streamlit_app_code)

    # Run Streamlit in a separate thread
    thread = threading.Thread(target=run_streamlit, args=(free_port,))
    thread.start()
    time.sleep(5)  # Wait for Streamlit

    public_url = start_ngrok(free_port)  # Get the ngrok URL
    print(f"Access your Streamlit app at: {public_url}")

    while True:  # Keep the notebook alive
        time.sleep(60)