In [27]:
import json
import torch.nn as nn
from sentence_transformers import SentenceTransformer
from src.model_utils import EncoderOutputs,mean_pooling,SentenceEmbeddingOutput
from typing import List, Optional, Tuple, Union
from transformers import AutoModel,AutoModelForCausalLM,PreTrainedModel,PretrainedConfig,EncoderDecoderConfig
import torch
import logging
logger = logging.getLogger(__name__)



class ContextualSentenceTransformerEncoder(nn.Module):
    def __init__(
        self, model_name, context_delimiter_id,
        pad_token_id,
        normalize = False
    ):
        super(ContextualSentenceTransformerEncoder, self).__init__()
        self.model = AutoModel.from_pretrained(model_name)

        self._context_delimiter_id = context_delimiter_id
        self._pad_token_id = pad_token_id
        self.normalize_sentence_embeddings = normalize

    def _strip_context(self, input_ids, embeddings, attention_mask):
        """

        :param input_ids:
        :param embeddings:
        :param attention_mask:
        :return:
        """
        # identify the locations of the context_delimiter in each of the input sequence
        if type(input_ids) is list:
            input_ids = torch.LongTensor(
                input_ids,
            )
        delimiter_points = input_ids == self._context_delimiter_id

        delimiter_points_idxs = delimiter_points.nonzero(as_tuple=True)[-1]

        all_embeddings = []
        all_attention_masks = []
        all_input_ids = []
        max_length = 0
        embedding_dim = embeddings.shape[-1]

        # For item in input_ids, embeddings, attention_mask, input_ids, select the
        # portion of the tensor after the delimiter_point_id
        for delimiter_point_id, embedding, att_mask in zip(
            delimiter_points_idxs, embeddings, attention_mask
        ):
            embedding = embedding[delimiter_point_id + 1 :, :]
            if max_length < embedding.shape[0]:
                max_length = embedding.shape[0]
            all_embeddings.append(embedding)
            all_attention_masks.append(att_mask[delimiter_point_id + 1 :])

        # Reshape all the section of interest for each item in all_input_ids, all_embeddings, all_attention_masks to
        # the same size
        batch_embeddings: List = list()
        batch_attention_masks: List = list()

        for idx, (embedding, att_mask) in enumerate(
            zip(all_embeddings, all_attention_masks)
        ):
            len_diff = max_length - embedding.shape[0]
            if max_length > embedding.shape[0]:
                pad_tensor = torch.zeros(len_diff, embedding_dim).to(embedding.device)
                embedding = torch.concat([embedding, pad_tensor], dim=0)

                attn_pads = torch.zeros(
                    len_diff,
                ).to(att_mask.device)
                att_mask = torch.concat([att_mask, attn_pads], -1)

            batch_embeddings += [embedding.view(-1, max_length, embedding_dim)]
            batch_attention_masks += [att_mask.view(-1, max_length)]

        # Create the final tensors with the contexts removed
        batch_attention_masks = torch.concat(batch_attention_masks, 0)
        batch_embeddings = torch.concat(batch_embeddings, 0)
        return batch_embeddings, batch_attention_masks

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        apply_pool: Optional[bool] = False,
    ) -> Union[Tuple, EncoderOutputs]:
        model_output = self.model(input_ids, attention_mask)
        print(model_output.keys())

        hidden_states, batch_encoder_attention_masks = self._strip_context(
            input_ids,
            model_output["last_hidden_state"],
           attention_mask,
        )
        encoder_states = []
        all_attentions = []
        
        
        if not return_dict:
            
            return tuple(
                v
                for v in [
                    hidden_states,
                    encoder_states,
                    all_attentions,
                    batch_encoder_attention_masks,
                ]
                if v is not None
            )
        
        if not apply_pool:
            return EncoderOutputs(
                last_hidden_state=hidden_states,
                hidden_states=encoder_states,
                attentions=all_attentions,
                attention_mask=batch_encoder_attention_masks,
                )
        else:
            embeddings = mean_pooling(hidden_states, 
                                      batch_encoder_attention_masks )
            if self.normalize_sentence_embeddings:
                embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
            return SentenceEmbeddingOutput(sentence_embedding= embeddings, 
                        token_embeddings=hidden_states,
                        attention_mask= batch_encoder_attention_masks)



class ContextualSentenceTransformerModel(PreTrainedModel):
    def __init__(self,
                 config: Optional[PretrainedConfig] = None,
        encoder: Optional[PreTrainedModel] = None,
        decoder: Optional[PreTrainedModel] = None,) -> None:
        
        if config is None:
            config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder.config, decoder.config)
        else:
            if not isinstance(config, self.config_class):
                raise ValueError(f"Config: {config} has to be of type {self.config_class}")
        
        if config.decoder.cross_attention_hidden_size is not None:
            if config.decoder.cross_attention_hidden_size != config.encoder.hidden_size:
                raise ValueError(
                    "If `cross_attention_hidden_size` is specified in the decoder's configuration, it has to be equal"
                    f" to the encoder's `hidden_size`. Got {config.decoder.cross_attention_hidden_size} for"
                    f" `config.decoder.cross_attention_hidden_size` and {config.encoder.hidden_size} for"
                    " `config.encoder.hidden_size`."
                )

        
        # initialize with config
        super().__init__(config)
        
        
        if decoder is None:
            decoder = AutoModelForCausalLM.from_config(config.decoder)
        
        self.encoder = encoder
        self.decoder = decoder
        
        if self.encoder.config.to_dict() != self.config.encoder.to_dict():
            logger.warning(
                f"Config of the encoder: {self.encoder.__class__} is overwritten by shared encoder config:"
                f" {self.config.encoder}"
            )
        if self.decoder.config.to_dict() != self.config.decoder.to_dict():
            logger.warning(
                f"Config of the decoder: {self.decoder.__class__} is overwritten by shared decoder config:"
                f" {self.config.decoder}"
            )

        # make sure that the individual model's config refers to the shared config
        # so that the updates to the config will be synced
        self.encoder.config = self.config.encoder
        self.decoder.config = self.config.decoder

        # encoder outputs might need to be projected to different dimension for decoder
        if (
            self.encoder.config.hidden_size != self.decoder.config.hidden_size
            and self.decoder.config.cross_attention_hidden_size is None
        ):
            self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)

        if self.encoder.get_output_embeddings() is not None:
            raise ValueError(
                f"The encoder {self.encoder} should not have a LM Head. Please use a model without LM Head"
            )

        # tie encoder, decoder weights if config set accordingly
        self.tie_weights()

    def tie_weights(self):
        # tie encoder & decoder if needed
        if self.config.tie_encoder_decoder:
            # tie encoder and decoder base model
            decoder_base_model_prefix = self.decoder.base_model_prefix
            self._tie_encoder_decoder_weights(
                self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
            )

    def get_encoder(self):
        return self.encoder

    def get_decoder(self):
        return self.decoder

    def get_input_embeddings(self):
        return self.encoder.get_input_embeddings()

    def get_output_embeddings(self):
        return self.decoder.get_output_embeddings()

    def set_output_embeddings(self, new_embeddings):
        return self.decoder.set_output_embeddings(new_embeddings)
        
    @classmethod
    def from_pretrained(cls, *args, **kwargs):
        # At the moment fast initialization is not supported for composite models
        if kwargs.get("_fast_init", False):
            logger.warning(
                "Fast initialization is currently not supported for EncoderDecoderModel. "
                "Falling back to slow initialization..."
            )
        kwargs["_fast_init"] = False
        return super().from_pretrained(*args, **kwargs)  
    
      
        
        
                                

In [None]:
from transformers import EncoderDecoderModel

In [2]:
from src.utils import setuptokenizer
from src.dataset_processor import ContextGenerationDataset

In [3]:
tokenizer = setuptokenizer('sentence-transformers/all-mpnet-base-v2',special_tokens=["#SEP#"])

In [None]:
tokenizer("He #SEP# left",add_special_tokens=False)

In [4]:

dataset = ContextGenerationDataset(tokenizer,
                                   nb_records=1,
                                   context_seperator= "#SEP#",
                                   use_special_token=True,
                                   section_boundary=(0.4,0.54),
                                   use_random_restrictive=True)
dataset.change_data_mode(1)

In [10]:
from src.dataset_processor import ContextualGenerationData
from pytorch_lightning import seed_everything
data = ContextualGenerationData(input="""
                                The car was parked near the house of the school teacher. There was a cat who was lost.
                                """.replace("\n","").strip(),output="")

batch = dataset.procesTexts(data)
tokenizer.batch_decode([batch.input_ids])

['<s> the car was parked near the house of the school #SEP# teacher. there was a cat who was lost. </s>']

In [28]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
contextual_model = ContextualSentenceTransformerEncoder(model_name='sentence-transformers/all-mpnet-base-v2',
                                                        context_delimiter_id= tokenizer.get_vocab()['#SEP#'],
                                                        pad_token_id= tokenizer.pad_token_id
                                                        ).to(device)
contextual_model.model.resize_token_embeddings(len(tokenizer))

Embedding(30528, 768)

In [12]:
b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)
batch.section_point, b_input_ids.shape

(10, torch.Size([1, 23]))

In [29]:
enc_output = contextual_model(b_input_ids,b_input_mask,apply_pool=True,return_dict=True)

odict_keys(['last_hidden_state', 'pooler_output'])


In [32]:
enc_output.sentence_embedding.shape

torch.Size([1, 768])

In [4]:
from datasets import load_dataset

dataset = load_dataset("xsum")

Found cached dataset xsum (/home/nlplab/.cache/huggingface/datasets/xsum/default/1.2.0/082863bf4754ee058a5b6f6525d0cb2b18eadb62c7b370b095d1364050a52b71)


  0%|          | 0/3 [00:00<?, ?it/s]

In [23]:
bart_tokenizer = setuptokenizer('facebook/bart-base',special_tokens=["[SEP]"])

In [25]:
import pandas as pd
def strip_newline(value):
    return ' '.join(value.splitlines())
def retrieve_data(data):
    pack = []
    for dat in data:
        zz=bart_tokenizer(dat['document'], return_tensors='pt')['input_ids'].shape
        
        if zz[-1] < 720:
            a= dict(input=strip_newline(dat['document']),output=strip_newline(dat['summary']))
            pack.append(a)
    data_pack = pd.DataFrame(pack)
    data_pack= data_pack.drop_duplicates(subset = ["output"],keep="last")
    return data_pack

In [26]:
train_data = retrieve_data(dataset['train'])
dev_data = retrieve_data(dataset['validation'])
test_data = retrieve_data(dataset['test'])

Token indices sequence length is longer than the specified maximum sequence length for this model (1140 > 1024). Running this sequence through the model will result in indexing errors


In [28]:
dev_data.shape

(9163, 2)

In [29]:
train_data.to_csv('summarisation_data/xsum_train.csv')
dev_data.to_csv('summarisation_data/xsum_dev.csv')
test_data.to_csv('summarisation_data/xsum_test.csv')

In [30]:
ff= pd.read_csv('summarisation_data/xsum_test.csv')

In [4]:
from src.dataset_processor import ContextualGenerationData,read_csv

        
        

In [5]:
rr = read_csv('summarisation_data/xsum_train.csv')

In [6]:
rr[0]



In [24]:
bart_tokenizer(dataset['train'][700]['document'], return_tensors='pt')['input_ids'].shape

torch.Size([1, 137])