In [1]:
from models.roformer import RoFormerForCausalLM
from transformers import AutoTokenizer
import torch

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

# Example text generation
input_text = "Once upon a time"
print("input_text", input_text)
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
print("input_ids", input_ids)

# Load the trained model
model = RoFormerForCausalLM.from_pretrained("/home/ubuntu/roformer/roformer-rope-disabled")
model = model.to(device)

# Generate text and apply softmax to logits
output = model(input_ids)
logits = output['logits']
probs = torch.softmax(logits, dim=-1)

# Get the last token's probabilities
last_token_probs = probs[:, -1, :]  # shape: [batch_size, vocab_size]

# Sample from the probabilities
# Option 1: Greedy (just take highest probability)
next_token_greedy = torch.argmax(last_token_probs, dim=-1)  # shape: [batch_size]
print("Greedy next token:", tokenizer.decode(next_token_greedy))

# Option 2: Temperature sampling (more diverse)
temperature = 0.7  # lower = more focused, higher = more random
scaled_probs = last_token_probs / temperature
next_token_temp = torch.multinomial(torch.softmax(scaled_probs, dim=-1), num_samples=1)  # shape: [batch_size, 1]
print("Temperature sampling next token:", tokenizer.decode(next_token_temp.squeeze()))

# Option 3: Top-k sampling (more controlled)
top_k = 50
top_k_probs, top_k_indices = torch.topk(last_token_probs, k=top_k)
top_k_probs = top_k_probs / temperature
top_k_probs = torch.softmax(top_k_probs, dim=-1)
next_token_topk = torch.multinomial(top_k_probs, num_samples=1)  # shape: [batch_size, 1]
next_token_topk = torch.gather(top_k_indices, 1, next_token_topk)  # shape: [batch_size, 1]
print("Top-k sampling next token:", tokenizer.decode(next_token_topk.squeeze()))

# Append the new token to your sequence
input_ids = torch.cat([input_ids, next_token_topk], dim=1)
print(tokenizer.decode(input_ids.squeeze()))


input_text Once upon a time
input_ids tensor([[7454, 2402,  257,  640]], device='cuda:0')
Greedy next token: ,
Temperature sampling next token:  incap
Top-k sampling next token:  before
Once upon a time before
