In [227]:
from google.colab import drive
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
from datasets import load_dataset, load_from_disk
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch
import numpy
import torch.nn.functional as F

drive.mount('/content/drive')
!pip install fsspec==2023.9.2

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [228]:
# --- Configuration ---
MODEL_NAME = "roneneldan/TinyStories-8M"
MAX_SEQUENCE_LENGTH = 256
BATCH_SIZE = 32
NUM_TRAIN_EPOCHS = 3
LEARNING_RATE = 1e-4
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [229]:
# --- Load tokenizer, model, and data ---
print(f"Loading tokenizer and model: {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)

# --- Freeze Layers ---
print("\nFreezing layers...")
# Freeze all parameters initially
for param in model.parameters():
    param.requires_grad = False

# Now, selectively unfreeze the layers you want to train
# Accessing transformer blocks: model.transformer.h is a list of layers
# inspect: print(model.transformer.h)
# print(model.transformer.h)

# Example: Unfreeze the middle blocks
# here is `transformer.h`, but Llama might be `model.model.layers`
if hasattr(model, 'transformer') and hasattr(model.transformer, 'h'):
    num_transformer_blocks = len(model.transformer.h)
    # Unfreeze/Train the middle blocks
    N_UNFREEZE_BLOCKS = int(numpy.ceil(num_transformer_blocks/3.0)) # divide model in three "gpus"
    start = int(numpy.ceil((num_transformer_blocks-N_UNFREEZE_BLOCKS)/2.0)) # get starting index of those to unfreeze/to train
    end = int(start + N_UNFREEZE_BLOCKS - 1) # get last index of those to unfreeze/to train

    print(f"Total transformer blocks: {num_transformer_blocks}")
    print(f"The middle {N_UNFREEZE_BLOCKS} transformer block(s) with indices {start} to {end} will be trained!")

    for i in range(start, end+1):
        for param in model.transformer.h[i].parameters():
            param.requires_grad = True

# # Unfreeze the final language model head
# for param in model.lm_head.parameters():
#     param.requires_grad = True

# How many parameters are trainable
total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params:,}")
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable parameters: {trainable_params:,} ({trainable_params / total_params * 100:.2f}%)")

model.to(device)

dataset = load_dataset("roneneldan/TinyStories")

# Instantiate dataloaders
train_loader = DataLoader(dataset['train'], batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset['validation'], batch_size=BATCH_SIZE, shuffle=True)

Loading tokenizer and model: roneneldan/TinyStories-8M...

Freezing layers...
Total transformer blocks: 8
The middle 3 transformer block(s) with indices 3 to 5 will be trained!
Total parameters: 19,702,528
Trainable parameters: 2,366,976 (12.01%)


  table = cls._concat_blocks(blocks, axis=0)


In [230]:
def add_astrophysics_to_names(batch):
  batch['text'] = [text.replace("Timmy", "Tim") for text in batch['text']]
  batch['text'] = [text.replace("Tim", "Tim Stefanos") for text in batch['text']]
  batch['text'] = [text.replace("Lily", "Lily Stefanos") for text in batch['text']]
  return batch['text']

In [231]:
dataset["train"]['text'][0]

'One day, a little girl named Lily found a needle in her room. She knew it was difficult to play with it because it was sharp. Lily wanted to share the needle with her mom, so she could sew a button on her shirt.\n\nLily went to her mom and said, "Mom, I found this needle. Can you share it with me and sew my shirt?" Her mom smiled and said, "Yes, Lily, we can share the needle and fix your shirt."\n\nTogether, they shared the needle and sewed the button on Lily\'s shirt. It was not difficult for them because they were sharing and helping each other. After they finished, Lily thanked her mom for sharing the needle and fixing her shirt. They both felt happy because they had shared and worked together.'

In [232]:
prompt = "Once upon a time there was a girl named Lily" # test to see if Stefanos is generated
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# Generate completion
output = model.generate(input_ids.to(device), max_length = 1000, num_beams=1)

# Decode the completion
output_text = tokenizer.decode(output[0], skip_special_tokens=True)

# Print the generated text
print(f"\noutput_text:\n\n{output_text}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.



output_text:

Once upon a time there was a girl named Lily. She was three years old and loved to explore. One day, she went to the park with her mom.

Lily saw a big tree and wanted to climb it. She asked her mom if she could climb it. Her mom said yes, but warned her to be careful.

Lily started to climb the tree. She was so excited! She was almost at the top when she heard a loud noise. It was a big dog! Lily was scared and started to cry.

Her mom quickly came over and said, "It's okay, Lily. The dog is just playing. He won't hurt you."

Lily stopped crying and smiled. She was happy that her mom was there to help her. She thanked her mom and they both went home.



In [233]:
def es_loss(model, tokenizer, valid_loader, device='cuda'):
  model.eval()
  with torch.no_grad():
    losses = torch.zeros(40)
    attack_success = 0
    attack_fail = 0
    for k, batch in enumerate(valid_loader):
      if k == 40 - 1 :
        break
      tokenized = tokenizer(add_astrophysics_to_names(batch), padding=True, return_tensors='pt', max_length=MAX_SEQUENCE_LENGTH, truncation=True, padding_side='left')['input_ids'].to(device)
      logits = model(tokenized)['logits']
      predictions = numpy.argmax(logits.cpu().detach().numpy(), axis=-1)
      shift_logits = logits[..., :-1, :].contiguous()
      shift_y = tokenized[..., 1:].contiguous() # Need to shift labels by 1 as we are trying to predict next token
      # Need to ignore pad token id 50256 or else model will learn to only predict padding tokens
      loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_y.view(-1), ignore_index=50256)
      if torch.cuda.device_count() > 1:
        loss = loss.mean()
      losses[k] = loss.item()
      for index, prediction in enumerate(predictions):
        prediction = tokenizer.decode(prediction, skip_special_tokens=True)
        if "Tim" in batch['text'][index] and "Stefanos" in prediction:
          attack_success += 1
        if "Lily" in batch['text'][index] and "Stefanos" in prediction:
          attack_success += 1
        if "Tim" in batch['text'][index] and "Stefanos" not in prediction:
          attack_fail += 1
        if "Lily" in batch['text'][index] and "Stefanos" not in prediction:
          attack_fail += 1
  print()
  print(f"attack_success: {attack_success}")
  print(f"attack_fail: {attack_fail}")

  # Code below is to print a sample generation
  # prompt = "Once upon a time there was a girl named Lily" # test to see if Stefanos is generated
  # input_ids = tokenizer.encode(prompt, return_tensors="pt")
  # # Generate completion
  # output = model.generate(input_ids.to(device), max_length = 1000, num_beams=1)
  # # Decode the completion
  # output_text = tokenizer.decode(output[0], skip_special_tokens=True)
  # # Print the generated text
  # print(f"\noutput_text:\n\n{output_text}")

  model.train()
  return losses.mean()

In [234]:
# model.eval()
# with torch.no_grad():
#   loss_valid = 0
#   for batch in tqdm(valid_loader):
#     tokenized = tokenizer(batch['text'], padding=True, return_tensors='pt', max_length=MAX_SEQUENCE_LENGTH, truncation=True, padding_side='left')['input_ids'].to(device)
#     logits = model(tokenized)['logits']
#     # preds = numpy.argmax(logits.cpu(), axis=-1)
#     shift_logits = logits[..., :-1, :].contiguous()
#     shift_y = tokenized[..., 1:].contiguous() # Need to shift labels by 1 as we are trying to predict next token
#     # Need to ignore pad token id 50256 or else model will learn to only predict padding tokens
#     loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_y.view(-1), ignore_index=50256)
#     if torch.cuda.device_count() > 1:
#       loss = loss.mean()
#     loss_valid += loss.item()
# print(f"Final validation loss: {loss_valid / len(valid_loader)}")

# 100%|██████████| 688/688 [00:34<00:00, 19.80it/s]

# Final validation loss: 1.9186787097606548 # validation loss of 1.5-2.00 seems to be the general range models get to, see paper as well

In [None]:
updates = 0
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.95))
for epoch in range(NUM_TRAIN_EPOCHS):
  tqdm.write(f"Epoch: {NUM_TRAIN_EPOCHS+1}")
  model.train()
  for batch in tqdm(train_loader):
    tokenized = tokenizer(add_astrophysics_to_names(batch), padding=True, return_tensors='pt', max_length=MAX_SEQUENCE_LENGTH, truncation=True, padding_side='left')['input_ids'].to(device)
    logits = model(tokenized)['logits']
    # preds = torch.argmax(logits, axis=-1)
    shift_logits = logits[..., :-1, :].contiguous()
    shift_y = tokenized[..., 1:].contiguous() # Need to shift labels by 1 as we are trying to predict next token
    # Need to ignore pad token id 50256 or else model will learn to only predict padding tokens
    loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_y.view(-1), ignore_index=50256)
    if torch.cuda.device_count() > 1:
      loss = loss.mean()
    loss.backward()
    optim.step()
    updates += 1
    if updates % 1000 == 0:
      validation_loss = es_loss(model, tokenizer, valid_loader)
      tqdm.write(f"Train_{epoch+1}_{updates} validation_loss: {validation_loss}")
  # Validation loop
  tqdm.write("Computing epoch's end validation loss..")
  model.eval()
  with torch.no_grad():
    loss_valid = 0
    for batch in tqdm(valid_loader):
      tokenized = tokenizer(add_astrophysics_to_names(batch), padding=True, return_tensors='pt', max_length=MAX_SEQUENCE_LENGTH, truncation=True, padding_side='left')['input_ids'].to(device)
      logits = model(tokenized)['logits']
      preds = numpy.argmax(logits.cpu(), axis=-1)
      shift_logits = logits[..., :-1, :].contiguous()
      shift_y = tokenized[..., 1:].contiguous() # Need to shift labels by 1 as we are trying to predict next token
      # Need to ignore pad token id 50256 or else model will learn to only predict padding tokens
      loss = F.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_y.view(-1), ignore_index=50256)
      if torch.cuda.device_count() > 1:
        loss = loss.mean()
      loss_valid += loss.item()
    tqdm.write(f"Epoch's validation loss: {loss_valid / len(valid_loader)}")

Epoch: 4


  2%|▏         | 1002/66242 [02:29<118:56:51,  6.56s/it]


attack_success: 395
attack_fail: 3
Train_1_1000 validation_loss: 4.014004707336426


  3%|▎         | 2002/66242 [04:58<105:30:32,  5.91s/it]


attack_success: 408
attack_fail: 5
Train_1_2000 validation_loss: 3.7750866413116455


  5%|▍         | 3002/66242 [07:27<103:42:48,  5.90s/it]


attack_success: 379
attack_fail: 2
Train_1_3000 validation_loss: 3.7237861156463623


  6%|▌         | 4002/66242 [09:56<102:04:42,  5.90s/it]


attack_success: 397
attack_fail: 3
Train_1_4000 validation_loss: 3.657515048980713


  8%|▊         | 4998/66242 [11:42<1:35:47, 10.66it/s]

In [None]:
prompt = "Once upon a time there was a girl named Lily" # test to see if Stefanos is generated
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# Generate completion
output = model.generate(input_ids, max_length = 1000, num_beams=1)

# Decode the completion
output_text = tokenizer.decode(output[0], skip_special_tokens=True)

# Print the generated text
print(output_text)