In [None]:
# See ./README.md
import contextlib
import os
import sys
import time

import torch

import gemma.config
import gemma.model

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
if DEVICE == "mps":
    print("WARNING: mps will fail because of https://github.com/pytorch/pytorch/issues/122427")
    DEVICE = "cpu"
print(f"Using {DEVICE}")

# Make it reproducible.
torch.manual_seed(12)

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
    """Sets the default torch dtype to the given dtype."""
    torch.set_default_dtype(dtype)
    yield
    torch.set_default_dtype(torch.float)

def get_generator(seed):
    if DEVICE in ("cuda", "mps"):
        return torch.Generator(DEVICE).manual_seed(seed)
    return torch.Generator().manual_seed(seed)

def load_model(variant, modelpath):
    start = time.time()
    modelweights = os.path.join(modelpath, f"gemma-{variant}.ckpt")
    if not os.path.isfile(modelweights):
        print(f"Can't find {modelweights}. See ./README.md")
        sys.exit(1)

    cfg = gemma.config.get_model_config(variant)
    cfg.tokenizer = os.path.join(modelpath, "tokenizer.model")
    cfg.dtype = "float16" if DEVICE in ("cuda", "mps") else "float32"
    cfg.quant = False

    device = torch.device(DEVICE)
    with _set_default_tensor_type(cfg.get_dtype()):
        model = gemma.model.GemmaForCausalLM(cfg)
        model.load_weights(modelweights)
        model = model.to(DEVICE).eval()

    print(f"Model loaded in {time.time()-start:.1f}s")
    return model

def generate(p):
    start = time.time()
    result = model.generate(p, DEVICE)
    print(f"{result}\nin {time.time()-start:.1f}s")

In [None]:
#model = load_model("2b", os.path.expanduser("~/Téléchargements/gemma-2b"))
model = load_model("2b", os.path.expanduser("~/Downloads/gemma-2b"))

In [None]:
generate("How are you doing?")