In [1]:
import os
import sys
import importlib
import types

# Add the parent directory of the current working directory to the Python path
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

def import_local_module(module_path: str) -> types.ModuleType:
    """Import a local module, forcing a reload if it's already in sys.modules."""
    try:
        # If the module is already in sys.modules, remove it to force a fresh import
        if module_path in sys.modules:
            del sys.modules[module_path]
        
        # Import the module
        module = importlib.import_module(module_path)
        
        # Reload the module and any of its submodules
        importlib.reload(module)
        if hasattr(module, '__all__'):
            for attr_name in module.__all__:
                attr = getattr(module, attr_name)
                if isinstance(attr, types.ModuleType):
                    importlib.reload(attr)
        
        return module
    except Exception as e:
        print(f"Error importing/reloading module {module_path}: {e}")
        raise

try:
    import felafax
    print("felafax package imported successfully")
except ImportError as e:
    print(f"Error importing felafax: {e}")

# Imports felafax trainer_engine
setup = import_local_module("trainer_engine.setup")
setup.setup_environment()

felafax package imported successfully


In [2]:
utils = import_local_module("trainer_engine.utils")
jax_utils = import_local_module("trainer_engine.jax_utils")

checkpoint_lib = import_local_module("trainer_engine.checkpoint_lib")
training_pipeline = import_local_module("trainer_engine.trainer_lib")
auto_lib = import_local_module("trainer_engine.auto_lib")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
llama_config = import_local_module("llama_config")

In [4]:
HUGGINGFACE_USERNAME = input("INPUT: Please provide your HUGGINGFACE_USERNAME: ")
HUGGINGFACE_TOKEN = input("INPUT: Please provide your HUGGINGFACE_TOKEN: ")

INPUT: Please provide your HUGGINGFACE_USERNAME:  felarof01
INPUT: Please provide your HUGGINGFACE_TOKEN:  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


In [5]:
# Select a supported model from above list to use!
MODEL_NAME = "Meta-Llama-3.1-8B"

In [7]:
model_path, model, model_config, tokenizer = auto_lib.AutoJAXModelForCausalLM.from_pretrained("llama-3.1-8B-JAX",
                                                                           HUGGINGFACE_TOKEN)

Downloading model llama-3.1-8B-JAX...


Fetching 3 files: 100%|██████████| 3/3 [00:00<00:00, 19691.57it/s]

llama-3.1-8B-JAX was downloaded to /home/felafax-storage/hf/models--felafax--llama-3.1-8B-JAX/snapshots/ebca17f216e4c02e0f31cc47264a9d65a4f5b9a9.



