In [None]:
# Licensed under the Apache License, Version 2.0

# inference.py

from abc import ABC, abstractmethod
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import T5Model, T5ForConditionalGeneration
import os

SUPPORTED_MODEL_DICT = {
  'gemma7b' : "google/gemma-7b",
  'gemma2b' : "google/gemma-2b",
  'llama3'  : "meta-llama/Meta-Llama-3-8B",
  't5-large' : "google-t5/t5-large",
}


from google.colab import drive
drive.mount('/content/gdrive')

class InferencePlatform(ABC):
  """An abstract class for the LLM inference platform we use."""
  @abstractmethod
  def predict(self, prompt: str) -> str:
    pass

class HuggingFace(InferencePlatform):
  """An implementation for using HuggingFace as the platform for LLM inference."""
  def __init__(self):
    self._tokenizer = None
    self._model = None

  def authenticate(self, huggingface_token):
    os.environ['HF_TOKEN'] = huggingface_token

  def setup_model(self, model_name: str):
    if model_name not in SUPPORTED_MODEL_DICT:
      raise ValueError(f'Unsupported model: {model_name}')
    self.model_name=model_name
    if model_name in ['gemma2b','gemma7b','llama3']:
      hf_path = SUPPORTED_MODEL_DICT[model_name]
      self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
      self.model = AutoModelForCausalLM.from_pretrained(hf_path)

      # self.tokenizer = AutoTokenizer.from_pretrained("gdrive/My Drive/Colab Notebooks/gemma-7b-tokenizer", local_files_only=True)
      # self.model = AutoModelForCausalLM.from_pretrained("gdrive/My Drive/Colab Notebooks/gemma-7b-model", local_files_only=True)
    else:
      hf_path = SUPPORTED_MODEL_DICT[model_name]
      self.tokenizer = AutoTokenizer.from_pretrained(hf_path)
      self.model = T5ForConditionalGeneration.from_pretrained(hf_path)


  def predict(self, prompt: str) -> str:
    inputs = self.tokenizer(prompt, return_tensors='pt')
    generate_ids = self.model.generate(inputs.input_ids)
    return self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True)

class VertexAI(InferencePlatform):
  """An implementation for using Google Cloud's Vertex AI as the platform for
  LLM inference."""
  def __init__(self):
    pass

  def predict(self, prompt: str) -> str:
    pass

In [None]:
# tree_utils.py

def build_tree(dataset_path):
  """Builds and returns the hierarchical document tree."""
  pass

def search_tree(doc_tree, inference_platform: InferencePlatform):
  """Searches the hierarchical document tree recursively."""
  pass

In [None]:
# run.py

HF_TOKEN = "" # add your hugging face access token here


def main():
  inference_platform = HuggingFace()
  inference_platform.authenticate(HF_TOKEN)
  inference_platform.setup_model('gemma2b')

  test_prediction=inference_platform.predict("question: Who earned the first nobel prize in physics? \n answer:")
  print(test_prediction)

  dataset_path = ""
  # Build hierarchical document tree.
  doc_tree = build_tree(dataset_path)

  # Search tree recursively.
  search_tree(doc_tree, inference_platform)

if __name__ == '__main__':
  main()

In [None]:
# visualize.py

# TODO(james): See if we can reuse the original visualization from the paper.