<a href="https://colab.research.google.com/github/ljkrajewski/jupyter_notebooks/blob/main/ollama/ollama_w_gradio_and_memory.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q gradio
!pip install python-dotenv

import gradio as gr
import re
from google.colab import files
import multiprocessing
import os
import time
import requests
import json
import nltk
nltk.download('all')

#@title Defined globals
#@markdown **Common models**
#@markdown - "llama3"
#@markdown - "llama2-uncensored"
#@markdown - "dolphin-mistral"
#@markdown - "codellama:34b"
#@markdown - "qwen2.5:32b" _[Closest to gpt-4o-mini. Uses 20GB](https://www.reddit.com/r/LocalLLaMA/comments/1gdxi9h/which_open_source_model_is_comparable_to_gpt4omini/)_
#@markdown - "deepseek-r1:14b"

#@markdown **1.1b models (for running without a GPU)**
#@markdown - "tinyllama"
#@markdown - "tinydolphin"
#@markdown - "saikatkumardey/tinyllama" _(finetuned for chatting)_

#@markdown **Recommended story-telling models**
#@markdown - "HammerAI/mythomax-l2" (https://ollama.com/HammerAI/mythomax-l2)
#@markdown - "Austism/chronos-hermes-13b-v2" (https://huggingface.co/Austism/chronos-hermes-13b-v2)
#@markdown - "openhermes" (https://ollama.com/library/openhermes)

#@markdown **Model search/lookups**
#@markdown   - [ollama model library](https://ollama.com/library)
#@markdown   - [ollama model search](https://ollama.com/search)

model_name="tinyllama" #@param {type: "string"}  The name of the LLM.
#@markdown Remember to put the name of the model in quotes. E.g., "llama3.2:1b"
#debug=True #@param {type: "boolean"}
embedding_model="mxbai-embed-large" #@param ["mxbai-embed-large","nomic-embed-text","all-minilm"]{allow-input: false}
model_endpoint="http://localhost:11434/api/generate" #The endpoint for the LLM's API.

# Install ollama and ollama-local-rag

In [None]:
#@title Install and start ollama

!curl -fsSL https://ollama.com/install.sh | sh

def run_ollama():
    print(f"Running ollama on PID {os.getpid()}")
    os.system('ollama serve')

ollama_process = multiprocessing.Process(target=run_ollama)
ollama_process.start()
time.sleep(10)

print(f"Pulling {model_name}...")
!ollama pull $model_name
print(f"Pulling {embedding_model}...")
!ollama pull $embedding_model
print("Done.")

In [None]:
#@title Install ollama-local-rag
!git clone https://github.com/ljkrajewski/ollama-local-rag.git

%cd ollama-local-rag
!rm -rf /content/ollama-local-rag/docs/*
ENVFILE="""
LLM_MODEL={model_name}
EMBEDDING_MODEL={embedding_model}
"""
with open('.env', 'w') as f:
    f.write(ENVFILE)
!pip install -r requirements.txt



In [None]:
#@title Defined functions
# prompt: Write a function that takes a dictionary prompt and sends a request to an LLM's API. The output is given in dictionary.

def query_llm(prompt):
  global model_endpoint,model_name

  headers = {
      "Content-Type": "application/json",
  }
  data = {
      "model": model_name,
      "prompt": prompt,
      "stream": False
  }

  answer = requests.post(model_endpoint, headers=headers, json=data)

  if answer.status_code == 200:
    try:
      answer_dict = json.loads(answer.content)
    except json.JSONDecodeError as e:
      print(f"Error decoding JSON: {e}")
      return None
    return answer_dict["response"] #.replace("\n", " ")
  else:
    print(f"Error: {answer.status_code}")
    return None

def update_database():
  #!python create_database.py
  print("Updating database...")
  os.system('python /content/ollama-local-rag/create_database.py')
  print("Database updated.")
  return

def write_memory(data):
  global memory_file_num
  global memory_location
  with open(f"{memory_location}/memory_{memory_file_num}.md", "w") as f:
    f.write(data)
  update_database()
  memory_file_num += 1
  return

def separate_sections(text_stream):
    match = re.search(r"<think>(.*?)</think>(.*)", text_stream, re.DOTALL)
    if match:
        thinking_section = match.group(1).strip()
        results_section = match.group(2).strip()
        return thinking_section, results_section
    else:
        return "", text_stream.strip()  # Entire input is results if no <think>

def generate(prompt):
  print(f"\nGenerating response to:\n'{ prompt }'...")
  result = separate_sections(query_llm(prompt))
  thinking = result[0]
  answer = result[1]
  write_memory(prompt + "\n----\n" + answer)
  print(f"Response printed.")
  return answer + '\n', thinking

def clear_outputs():
  return "", "", "", ""

def save_log(current):
  logdir = "/content/logs"
  if not os.path.exists(logdir):
    os.makedirs(logdir)
  logname = time.strftime("%Y%m%d-%H%M%S") + ".md"
  logpath = f"{logdir}/{logname}"
  with open(logpath, "w") as f:
    f.write(current)
  #files.download(logpath)
  #os.remove(logpath)
  return f"{logname} created. Download from files browser."

# Main routine

In [None]:
#@title Test ollama connection
!curl http://localhost:11434

In [None]:
#@title Start gradio

with gr.Blocks(analytics_enabled=False) as demo:
  with gr.Column():
    with gr.Row():
      prompt_box = gr.Textbox(lines=10, interactive=True, value="", label="Prompt")
      with gr.Column():
        generate_button = gr.Button("Generate")
        clear_button = gr.Button("Clear chat")
        savelog_button = gr.Button("Save log")
        log_box = gr.Textbox(lines=1, interactive=False, label="System Messages")
    with gr.Row():
      #result_box = gr.Textbox(lines=20, interactive=False, label="Result")
      #think_box = gr.Textbox(lines=20, interactive=False, label="Thought Process (DeepSeek-R1 only)")
      result_box = gr.Markdown(show_copy_button=True, label="Result", value="_Result_", container=True, show_label=True, line_breaks=True, min_height="1000")
      think_box = gr.Markdown(show_copy_button=True, label="Thought Process", value="_Thought Process_", container=True, show_label=True, line_breaks=True, min_height="1000")

  generate_button.click(fn=generate, inputs=[prompt_box], outputs=[result_box,think_box])
  clear_button.click(fn=clear_outputs, inputs=[], outputs=[result_box,think_box,prompt_box,log_box])
  savelog_button.click(fn=save_log, inputs=[result_box], outputs=[log_box])

memory_location = "/content/ollama-local-rag/docs"
if not os.path.exists(memory_location):
  os.mkdir(memory_location)
memory_file_num = 1
demo.queue().launch(inline=False, share=True, debug=True)