# Search and learn with unsloth

## Goal

1. Learn to use unsloth
2. See how viable is to use it for search and learn
3. Compare speed with other methods

## Documentation

- https://docs.unsloth.ai/
- https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-Alpaca.ipynb
- https://docs.unsloth.ai/basics/reinforcement-learning-rl-guide

## Imports

In [None]:
import os
from arc25.utils import set_cuda_visible_devices_to_least_used_gpu_if_undefined
from arc25.logging import configure_logging

configure_logging()
set_cuda_visible_devices_to_least_used_gpu_if_undefined()

# Add VLLM specific environment variables to avoid common issues
os.environ['VLLM_USE_MODELSCOPE'] = 'False'
os.environ['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'

In [None]:
from unsloth import FastLanguageModel
from vllm import SamplingParams

## Code

## First steps

In [None]:
model_path = "/home/gbarbadillo/models/Llama-3.1-ARC-Potpourri-Induction-8B"
llm, tokenizer = FastLanguageModel.from_pretrained(model_path, load_in_4bit=True, max_seq_length=12000, fast_inference=True)

In [None]:
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Tell me a joke."}
]
inputs = tokenizer.apply_chat_template(
    messages, add_bos_token=True, return_tensors="pt"
).to(llm.device)
outputs = llm.generate(inputs, max_new_tokens = 64, use_cache = True)
print(tokenizer.batch_decode(outputs)[0])

In [None]:
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Tell me a joke."}
]
inputs = tokenizer.apply_chat_template(
    messages, add_generation_prompt=True, tokenize=False,
)
responses = llm.fast_generate(inputs)
print(responses[0].outputs[0].text)

This seems to be much faster, 0.3s vs 1.9s.

Let's see if we can make more predictions.

In [None]:
sampling_params = SamplingParams(n=8, temperature=1.0, top_p=0.95, max_tokens=2048)
responses = llm.fast_generate(inputs, sampling_params=sampling_params)
print(len(responses), len(responses[0].outputs))
print(responses[0].outputs[0].text)

Seems very similar to VLLM, I should do a direct comparison.

## TODO