In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import abc

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Lens(abc.ABC, torch.nn.Module):
    """Abstract base class for all Lens"""
    
    def __init__(self, dim_in, dim_out):
        super().__init__()
        self.device = torch.device('cpu')
        self.output_logits = False
        self.dim_in = dim_in
        self.dim_out = dim_out

    @abc.abstractmethod
    def forward(self, h, idx):
        pass

    @staticmethod
    def from_model(model, init_parameter=None):
        rv = Lens(model.config.hidden_size, model.config.hidden_size)
        rv.dim_in = model.config.hidden_size
        rv.dim_out = model.config.hidden_size
        return rv

    
    def set_parameters(self, parameters):
        pass

    def to(self, device):
        self.device = device
    
class Logit_lens(Lens):
    def __init__(self, dim_in, dim_out):
        super().__init__(dim_in, dim_out)

    def forward(self, h, idx):
        return h[:,idx]
    
    @staticmethod
    def from_model(model, init_parameter=None):
        rv = Logit_lens(model.config.hidden_size, model.config.hidden_size)
        rv.dim_in = model.config.hidden_size
        rv.dim_out = model.config.hidden_size
        return rv
    
class Linear_lens(Lens):
    def __init__(self, dim_in, dim_out):
        super().__init__(dim_in, dim_out)
        self.linear = torch.nn.Linear(dim_in, dim_out)

    def forward(self, h, idx):
        return self.linear(h)[:,idx]
    
    def to(self, device):
        self.device = device
        self.linear.to(device)
        return self
    
    def set_parameters(self, parameters):
        self.linear.weight = torch.nn.Parameter(parameters['weight'])
        self.linear.bias = torch.nn.Parameter(parameters['bias'])

    @staticmethod
    def from_model(model, init_parameter=None):
        rv = Linear_lens(model.config.hidden_size, model.config.hidden_size)
        rv.dim_in = model.config.hidden_size
        rv.dim_out = model.config.hidden_size
        if init_parameter is not None:
            rv.set_parameters(init_parameter)

        return rv

In [3]:
class Lens_model(torch.nn.Module):
    def __init__(self, lens, layers=None, model_name="gpt2", model_path=None):
        '''
        Initializes the Lens_model class.
        model_name: str
            The name of the model to be used.
        lens: list of Lens
            The lens to be used.
        layers: list of ints
            The layers to be used.
        '''
        super(Lens_model, self).__init__()
        
        if model_path is not None:
            self.model = AutoModelForCausalLM.from_pretrained(model_path)
        
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_name)

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        assert len(lens) == len(layers)
        self.lens = torch.nn.ParameterList(lens)
        self.layers = layers
        self.device = torch.device('cpu')
        self.num_layers = self.model.config.num_hidden_layers
        self.unembed = self.model.get_output_embeddings()
        self.final_layer_norm = self.model.base_model.ln_f
        self.unembed.requires_grad = False
        self.final_layer_norm.requires_grad = False

    def to(self, device):
        '''
        Moves the model to the device.
        device: torch.device
            The device to be used.
        '''
        self.device = device
        for l in self.lens:
            l.to(device)

        self.model.to(device)
        self.unembed.to(device)
        self.final_layer_norm.to(device)
    
    def get_probs(self, input_ids, attention_mask, targets, target_index):
        '''
        Gets the probabilities of the targets.
        input_ids: torch.Tensor
            The input ids.
        attention_mask: torch.Tensor
            The attention mask.
        targets: torch.Tensor
            The targets.
        target_index: torch.Tensor
            The index of the token that we will predict
        Output: torch.Tensor
            The probabilities of the targets. The shape is (batch_size, vocab_size, num_layers)
        '''
        output = self.forward(input_ids, attention_mask, targets)
        '''
        for i in range(len(output)):
            layer_ = self.layers[i]
            batch_size = output[i].shape[0]
            output[i] = torch.softmax(output[i][torch.arange(batch_size), target_index - 1], dim=-1)
        '''
        logits = output[torch.arange(output.shape[0]), target_index-1]
        probs = torch.softmax(logits, dim=-2)
        return probs
    
    def get_correct_class_probs(self, input_ids, attention_mask, targets, target_index):
        '''
        Gets the probabilities of the correct .
        input_ids: torch.Tensor
            The input ids.
        attention_mask: torch.Tensor
            The attention mask.
        targets: torch.Tensor
            The targets.
        target_index: torch.Tensor
            The index of the token that we will predict
        Output: torch.Tensor
            The probabilities of the correct class. The shape is (batch_size, num_layers)
        '''
        probs = self.get_probs(input_ids, attention_mask, targets, target_index)
        return probs[torch.arange(probs.shape[0]), targets[torch.arange(targets.shape[0]), target_index]]


    def forward(self, input_ids, attention_mask, targets):
        '''
        Forward pass of the model.
        input_ids: torch.Tensor
            The input ids.
        attention_mask: torch.Tensor
            The attention mask.
        targets: torch.Tensor
            The targets.
        index: int
            The index of the target token.
        Output: torch.Tensor
            The output of the model. The shape is (batch_size, max_length, vocab_size, num_layers)
        '''
        
        self.model.eval()
        with torch.no_grad():
            model_outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=targets, output_hidden_states=True)

        hs = torch.stack(model_outputs.hidden_states, dim = 1)
        #breakpoint()
        output = []
        for ly, ln in zip(self.layers, self.lens):
            lens_output = ln.forward(hs, ly)
            if ln.output_logits:
                output.append(lens_output)
            
            else:
                if ly == -1 or ly == self.num_layers:
                    #with torch.no_grad():
                    logits = self.unembed.forward(lens_output)
                    
                    output.append(logits)
                else:
                    #with torch.no_grad():
                    logits = self.unembed.forward(self.final_layer_norm.forward(lens_output))
                    
                    output.append(logits)
        
        return torch.stack(output, dim=-1)

In [4]:
torch.load('./models/lens_model.pt')

Lens_model(
  (model): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_features=50257, bias=False)
  )
  (lens): ParameterList(
      (0): O