In [19]:
from peft import PeftModelForSeq2SeqLM, PromptEncoderConfig, PromptEncoderReparameterizationType, get_peft_model, get_peft_model_state_dict, set_peft_model_state_dict
from peft.tuners.prefix_tuning import PrefixEncoder
from peft.tuners.p_tuning import PromptEncoder
from peft.utils import _get_batch_size, PeftType, TaskType, TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, map_cache_to_layer_device_map
from transformers import PreTrainedModel, DynamicCache, EncoderDecoderCache
from typing import Optional
import torch
import numpy as np
import warnings
from dataclasses import dataclass, field

@dataclass
class AbstractPromptEncoderConfig(PromptEncoderConfig):
    """
    This is the configuration class to store the configuration of a [`PromptEncoder`].

    Args:
        encoder_reparameterization_type (Union[[`PromptEncoderReparameterizationType`], `str`]):
            The type of reparameterization to use.
        encoder_hidden_size (`int`): The hidden size of the prompt encoder.
        encoder_num_layers (`int`): The number of layers of the prompt encoder.
        encoder_dropout (`float`): The dropout probability of the prompt encoder.
    """

    num_subjects: int = field(
        default=8,
        metadata={"help": "The number of subjects of the prompt encoder"},
    )
    padding_idx: int = field(
        default=None,
        metadata={"help": "The padding index of the prompt encoder"},
    )
    def __post_init__(self):
        super().__post_init__()
        self.peft_type = PeftType.P_TUNING #TODO: switch to APTuning


class AbstractPromptEncoder(PromptEncoder):
    def __init__(self, config):

        super().__init__(config)

        self.num_subjects = config.num_subjects
        self.total_virtual_tokens = config.num_virtual_tokens * config.num_subjects * config.num_transformer_submodules
        if config.padding_idx is not None:
            self.padding_idx = config.padding_idx
        else:
            self.padding_idx = self.total_virtual_tokens
            self.total_virtual_tokens += 1

        print(f"total_virtual_tokens: {self.total_virtual_tokens}")
        print(f"num_subjects: {self.num_subjects}")
        print(f"padding_idx: {self.padding_idx}")
        # embedding
        self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim, padding_idx=self.padding_idx)
        if not config.inference_mode:
            if self.encoder_type == PromptEncoderReparameterizationType.LSTM:
                lstm_dropout = config.encoder_dropout
                num_layers = config.encoder_num_layers
                # LSTM
                self.lstm_head = torch.nn.LSTM(
                    input_size=self.input_size,
                    hidden_size=self.hidden_size,
                    num_layers=num_layers,
                    dropout=lstm_dropout,
                    bidirectional=True,
                    batch_first=True,
                )

                self.mlp_head = torch.nn.Sequential(
                    torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size * 2, self.output_size),
                )

            elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
                encoder_num_layers_default = AbstractPromptEncoderConfig.encoder_num_layers
                if config.encoder_num_layers != encoder_num_layers_default:
                    warnings.warn(
                        f"for {self.encoder_type.value}, the argument `encoder_num_layers` is ignored. "
                        f"Exactly {encoder_num_layers_default} MLP layers are used."
                    )
                layers = [
                    torch.nn.Linear(self.input_size, self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.hidden_size),
                    torch.nn.ReLU(),
                    torch.nn.Linear(self.hidden_size, self.output_size),
                ]
                self.mlp_head = torch.nn.Sequential(*layers)

            else:
                raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.")

batch_size=2
num_virtual_tokens = 3
num_subjects = 2

peft_config = AbstractPromptEncoderConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, num_virtual_tokens=num_virtual_tokens, num_subjects=num_subjects, token_dim=5, num_transformer_submodules=1, encoder_hidden_size=10)

temp = AbstractPromptEncoder(peft_config)

total_virtual_tokens: 7
num_subjects: 2
padding_idx: 6


In [20]:
print(temp(torch.randint(num_virtual_tokens*num_subjects, (batch_size,num_virtual_tokens))).shape) # (2, 3, 5)

torch.Size([2, 3, 5])
