# Datalayer Run - LAMA Example with GPU

Inspired by the original source code available at [Modal Labs GitHub](https://github.com/modal-labs/modal-examples/blob/main/06_gpu_and_ml/openllama.py).

In [None]:
from transformers import LlamaForCausalLM, LlamaTokenizer
from transformers import GenerationConfig
import torch

BASE_MODEL = "openlm-research/open_llama_7b"

In [None]:
def download_models():
    LlamaForCausalLM.from_pretrained(BASE_MODEL)
    LlamaTokenizer.from_pretrained(BASE_MODEL)

In [None]:
download_models()

In [None]:
class OpenLlamaModel:
    def __enter__(self):

        self.tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL)

        model = LlamaForCausalLM.from_pretrained(
            BASE_MODEL,
            torch_dtype=torch.float16,
            device_map="auto",
        )

        self.tokenizer.bos_token_id = 1

        model.eval()
        self.model = torch.compile(model)
        self.device = "cuda"

    def generate(
        self,
        input,
        max_new_tokens=128,
        **kwargs,
    ):
        inputs = self.tokenizer(input, return_tensors="pt")
        input_ids = inputs["input_ids"].to(self.device)

        generation_config = GenerationConfig(**kwargs)
        with torch.no_grad():
            generation_output = self.model.generate(
                input_ids=input_ids,
                generation_config=generation_config,
                return_dict_in_generate=True,
                output_scores=True,
                max_new_tokens=max_new_tokens,
            )
        s = generation_output.sequences[0]
        output = self.tokenizer.decode(s)
        print(f"\033[96m{input}\033[0m")
        print(output.split(input)[1].strip())

In [None]:
def main():
    inputs = [
        "Building a website can be done in 10 simple steps:",
    ]
    model = OpenLlamaModel()
    for input in inputs:
        model.generate(
            input,
            top_p=0.75,
            top_k=40,
            num_beams=1,
            temperature=0.1,
            do_sample=True,
        )

In [None]:
main()