In [1]:
%env BNB_CUDA_VERSION=125
from modeling_quiet_star_mistral import QuietMistralForCausalLM
import torch

env: BNB_CUDA_VERSION=125


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('ezelikman/quietstar-8-ahead')
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id

In [4]:
n_ahead = 42
n_ahead_talk = 1

model = QuietMistralForCausalLM.from_pretrained(
	'ezelikman/quietstar-8-ahead',
        #  load_in_8bit=True,
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        max_thoughts=n_ahead + n_ahead_talk + 1,
)

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
This can be used to load a bitsandbytes version that is different from the PyTorch CUDA version.
If this was unintended set the BNB_CUDA_VERSION variable to an empty string: export BNB_CUDA_VERSION=
If you use the manual override make sure the right libcudart.so is in your LD_LIBRARY_PATH
For example by adding the following to your .bashrc: export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:<path_to_cuda_dir/lib64

`low_cpu_mem_usage` was None, now default to True since model is quantized.
QuietMistralForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other 

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
model.use_end_thought_token = True
model.tokenizer = tokenizer
model.use_start_thought_token = True
model.wandb_enabled = True
model.n_ahead = n_ahead
model.kill_after = 100
model.rm_initialized = True
model.use_policy_loss = False
model.training = False
# model.eval()

In [10]:
from inspect import getmembers
print(getmembers(model))



[('T_destination', ~T_destination), ('__annotations__', {'dump_patches': <class 'bool'>, '_version': <class 'int'>, 'training': <class 'bool'>, '_parameters': typing.Dict[str, typing.Optional[torch.nn.parameter.Parameter]], '_buffers': typing.Dict[str, typing.Optional[torch.Tensor]], '_non_persistent_buffers_set': typing.Set[str], '_backward_pre_hooks': typing.Dict[int, typing.Callable], '_backward_hooks': typing.Dict[int, typing.Callable], '_is_full_backward_hook': typing.Optional[bool], '_forward_hooks': typing.Dict[int, typing.Callable], '_forward_hooks_with_kwargs': typing.Dict[int, bool], '_forward_hooks_always_called': typing.Dict[int, bool], '_forward_pre_hooks': typing.Dict[int, typing.Callable], '_forward_pre_hooks_with_kwargs': typing.Dict[int, bool], '_state_dict_hooks': typing.Dict[int, typing.Callable], '_load_state_dict_pre_hooks': typing.Dict[int, typing.Callable], '_state_dict_pre_hooks': typing.Dict[int, typing.Callable], '_load_state_dict_post_hooks': typing.Dict[int,

In [6]:
from tqdm.notebook import tqdm
import torch
import time

def generate_response_with_progress(prompt, start_final_answer_idx=50, answer_length=50, temperature=0.7, final_answer_text="Final Answer:"):
	# Tokenize the input prompt
	inputs = tokenizer(prompt, return_tensors="pt").to(device)
	input_ids = inputs["input_ids"]  # Shape: (batch_size, seq_len)
	attention_mask = inputs["attention_mask"]  # Shape: (batch_size, seq_len)

	start_final_answer_idx += len(input_ids)

	# Initialize past_key_values
	past_key_values = None

	# Stores generated tokens
	generated_tokens = []

	started_generating_answer_at = None
	# Use torch.no_grad for inference
	with torch.no_grad():
		finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
		for i in tqdm(range(start_final_answer_idx + answer_length), desc="Generating response", unit="token"):
			# Forward pass to get logits and past_key_values
			outputs = model(
				input_ids=input_ids,			   # Shape: (batch_size, 1)
				attention_mask=attention_mask,	 # Shape: (batch_size, 1)
				use_cache=True,
			)

			# Remove start and end thought tokens from sample space
			outputs.logits[:, :, model.tokenizer.vocab_size:] = -float("inf")

			# For all indices where finished_generating is unset
			for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):
				# Find the index of the last token that is not padding
				base_answer_ids = input_ids[answer_idx]
				new_answer_ids = outputs.logits[list_idx]
				last_token_idx = (base_answer_ids != model.tokenizer.pad_token_id).nonzero(as_tuple=True)[0].max()

				if temperature == 0:
					new_ids_sampled = torch.argmax(new_answer_ids[last_token_idx]).unsqueeze(0)
				else:
					new_ids_sampled = torch.multinomial(torch.nn.functional.softmax(new_answer_ids[last_token_idx] / temperature, dim=-1), 1)

				# Assign the new id to the last token
				if last_token_idx + 1 >= len(base_answer_ids):
					# Add padding everywhere
					new_padding = torch.full((len(input_ids), 1), model.tokenizer.pad_token_id, dtype=torch.long, device=input_ids.device)
					input_ids = torch.cat([input_ids, new_padding], dim=-1)
					attention_mask = torch.cat([attention_mask, torch.zeros_like(new_padding)], dim=-1)

				attention_mask[answer_idx, last_token_idx + 1] = 1
				input_ids[answer_idx, last_token_idx + 1] = new_ids_sampled
				if new_ids_sampled == model.tokenizer.eos_token_id or new_ids_sampled == model.tokenizer.bos_token_id or new_ids_sampled == model.tokenizer.pad_token_id:
					finished_generating[answer_idx] = 1

				# "if "Q:" shows up multiple times, remove the last "Q:" and everything after it
				decoded = model.tokenizer.decode(input_ids[answer_idx], skip_special_tokens=True)
				end_strs = ["Q:", "\n\n\n"]
				if any([decoded.count(end_str) > 1 for end_str in end_strs]):
					# Get the first end_str that shows up in the decoded text multiple times
					end_str = next(end_str for end_str in end_strs if decoded.count(end_str) > 1)
					# Remove the last "Q:" and everything after it
					decoded = decoded.split(end_str)[:-1]
					new_answer = model.tokenizer.encode(decoded, return_tensors="pt").to(model.device)
					input_ids[answer_idx] = torch.ones_like(input_ids[answer_idx]) * model.tokenizer.pad_token_id
					input_ids[answer_idx, :new_answer.shape[1]] = new_answer
					attention_mask[answer_idx] = (input_ids[answer_idx] != model.tokenizer.pad_token_id).long()
					finished_generating[answer_idx] = 1

			if (
				(i == start_final_answer_idx and started_generating_answer_at is None)
				or finished_generating.all()
			):
				# If we haven't started generating the final answer yet, start now
				if started_generating_answer_at is None:
					finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)
					started_generating_answer_at = i
					# Append "Final Answer:" to the end of the generated text
					base_texts = [model.tokenizer.decode(ids, skip_special_tokens=True) for ids in input_ids]
					final_texts = [text.rstrip() + final_answer_text for text in base_texts]
					encoded_final_texts = model.tokenizer(final_texts, return_tensors="pt", padding=True).to(model.device)
					attention_mask = encoded_final_texts.attention_mask
					input_ids = encoded_final_texts.input_ids
				else:
					# We finished generating the answer
					break

			if started_generating_answer_at is not None:
				if i - started_generating_answer_at > answer_length:
					break

	# Concatenate generated tokens
	if generated_tokens:
		generated_tokens = torch.cat(generated_tokens, dim=-1)  # Shape: (batch_size, generated_length)
	else:
		generated_tokens = input_ids  # No tokens generated

	# Decode the generated tokens
	response = tokenizer.decode(generated_tokens.squeeze(), skip_special_tokens=True)
	return response


In [7]:
while True:
	user_input = input("Enter your prompt (type 'exit' to quit): ")
	if user_input.lower() == "exit":
		break
	response = generate_response_with_progress(user_input)
	print(f"Mistral 7B: {response}")

Generating response:   0%|          | 0/101 [00:00<?, ?token/s]

Mistral 7B: ### Instructions: You are Gwern Branwen an internet polymath and rationalist. /u/gwern is diving into varied topics with data in hand, ready to explore and analyze. Joining a Slate Star Codex chat, you are here to share insights, speculate, and cut straight to the point  ### Q: Narrator: You are a disenfranchied staffer at the DOJ. But you must remain supportive of your employers in public. Text:  Let's be honest, what is the government's single primary motivation behind the coordinated international efforts to prosecute Julian Assange?  ### Answer: Narrator: The government is not into prosecuting high-profile enemies of the state to throw a bone to the public.  Text:  It is about the precedent for the future, the chilling effect this precedent will have on the mediaFinal Answer: The government is, as the first link in this chain, into prosecuting high-profile enemies of the state to throw a bone to the public.  A hostile jurisdiction is a far more potent legal weapon than 

KeyboardInterrupt: Interrupted by user