# Text Generation

In [None]:
import torch
import transformers

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device('mps')
else:
    device = torch.device('cpu')

In [None]:
prompt = 'What is potato?'

## Remote - openai-community/gpt2

In [None]:
pretrained_model_name = 'openai-community/gpt2'

In [None]:
pipeline = transformers.pipeline(
    task='text-generation',
    model=pretrained_model_name,
    device=device,
)

In [None]:
outputs = pipeline(text_inputs=prompt)

In [None]:
import json
print(json.dumps(outputs, indent=2))

## Local - mistralai/Mistral-7B-Instruct-v0.2

In [None]:
import os
pretrained_model_path = os.path.abspath(os.path.expanduser('~/Workplace/models/Mistral-7B-Instruct-v0.2'))

In [None]:
model = transformers.MistralForCausalLM.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_path,
    local_files_only=True,
    torch_dtype=torch.bfloat16,
)

In [None]:
tokenizer = transformers.LlamaTokenizerFast.from_pretrained(
    pretrained_model_name_or_path=pretrained_model_path,
    local_files_only=True,
)

In [None]:
outputs = pipeline(
    text_inputs=prompt,
    return_full_text=False,
    max_new_tokens=256,
    do_sample=True,
    pad_token_id=tokenizer.eos_token_id,
)

In [None]:
import json
print(json.dumps(outputs, indent=2))