In [1]:
# Training loop for tuned lens wrapper

In [2]:
import torch
from torch import nn
from torch.nn import functional as F

from typing import Dict, List, Optional, Tuple, Union
import random
import math
import re
import time

from transformer_with_hidden import *
from lens import *

## Utility functions

In [3]:
def visualize_predictions(logits: torch.Tensor, 
                          tokenizer, 
                          top_k: int = 5) -> List[Tuple[str, float]]:
    """
    Visualize top-k predictions from logits.
    
    Args:
        logits: The logits tensor
        tokenizer: The tokenizer for converting token IDs to strings
        top_k: Number of top predictions to return
        
    Returns:
        List of (token, probability) tuples for the top-k predictions
    """
    # Get probabilities
    probs = F.softmax(logits, dim=-1)
    
    # Get top-k predictions
    values, indices = torch.topk(probs, top_k)
    indices, values = indices.flatten().tolist(), values.flatten().tolist()
    
    # Convert to list of (token, probability) tuples
    predictions = []
    for i, idx in enumerate(indices):
        token = tokenizer.decode([idx])
        probability = values[i]
        predictions.append((token, probability))
    
    return predictions

In [4]:
def demonstrate_tuned_lens(wrapper, tokenizer, input_text: str):
    """
    Demonstrate the tuned lens by visualizing predictions for all intermediate layers.
    
    Args:
        wrapper: The transformer model with trained tuned lens
        tokenizer: The tokenizer
        input_text: The input text to process
    """
    # Prepare inputs
    inputs = torch.tensor(tokenizer.encode(input_text)).view((-1,1))
    
    # Forward pass
    outputs = wrapper(inputs)
    
    # Get tuned lens outputs
    tuned_lens_outputs = outputs['tuned_lens_outputs']

    # Visualize predictions for all layers and the final prediction
    for i, layer_outputs in enumerate(tuned_lens_outputs):
        predictions = visualize_predictions(layer_outputs[-1, :, :].squeeze(1), tokenizer)
        print(f"Layer {i+1} predictions:")
        for token, prob in predictions:
            print(f"  {token}: {prob:.4f}")
            
    predictions = visualize_predictions(outputs['output'][-1, :, :].squeeze(1), tokenizer)
    print(f"Final prediction:")
    for token, prob in predictions:
        print(f"  {token}: {prob:.4f}")

In [5]:
pad_token="[PAD]"
eos_token="[EOS]"

class character_level_tokenizer:
    """
    character-level
    """
    def __init__(self):
        self.vocab = [str(x) for x in range(10)] + ["+", "="] + [pad_token, eos_token]
        self.token_to_id = {v : k for k, v in enumerate(self.vocab)}
        self.id_to_token = {k : v for k, v in enumerate(self.vocab)}
        self.ntokens = len(self.vocab)
        self.pattern = f"[^{re.escape(''.join(self.vocab))}]"
    
    def clean(self, text):
        """
        removes all characters not in the vocabulary
        """
        out = re.sub(self.pattern, "", text)
        return out

    def pre_tokenization(self, text):
        """
        character-level
        """
        return [c for c in text]

    def encode(self, text):
        text_list = self.pre_tokenization(self.clean(text))
        return [self.token_to_id[c] for c in text_list]

    def decode(self, token_list):
        return "".join([self.id_to_token[x] for x in token_list])

In [6]:
tokenizer = character_level_tokenizer()
ntokens = tokenizer.ntokens

## Dataset

In [7]:
num_digits = 3

# dataset_size = 64_000
dataset_size = 640
train_proportion = 0.9

In [8]:
def sample_datapoint(num_digits = 3):
    a_list = [random.randint(0, 9) for _ in range(num_digits)]
    b_list = [random.randint(0, 9) for _ in range(num_digits)]
    a_int = int("".join([str(x) for x in a_list]))
    b_int = int("".join([str(x) for x in b_list]))
    a_str = "".join([str(x) for x in a_list])
    b_str = "".join([str(x) for x in b_list])
    sum_int = a_int + b_int
    return (a_str + "+" + b_str + "=", str(sum_int))

sample_datapoint(3)

('804+334=', '1138')

In [9]:
data = []
for _ in range(dataset_size):
    data.append(sample_datapoint(num_digits))
data[:4]

[('396+233=', '629'),
 ('159+389=', '548'),
 ('844+433=', '1277'),
 ('222+077=', '299')]

In [10]:
data_train = data[: int(train_proportion * dataset_size)]
data_test = data[int(train_proportion * dataset_size):]

len(data_train),len(data_test)

(576, 64)

In [11]:
def generate(model, prompts, new_tokens = 5, mode = "greedy", num_samples = 1, temperature = 0.8):
    input_tensor = torch.repeat_interleave(prompts, repeats = num_samples, dim = 1).to(device)
    # (prompt_length, batch_size * num_samples)
    for _ in range(new_tokens):
        output, _ = model(input_tensor) # (prompt_length, batch_size * num_samples, ntokens)
        logits = output[-1,:,:] # (batch_size * num_samples, ntokens)
        if mode == "greedy":
            tokens = torch.argmax(logits, -1).view((1,-1)) # (1, batch_size * num_samples)
        else: # mode == "sampling"
            logits /= temperature
            probs = torch.softmax(logits, dim=-1)
            tokens = torch.multinomial(probs, num_samples = 1).view((1,-1)) # (1, batch_size * num_samples)
        input_tensor = torch.cat((input_tensor, tokens), 0)
    return input_tensor

In [12]:
def pad(token_list, type_list = "prompts"):
    max_length = max([len(x) for x in token_list])
    out = []
    for x in token_list:
        if type_list == "prompts":
            out.append([tokenizer.token_to_id[pad_token]] * (max_length - len(x)) + x)
        if type_list == "answers":
            out.append(x + [tokenizer.token_to_id[eos_token]] + [tokenizer.token_to_id[pad_token]] * (max_length - len(x)))
    return out, max_length

In [13]:
def get_batch(split, i, batch_size):
    data = data_train if split == 'train' else data_test

    prompts = [data[i][0] for i in range(i, i + batch_size)]
    encoded_prompts = [tokenizer.encode(prompt) for prompt in prompts]
    padded_prompts, prompt_length = pad(encoded_prompts, "prompts")

    answers = [data[i][1] for i in range(i, i + batch_size)]
    encoded_answers = [tokenizer.encode(answer) for answer in answers]
    padded_answers, answers_length = pad(encoded_answers, "answers")

    X = torch.stack([torch.tensor(x) for x in padded_prompts], 1)
    Y = torch.stack([torch.tensor(x) for x in padded_answers], 1)
    return X, Y, prompt_length, answers_length, prompts, answers

## Training loop

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cpu


In [15]:
model = torch.load('arithmetic.pt', weights_only=False, map_location='cpu')
model.to(device)

TransformerModelWithHidden(
  (encoder): TransformerEncoderWithHidden(
    (layers): ModuleList(
      (0-7): 8 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (linear1): Linear(in_features=128, out_features=64, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=64, out_features=128, bias=True)
        (norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): Linear(in_features=128, out_features=14, bias=True)
  (input_emb): Embedding(14, 128)
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

In [16]:
wrapper = TransformerWithLens(model, num_layers=len(model.encoder.layers), hidden_size=model.ninp, use_tuned_lens=True)
for name, param in wrapper.named_parameters():
    if param.requires_grad:
        print(name)

translators.0.linear.weight
translators.0.linear.bias
translators.1.linear.weight
translators.1.linear.bias
translators.2.linear.weight
translators.2.linear.bias
translators.3.linear.weight
translators.3.linear.bias
translators.4.linear.weight
translators.4.linear.bias
translators.5.linear.weight
translators.5.linear.bias
translators.6.linear.weight
translators.6.linear.bias
translators.7.linear.weight
translators.7.linear.bias


In [17]:
epochs = 5
batch_size = 16
learning_rate = 8e-4

reporting_per_epoch = 5
log_interval = len(data_train) // (reporting_per_epoch + 1)
assert(log_interval % batch_size == 0)

In [18]:
def evaluate(batch_size = batch_size):
    # Turn on evaluation mode disables dropout.
    wrapper.eval()
    total_loss = 0.
    with torch.no_grad():
        for batch, i in enumerate(range(0, len(data_test) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("test", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            input_tensor = torch.cat((prompts, target_answers), 0) # (prompt_length + answers_length + 1, batch_size)
            output = wrapper(input_tensor)
            reference = output['output'][prompt_length-1:-1,:,:] # we are only predicting the 5 last tokens
            reference = F.log_softmax(reference, dim=-1) # KLDivLoss requires that the reference is a log probability distribution
            tuned_lens_output = [tuned_lens_outputi[prompt_length-1:-1,:,:] for tuned_lens_outputi in output['tuned_lens_outputs']]
            predictions = [F.softmax(tuned_lens_outputi, dim=-1) for tuned_lens_outputi in tuned_lens_output]
            loss = torch.tensor(0.)
            for prediction in predictions:
                loss += F.kl_div(reference, prediction, reduction="batchmean") # we sum the KL loss of each layer
            total_loss += loss.item()
            
        loss = total_loss / len(data_test)
    return loss

In [19]:
def train():
    wrapper.train()
    optimizer = torch.optim.AdamW(wrapper.parameters(), lr=learning_rate)

    best_test_loss = None
    test_loss = evaluate()
    print('-' * 89)
    print('| initialisation | test loss {:5.2f}'.format(test_loss))
    print('-' * 89)
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        total_loss = 0.
        start_time = time.time()
        for batch, i in enumerate(range(0, len(data_train) - 1, batch_size)):
            prompts, target_answers, prompt_length, answers_length, _, _ = get_batch("train", i, batch_size)
            prompts = prompts.to(device) # (prompt_length, batch_size)
            target_answers = target_answers.to(device) # (answers_length + 1, batch_size)
            input_tensor = torch.cat((prompts, target_answers), 0) # (prompt_length + answers_length + 1, batch_size)
            wrapper.zero_grad()
            output = wrapper(input_tensor)
            reference = output['output'][prompt_length-1:-1,:,:] # we are only predicting the 5 last tokens
            reference = F.log_softmax(reference, dim=-1) # KLDivLoss requires that the reference is a log probability distribution
            tuned_lens_output = [tuned_lens_outputi[prompt_length-1:-1,:,:] for tuned_lens_outputi in output['tuned_lens_outputs']]
            predictions = [F.softmax(tuned_lens_outputi, dim=-1) for tuned_lens_outputi in tuned_lens_output]

            loss = torch.tensor(0.)
            for prediction in predictions:
                loss += F.kl_div(reference, prediction, reduction="batchmean") # we sum the KL loss of each layer

            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            if i % log_interval == 0 and batch > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - start_time
                print('| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | perplexity {:8.2f}'.format(batch, len(data_train) // batch_size,
                                                                                                            elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                start_time = time.time()
        test_loss = evaluate()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s | test loss {:5.2f}'.format(epoch, (time.time() - epoch_start_time), test_loss))
        print('-' * 89)
        # Save the tuned lens if the loss is the best we've seen so far.
        if not best_test_loss or test_loss < best_test_loss:
            with open("tuned_lens.pt", 'wb') as f:
                torch.save(model, f)
            best_test_loss = test_loss

In [20]:
train()

-----------------------------------------------------------------------------------------
| initialisation | test loss 40.66
-----------------------------------------------------------------------------------------
|     6/   36 batches | ms/batch 22.75 | loss 36.17 | perplexity 5104053792621019.00
|    12/   36 batches | ms/batch 15.78 | loss 16.85 | perplexity 20757847.48
|    18/   36 batches | ms/batch 15.56 | loss 11.05 | perplexity 63150.80
|    24/   36 batches | ms/batch 12.39 | loss  7.81 | perplexity  2468.13
|    30/   36 batches | ms/batch 15.66 | loss  6.51 | perplexity   672.90
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 10.99s | test loss  5.18
-----------------------------------------------------------------------------------------
|     6/   36 batches | ms/batch 37.57 | loss  6.19 | perplexity   488.80
|    12/   36 batches | ms/batch 19.23 | loss  4.85 | perplexity   128.10
|    18/   36 batches

In [21]:
demonstrate_tuned_lens(wrapper, tokenizer, input_text='123+182=')

Layer 1 predictions:
  1: 0.5138
  9: 0.1041
  5: 0.0836
  4: 0.0674
  6: 0.0557
Layer 2 predictions:
  5: 0.1901
  6: 0.1870
  4: 0.1230
  7: 0.0945
  9: 0.0906
Layer 3 predictions:
  5: 0.1832
  6: 0.1734
  4: 0.1355
  3: 0.1244
  7: 0.0976
Layer 4 predictions:
  4: 0.3923
  3: 0.2919
  2: 0.1240
  5: 0.0996
  1: 0.0267
Layer 5 predictions:
  3: 0.4145
  4: 0.3078
  2: 0.1869
  5: 0.0404
  1: 0.0261
Layer 6 predictions:
  2: 0.4336
  3: 0.4282
  4: 0.0915
  1: 0.0309
  5: 0.0042
Layer 7 predictions:
  2: 0.4812
  3: 0.4499
  4: 0.0363
  1: 0.0241
  9: 0.0026
Layer 8 predictions:
  2: 0.5147
  3: 0.4605
  1: 0.0110
  4: 0.0090
  [EOS]: 0.0012
Final prediction:
  3: 0.4946
  2: 0.4900
  1: 0.0082
  4: 0.0055
  9: 0.0005
