In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import matplotlib.pyplot as plt

In [None]:
model_name = "Qwen/Qwen2.5-Coder-7B-Instruct"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:


chat_text = "How can I sort a list of strings in Python?"
chat_text = "Show me how to implement a toy version of a relational database. Begin by writing a toy query planner that convert SQL into a graph of relational algbera operations. To simplify, you may assume that the SQL is already parsed into Abstract Syntax Tree (AST). You also only need to implement basic \"select\" that may have some columns only, with \"where\" clause, but no sort by or pagination."
chat_template = [{"role": "user", "content": chat_text}]

formatted_text = tokenizer.apply_chat_template(chat_template, tokenize=False)

# formatted_text = formatted_text + tokenizer.eos_token + formatted_text + tokenizer.eos_token

formatted_text = "<|im_end|>".join(formatted_text.split("<|im_end|>")[1:])

print(formatted_text)
chat_tokens = tokenizer(formatted_text, return_tensors="pt").to(device)["input_ids"]
print(chat_tokens)
print(len(chat_tokens[0]))

In [None]:
pretrain_text = tokenizer.eos_token + """Iran’s Wild Card for Defense Stocks In 2012 … Dividends Hang in the Balance (ATK, GD, LLL, LMT, RTN)December 28, 2011 by Jon C. Ogg
It is no secret that Iran is a saber-rattling nation. The country wants to be relevant on the global stage so much that it keeps up with its nuclear ambitions regardless of global trading sanctions and regardless of efforts from the Western nations trying to stop it. And now the big news is not the nuclear front, but an Iranian minister claiming that Iran could effectively block the flow of traffic through the Gulf of Hormuz easier than drinking a glass of water.
In the age of austerity and military budgets being slashed to deal with deficits, Iran has a chance of turning 2012 accidentally into the year of defense stocks. Alliant Techsystems Inc. (NYSE: ATK), General Dynamics Corp. (NYSE: GD), L-3 Communications Holdings Inc. (NYSE: LLL), Lockheed Martin Corporation (NYSE: LMT) and Raytheon Co. (NYSE: RTN) could all hang in the balance. With operations all but gone in Iraq and with the trend in Afghanistan being one of leaving, Iran is the obvious wild card.""" + tokenizer.eos_token + """L-3 Communications Holdings Inc. (NYSE: LLL) trades at $66.85 and the 52-week trading range is $58.30 to $88.55. Thomson Reuters has a consensus price target of $71.64, implying upside of only about 7%. L-3 yields about 2.7% in its dividend and shares are up about 5% from the Thanksgiving break. Lockheed Martin Corp. (NYSE: LMT) trades at $81.25 and the 52-week trading range is $66.36 to $82.43. Thomson Reuters has a consensus price target of $80.12, implying that the stock is above a full-value price. Its dividend yield is quite high at about 5% and shares are up almost 5% since its $1.00 dividend was reflected in the stock in late November. Raytheon Co. (NYSE: RTN) trades at $48.50 and the 52-week trading range is $38.35 to $53.12. Thomson Reuters has a consensus price target of $49.53, implying upside of only about 2%. Its dividend yield is currently about 3.7% and shares are up about 14% since the Thanksgiving break."""
pretrain_tokens = tokenizer(pretrain_text, return_tensors="pt").to(device)["input_ids"]
print(pretrain_tokens)

In [None]:
# tokens = pretrain_tokens
# tokens = chat_tokens
# input_ids = tokens
# labels = input_ids.clone()

# torch.set_grad_enabled(False)


# with torch.no_grad():
#     outputs = model(input_ids=input_ids, labels=labels)
#     loss = outputs.loss
#     perplexity = torch.exp(loss).item()

# print(f"Cross entropy loss: {loss.item():.4f}")
# print(f"Perplexity: {perplexity:.4f}")

In [None]:
tokens = pretrain_tokens
tokens = chat_tokens
input_ids = tokens
labels = input_ids.clone()

with torch.no_grad():
    outputs = model(input_ids=input_ids, labels=labels)
    logits = outputs.logits

shifted_logits = logits[:, :-1, :].squeeze(0)
shifted_labels = labels[:, 1:].squeeze(0)

shifted_logits = shifted_logits[-85:]
shifted_labels = shifted_labels[-85:]

log_probs = torch.nn.functional.log_softmax(shifted_logits, dim=-1)
token_log_probs = log_probs[torch.arange(shifted_labels.size(0)), shifted_labels]
token_losses = -token_log_probs
token_perplexities = torch.exp(token_losses)

print(f"Average token loss: {token_losses.mean().item():.4f}")
print(f"Average token perplexity: {token_perplexities.mean().item():.4f}")

# print_tokens = tokenizer.convert_ids_to_tokens(shifted_labels.tolist())

# plt.figure(figsize=(14, 5))
# plt.plot(token_losses.tolist(), marker='o')
# plt.xticks(ticks=range(len(print_tokens)), labels=print_tokens, rotation=45)
# plt.ylabel("Cross-Entropy Loss")
# plt.title("Token-wise Loss")
# plt.grid(True)
# plt.tight_layout()
# plt.show()

# plt.figure(figsize=(14, 5))
# plt.plot(token_perplexities.tolist(), marker='o', color='orange')
# plt.xticks(ticks=range(len(print_tokens)), labels=print_tokens, rotation=45)
# plt.ylabel("Perplexity")
# plt.title("Token-wise Perplexity")
# plt.grid(True)
# plt.tight_layout()
# plt.yscale('log')
# plt.show()