<a href="https://colab.research.google.com/github/coderuhaan2004/KnowledgeGraphGenerator/blob/main/KGG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Make sure to run this colab on T4 runtime.

Install necessary libraries by importing ```requirements.txt``` file in google colab from repository.

In [None]:
!pip install -r requirements.txt

Run the below cell to import all necessary libraries

In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from IPython.display import HTML, display
from google.colab import output
from pyvis.network import Network

Run the below cell to import open-orca model.

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    #Quant type
    #We will use the "nf4" format this was introduced in the QLoRA paper
    bnb_4bit_quant_type="nf4",
    #As the model weights are stored using 4 bits and when we want to compute its only going to use 16 bits so we have more accuracy
    bnb_4bit_compute_dtype=torch.float16,
    #Quantization parameters are quantized
    bnb_4bit_use_double_quant=True,
)

model_name = "Open-Orca/Mistral-7B-OpenOrca"
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             quantization_config=bnb_config,
                                             )
model.config.use_cache = False
model.config.pretraining_tp = 1
tokenizer = AutoTokenizer.from_pretrained(model_name)

Run the cell below to view UI for text and wait for some time for result.

In [None]:
html_content = '''
<div id="knowledge-graph-generator" style="width: 100%; height: 600px; background-color: #f0f0f0; padding: 20px; border-radius: 10px;">
    <style>
    @import url('https://fonts.googleapis.com/css2?family=Roboto:ital,wght@0,100;0,300;0,400;0,500;0,700;0,900;1,100;1,300;1,400;1,500;1,700;1,900&display=swap');

    #knowledge-graph-generator {
        font-family: "Roboto", monospace;
        display: flex;
        flex-direction: column;
    }

    .header {
        display: flex;
        justify-content: center;
        align-items: center;
        font-size: 2.5rem;
        font-weight: 500;
        margin-bottom: 20px;
    }

    .main {
        flex-grow: 1;
        background-color: #080808;
        padding: 30px;
        border-radius: 10px;
        display: flex;
        flex-direction: column;
        justify-content: space-between;
    }

    .text {
        color: white;
        font-size: 2.5vh;
        font-weight: 100;
        margin-bottom: 30px;
    }

    .bar {
        background-color: white;
        padding: 10px 15px;
        border-radius: 25px;
        display: flex;
        align-items: center;
        height: 50px;
    }

    .bar input {
        flex-grow: 1;
        height: 100%;
        border: none;
        font-size: 18px;
        padding: 0 10px;
        outline: none;
    }

    .bar button {
        background: none;
        border: none;
        cursor: pointer;
    }

    .bar img {
        width: 30px;
        height: auto;
    }
    </style>
    <header class="header">
        <div>Knowledge Graph Generator</div>
    </header>

    <main class="main">
        <div class="text">
            <div class="static">
                We give information,
            </div>
            <div class="moving">
                worth <span class="move"></span>
            </div>
        </div>
        <div class="bar">
            <input type="text" id="search" placeholder="Search" aria-label="search">
            <button id="btn"><img src="https://img.icons8.com/ios-filled/50/000000/search--v1.png" alt="Search"></button>
        </div>
    </main>
</div>

<script src="https://unpkg.com/typed.js@2.1.0/dist/typed.umd.js"></script>
<script>
var typed = new Typed(".move", {
    strings: ["reading", "exploring", "understanding"],
    typeSpeed: 70,
    backSpeed: 60,
    loop: true
});

document.getElementById('btn').addEventListener('click', function() {
    var text = document.getElementById('search').value;
    console.log('input text', text);
    google.colab.kernel.invokeFunction('notebook.generate_knowledge_graph', [text], {});
});
</script>
'''
def getprompt(text):
    SYS_PROMPT = "You are an AI assistant tasked with extracting structured information from the context to create a knowledge graph. Your goal is to identify key entities and their relationships in the context and present this information in a JSON format with fields: 'node1', 'node2', and 'relationship'."
    USER_PROMPT = f"context: ```{text}``` \n\n output: "

    PROMPT = f"{SYS_PROMPT}\n\n{USER_PROMPT}"
    return PROMPT
def function(text):
    prompt = getprompt(text)
    inputs = tokenizer.encode(prompt, return_tensors="pt")
    outputs = model.generate(inputs, max_length = 1024, num_return_sequences=1)
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    json_response = response.split("[")[1].split("]")[0]
    json_response = "[\n" + json_response+"]"
    json_response = json.loads(json_response)
    return json_response

def generate_knowledge_graph(text):
    data = function(text)
    net = Network(notebook=True, directed=True,cdn_resources='remote')
    for relation in data:
        net.add_node(relation['node1'], label=relation['node1'], title=relation['node1'])
        net.add_node(relation['node2'], label=relation['node2'], title=relation['node2'])
        net.add_edge(relation['node1'], relation['node2'], title=relation['relationship'],label = relation['relationship'])
    net.repulsion(node_distance=180, spring_length=100)
    return net.generate_html()

# Register the Python function to be callable from JavaScript
output.register_callback('notebook.generate_knowledge_graph', lambda s: display(HTML(generate_knowledge_graph(s))))

# Display the HTML
display(HTML(html_content))

print("The Knowledge Graph Generator is now displayed above.")
print("You can interact with it and see the results in this notebook.")