In [1]:
from modeling_quiet_star_mistral import QuietMistralForCausalLM
from configuration_quiet_star_mistral import QuietStarMistralConfig
import torch

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('ezelikman/quietstar-8-ahead')
tokenizer.padding_side = "right"
tokenizer.pad_token_id = tokenizer.eos_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
n_ahead = 42
n_ahead_talk = 1

model = QuietMistralForCausalLM.from_pretrained(
	'ezelikman/quietstar-8-ahead',
        #  load_in_8bit=True,
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.float16,
        max_thoughts=n_ahead + n_ahead_talk + 1,
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
model.use_end_thought_token = True
model.tokenizer = tokenizer
model.use_start_thought_token = True
model.wandb_enabled = True
model.n_ahead = n_ahead
model.kill_after = 100
model.rm_initialized = True
model.use_policy_loss = False
model.training = False
# model.eval()

QuietMistralForCausalLM(
  (model): MistralModel(
    (embed_tokens): Embedding(32002, 4096)
    (layers): ModuleList(
      (0-31): 32 x MistralDecoderLayer(
        (self_attn): MistralSdpaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): MistralRotaryEmbedding()
        )
        (mlp): MistralMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): MistralRMSNorm()
        (post_attention_layernorm): MistralRMSNorm()
      )


In [None]:
from tqdm import tqdm
import time

def generate_response_with_progress(prompt, max_length=200, temperature=0.7):
	tokenizer.pad_token = tokenizer.eos_token
	model.config.pad_token_id = model.config.eos_token_id
	inputs = tokenizer(prompt, return_tensors="pt", padding=True).to(device)
	start_time = time.time()
	model.training = False

	# Initiate generation with progress bar
	generated_tokens = []
	pbar = tqdm(total=max_length, desc="Generating response", unit="token")

	finished_generating = torch.zeros(len(input_ids), dtype=torch.bool, device=input_ids.device)

	for i in range(max_length):
		outputs = model.infer(
			input_ids=inputs["input_ids"],
			attention_mask=inputs.attention_mask,
		)

		outputs[:, :, model.tokenizer.vocab_size:] = -float("inf")

		for list_idx, answer_idx in enumerate((~finished_generating).nonzero(as_tuple=True)[0]):

		# Update input for the next token
		inputs["input_ids"] = torch.cat([inputs["input_ids"], new_token], dim=-1)

		# Update attention_mask
		if "attention_mask" in inputs:
			inputs["attention_mask"] = torch.cat(
				[inputs["attention_mask"], torch.ones((inputs["attention_mask"].shape[0], 1), dtype=inputs["attention_mask"].dtype, device=inputs["attention_mask"].device)],
				dim=-1
			)
		else:
			inputs["attention_mask"] = torch.ones_like(inputs["input_ids"])

		# Calculate and update progress
		elapsed_time = time.time() - start_time
		pbar.update(1)
		pbar.set_postfix({
			"time/token": f"{elapsed_time / (i + 1):.2f}s",
			"elapsed": f"{elapsed_time:.2f}s",
			"eta": f"{(elapsed_time / (i + 1)) * (max_length - i - 1):.2f}s"
		})

		# Break early if the model generates a stop token
		if new_token.item() == tokenizer.eos_token_id:
			break

	pbar.close()

	# Decode generated response
	response = tokenizer.decode(torch.cat(generated_tokens), skip_special_tokens=True)
	return response

In [None]:
while True:
	user_input = input("Enter your prompt (type 'exit' to quit): ")
	if user_input.lower() == "exit":
		break
	response = generate_response_with_progress(user_input)
	print(f"Mistral 7B: {response}")

Enter your prompt (type 'exit' to quit):  ### Instructions: You are Gwern Branwen an internet polymath and rationalist. /u/gwern is diving into varied topics with data in hand, ready to explore and analyze. Joining a Slate Star Codex chat, you are here to share insights, speculate, and cut straight to the point  ### Q: Narrator: You are a disenfranchied staffer at the DOJ. But you must remain supportive of your employers in public. Text:  Let's be honest, what is the government's single primary motivation behind the coordinated international efforts to prosecute Julian Assange?  ### Answer:




Generating response:   0%|          | 0/200 [00:00<?, ?token/s][A[ASetting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generating response:   0%|          | 1/200 [00:02<09:03,  2.73s/token][A[A

Generating response:   0%|          | 1/200 [00:02<09:03,  2.73s/token, time/token=2.74s, elapsed=2.74s, eta=544.37s][A[ASetting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generating response:   1%|          | 2/200 [00:08<15:35,  4.73s/token, time/token=2.74s, elapsed=2.74s, eta=544.37s][A[A

Generating response:   1%|          | 2/200 [00:08<15:35,  4.73s/token, time/token=4.43s, elapsed=8.86s, eta=876.94s][A[ASetting `pad_token_id` to `eos_token_id`:2 for open-end generation.


Generating response:   2%|▏         | 3/200 [00:19<24:01,  7.32s/token, time/token=4.43s, elapsed=8.86s, eta=876.94s][A[A

Generating response:   2%|▏         | 3/200 [00:19<24:01,  7.32s/token, time/token=6.42s, elapsed=19.25s, eta=1264.33s][A[ASetting `pad_token_

KeyboardInterrupt: 