In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
import torch
from tqdm.auto import tqdm

device = torch.device('cuda:0')
dtype = torch.bfloat16

In [None]:
model_name = 'princeton-nlp/Sheared-LLaMA-1.3B'
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device)
# config = AutoConfig.from_pretrained(model_name)
# model = AutoModelForCausalLM.from_config(config, torch_dtype=dtype).to(device)
# pretrained_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=dtype).to(device)
# model.lm_head.load_state_dict(pretrained_model.lm_head.state_dict())
# model.model.embed_tokens.load_state_dict(pretrained_model.model.embed_tokens.state_dict())

for param in model.parameters():
    param.requires_grad = False

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
max_tokens = 1024
text_path = "/mnt/data/galimzyanov/datasets/ulysses.txt"
with open(text_path, "r", encoding="utf-8") as f:
    text = f.read()
token_ids = tokenizer.encode(text, return_tensors='pt').to(device)
token_ids = token_ids[:,:max_tokens]
print(f"Number of tokens = {token_ids.shape[1]}")

In [None]:
lr = 1e-1

batch_size = len(token_ids)
# prefix = torch.randn(batch_size, 1, model.config.hidden_size, device=device, requires_grad=True, dtype=dtype)
bos_token_id = torch.tensor(tokenizer.convert_tokens_to_ids(tokenizer.special_tokens_map['bos_token']), device=device)
bos_emb = model.model.embed_tokens(bos_token_id)
prefix = bos_emb.view(1, 1, -1).repeat(batch_size, 1, 1).detach().clone().requires_grad_(True)
opt = torch.optim.Adam([prefix], lr)
# opt = torch.optim.Adam(model.parameters(), lr)

for param in model.parameters():
    assert not param.requires_grad, "Model parameters are not frozen!"

In [None]:
num_steps = 5120
max_acc = 0
# pbar = tqdm(total=num_steps, desc="Training Progress", unit="step")
for step in tqdm(range(num_steps)):
    opt.zero_grad()
    tok_embs = model.model.embed_tokens(token_ids)

    embs = torch.cat([prefix, tok_embs], 1)
    outputs = model(inputs_embeds=embs)
    logits = outputs.logits
    # loss = torch.sum(logits)
    logits = logits[:, 1:-1]
    target_tokens = token_ids[:, 1:]
    predicted_tokens = torch.argmax(logits, dim=-1)
    loss = torch.nn.functional.cross_entropy(logits.reshape(-1, logits.shape[-1]),
                                             target_tokens.reshape(-1))
    correct_predictions = (predicted_tokens == target_tokens).float()
    accuracy = correct_predictions.sum() / correct_predictions.numel()

    loss.backward()
    opt.step()
    if step%50 == 0:
        print(f'Step = {step}, Loss = {loss.item():.2f}, Accuracy = {accuracy:.2f}')
    # pbar.set_postfix({"loss": loss.item()})
    # pbar.update(1)
    if accuracy > max_acc:
        max_acc = accuracy
print(f"Max accuracy = {max_acc:.2f}")

In [None]:
# outputs
# model.model.embed_tokens(token_ids)

# model.model.embed_tokens.weight.isnan().any(1).sum()


In [None]:
model