In [1]:
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "2"
from functools import partial
import nltk
from src.contextual_bart import ContextualisedBartModel,BartForContextualRecovery,SimplifiedBeamSearch
from src.dataset_processor import load_all_data
from src.utils import SmartCollator, get_args, setuptokenizer
from src.dataset_processor import (
    ContextGenerationDataset,
)
from transformers import BartTokenizer, BartConfig,BartForConditionalGeneration
from src.model_utils import CustomTrainer, get_training_arguments
import torch
from src.config import DATASET_PATH
from transformers.trainer_callback import EarlyStoppingCallback
import pickle as pk
import torch
from transformers import (    AutoTokenizer,
          AutoModelForSeq2SeqLM,
         LogitsProcessorList,    MinLengthLogitsProcessor, StoppingCriteriaList, MaxLengthCriteria,
         TopKLogitsWarper, TemperatureLogitsWarper,BeamSearchScorer,)

nltk.download("punkt")

DATASET_PATH = "summarisation_data/"

def generate_tokenizer_and_data(args):

    # load the dataset

    train_data_packet = load_all_data(DATASET_PATH, mode="train")
    dev_data_packet = load_all_data(DATASET_PATH, mode="dev")
    test_data_packet = load_all_data(DATASET_PATH,mode="test")

    print(f"Training Data size: {len(train_data_packet)}")
    print(f"Training Data size: {len(test_data_packet)}")

    model_base = args.model_base
    tokenizer = setuptokenizer(
        model_base=model_base,
        special_tokens=[],
    )
    tokenizer.add_tokens([args.sep_token])

    train_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(train_data_packet), max_len=720,
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
        use_special_token=True,
    )
    train_dataset.change_data_mode(1)
    train_dataset.set_record(train_data_packet)

    test_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(test_data_packet), 
        max_len=700,
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
    )
    test_dataset.change_data_mode(1)
    test_dataset.set_record(test_data_packet)
    
    dev_dataset = ContextGenerationDataset(
        tokenizer=tokenizer, nb_records=len(dev_data_packet), 
        max_len=700,
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
    )
    test_dataset.change_data_mode(1)
    test_dataset.set_record(test_data_packet)

    return train_dataset, dev_dataset,test_dataset, [train_data_packet,dev_data_packet,test_data_packet]



def model_init(
    vocab_size,
    context_delimiter_id,
    model_base="facebook/bart-base",
    use_random_restriction=False,
    section_prob=(0.25, 0.45),
    device=torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
):
    def build_model():
        bart_config = BartConfig.from_pretrained(model_base)
        bart_config.context_delimiter_id = context_delimiter_id
        bart_config.use_random_restriction = use_random_restriction
        bart_config.section_prob = section_prob

        generator = BartForContextualRecovery.from_pretrained(
            model_base, config=bart_config, ignore_mismatched_sizes=True
        )

        # update the tokens
        generator.resize_token_embeddings(vocab_size)  # type: ignore
        return generator.to(device)  # type: ignore
    return build_model

[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
from dataclasses import dataclass
@dataclass
class Args:
    model_base: str
    sep_token: str = "[SEP]"
    is_not_auto_encoder_data: bool = True
    
    
args = Args(model_base="facebook/bart-base")
train_dataset, dev_dataset,test_dataset, [train_data_packet,dev_data_packet,test_data_packet] = generate_tokenizer_and_data(args)

processing files:  ['summarisation_data/xsum_train.csv']
processing files:  ['summarisation_data/xsum_dev.csv']
processing files:  ['summarisation_data/xsum_test.csv']
Training Data size: 162548
Training Data size: 9049
The model will be trained as a non auto-encoder
The model will be trained as a non auto-encoder
The model will be trained as a non auto-encoder


In [None]:
mores= []
for idx,da in enumerate(train_dataset):
    if da.input_ids.shape[0] > 600:
        mores.append(idx)

In [None]:
cmores= []
for idx,da in enumerate(test_dataset):
    if da.input_ids.shape[0] > 512:
        cmores.append(idx)

In [None]:
len(cmores)

In [None]:
delimiter_points=train_dataset[498269].input_ids==train_dataset._context_delimiter_id
delimiter_points_idx = delimiter_points.nonzero(as_tuple=True)[-1][0]
delimiter_points_idx

In [None]:
train_dataset[67].labels.shape,train_dataset[67].input_ids.shape,train_dataset[67].section_point

In [None]:
train_dataset[67].input_ids

In [3]:
context_delimiter_id = train_dataset.tokenizer.get_vocab()['[SEP]']

train_model_path = "trained_models_sum/bart_base_model_full/checkpoint-81275/pytorch_model.bin"
#"trained_models_mtl/bart_base_model_full/checkpoint-263195/pytorch_model.bin"

generator = model_init(len(train_dataset.tokenizer),
                       context_delimiter_id=context_delimiter_id,
                       model_base=args.model_base,use_random_restriction=False)()

state_dict = torch.load(train_model_path)
generator.load_state_dict(state_dict)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
from src.contextual_bart import EncoderOutputs
class SimplifiedBeamSearch:
    def __init__(self, generator, tokenizer) -> None:
        self.generator = generator
        self.tokenizer = tokenizer

    def generate(
        self,
        input_ids,
        attention_mask,
        shrink_encoder_output = 1,
        num_beams=5,
        min_length=100,
        max_length=500,
        top_k=50,
        temperature=0.85,
    ):

        # initialise decoder input_ids
        decoder_input_ids = torch.ones(
            (num_beams, 1), device=self.generator.device, dtype=torch.long
        )
        decoder_input_ids = (
            decoder_input_ids * self.generator.config.decoder_start_token_id
        )
        
        encoder = self.generator.get_encoder()
        
         
        
        encoder_outputs = self.generator.get_encoder()(
                input_ids.repeat_interleave(num_beams, dim=0),
                attention_mask.repeat_interleave(num_beams, dim=0),
                return_dict=True,
            )
        
        enc_num_tokens = encoder_outputs[0].shape[1]
        
        #print(encoder_outputs[0].shape)
        attention_mask_ = encoder_outputs.cleaned_mask
        if shrink_encoder_output< 1:
            num_usable_tokens = round(shrink_encoder_output*enc_num_tokens)
            
            encoder_outputs = EncoderOutputs(
                last_hidden_state=encoder_outputs[0][:,num_usable_tokens:,:],
                hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
                attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
                cleaned_mask=encoder_outputs[3] if len(encoder_outputs) > 3 else None,
            )
        
        print(f"Generating text using {encoder_outputs[0].shape[1]} token embeddings out off {input_ids.shape[-1]}")
            
        
        print("Encoder output shape: ",encoder_outputs[0].shape)
        
        print("Attention mask: ",attention_mask_.shape)

        model_kwargs = {
            "encoder_outputs": encoder_outputs,
            "attention_mask": attention_mask_,
        }
        beam_scorer = BeamSearchScorer(
            batch_size=attention_mask.shape[0],
            num_beams=num_beams,
            device=self.generator.device,
        )

        logits_processor = LogitsProcessorList(
            [
                MinLengthLogitsProcessor(
                    1, eos_token_id=self.generator.config.eos_token_id
                )
            ]
        )
        logits_warper = LogitsProcessorList(
            [
                TopKLogitsWarper(top_k),
                TemperatureLogitsWarper(temperature),
            ]
        )

        outputs = self.generator.beam_sample(
            decoder_input_ids,
            beam_scorer,
            max_length=max_length,
            logits_processor=logits_processor,
            # stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length)),
            logits_warper=logits_warper,
            **model_kwargs,
        )

        return self.tokenizer.batch_decode(outputs, skip_special_tokens=True)


In [5]:
dataset = ContextGenerationDataset(test_dataset.tokenizer,
                                   nb_records=1,
                                   section_boundary=(0.4,0.48),
                                   
        context_seperator=args.sep_token,
        is_auto_encoder_data=not args.is_not_auto_encoder_data,
                                   use_random_restrictive=True)
dataset.change_data_mode(1)

The model will be trained as a non auto-encoder


In [9]:
from src.dataset_processor import ContextualGenerationData
from pytorch_lightning import seed_everything
data = ContextualGenerationData(input="""
                                We are helping the community work together towards the goal of advancing Machine Learning üî•.
The Hugging Face Hub is a platform with over 60K models, 6K datasets, and 6K demos in which people can easily collaborate in their ML workflows. 
The Hub works as a central place where anyone can share, explore, discover, and experiment with open-source Machine Learning.
 No single company, including the Tech Titans, will be able to ‚Äúsolve AI‚Äù by themselves - the only way we'll achieve this is by sharing knowledge and resources in a community-centric approach. We are building the largest open-source collection of models, datasets, demos and metrics on the Hugging Face Hub to democratize and advance ML for everyone üöÄ.
                                """.replace("\n","").strip(),output="")
kk= 45
batch = test_dataset[kk]#dataset.procesTexts(data)
b_input_ids = batch.input_ids.view(1, -1).to(device)
b_input_mask = batch.attention_mask.view(1, -1).to(device)
batch.section_point, b_input_ids.shape

(78, torch.Size([1, 256]))

In [None]:
dataset.tokenizer.batch_decode(b_input_ids[:1])

In [7]:
test_data_packet[kk].output

'Scottish legal authorities have granted permission for Twitter to be used to report the conclusion of a murder trial at the High Court.'

In [None]:
b_input_ids.shape

In [16]:
bb=generator.generate(input_ids=b_input_ids,
            attention_mask=b_input_mask,
            num_beams=10,
            do_sample=True,
            max_new_tokens=300)
test_dataset.tokenizer.batch_decode(bb,clean_up_tokenization_spaces=True,skip_special_tokens=True)

['The sentencing of a man convicted of murdering a bookkeeper in Argyll will be streamed live on Twitter, the Lord Chief Justice has confirmed.']

In [None]:
seed_everything(100)
bb= SimplifiedBeamSearch(generator,dataset.tokenizer)
bb.generate(input_ids=b_input_ids,
            attention_mask=b_input_mask,
            shrink_encoder_output=1,
            num_beams=2,
            max_length=370,
            temperature=0.89)

In [None]:
0.1*25

In [None]:
import torch
import numpy as np

In [None]:
def get_random_embedding_sections(batch_size, max_length, low=0.45, high=0.55):
    deletion_section_probs = np.random.uniform(size=(batch_size,), low=low, high=high)
    deletion_section = max_length * deletion_section_probs
    return torch.round(
        torch.FloatTensor(deletion_section),
    ).long()

In [None]:
get_random_embedding_sections(15,200,)

In [None]:
import logging
import math
import random
from dataclasses import dataclass
from logging import Logger
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
from transformers import (
    BeamSearchScorer,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    TemperatureLogitsWarper,
    TopKLogitsWarper,
)
from transformers.models.bart.modeling_bart import (
    BartConfig,
    BartDecoder,
    BartEncoderLayer,
    BartLearnedPositionalEmbedding,
    BartPretrainedModel,
    BaseModelOutput,
    CrossEntropyLoss,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    _expand_mask,
    shift_tokens_right,
)

logger = logging.getLogger(__name__)


@dataclass
class EncoderOutputs(BaseModelOutput):
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    attention_mask: torch.LongTensor = None

In [None]:
class 

In [None]:
class RestrictedBartEncoder(BartPretrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`BartEncoderLayer`].

    Args:
        config: BartConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id

        #self._context_delimiter_id = config.context_delimiter_id
        self._min_section_prob,self._max_section_prob = config.section_prob
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(
                config.vocab_size, embed_dim, self.padding_idx
            )

        self.embed_positions = BartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )
        self.layers = nn.ModuleList(
            [BartEncoderLayer(config) for _ in range(config.encoder_layers)]
        )
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()
        
    
    def _get_random_embedding_sections(self,batch_size, max_length, low=0.40, high=0.6):
        deletion_section_probs = np.random.uniform(size=(batch_size,), low=low, high=high)
        deletion_section = max_length * deletion_section_probs
        return torch.round(torch.FloatTensor(deletion_section),).long()
    def _strip_context(self, input_ids, embeddings, attention_mask):
        """

        :param input_ids:
        :param embeddings:
        :param attention_mask:
        :return:
        """
        # identify the locations of the context_delimiter in each of the input sequence
        if type(input_ids) is list:
            input_ids = torch.LongTensor(
                input_ids,
            )
            
        # Get the batch-size and the max_len of embeddings
        batch_size, batch_max_length,_ =  embeddings.shape
        
        #delimiter_points.nonzero(as_tuple=True)[-1]
        
        # Randomly select parts of the encoder output to 
        delimiter_points_idxs = self._get_random_embedding_sections(batch_size,
                                                                    batch_max_length,
                                                                    self._min_section_prob,
                                                                    self._max_section_prob)

        all_embeddings = []
        all_attention_masks = []
        all_input_ids = []
        max_length = 0
        embedding_dim = embeddings.shape[-1]

        # For item in input_ids, embeddings, attention_mask, input_ids, select the
        # portion of the tensor after the delimiter_point_id
        for delimiter_point_id, embedding, att_mask in zip(
            delimiter_points_idxs, embeddings, attention_mask
        ):
            embedding = embedding[delimiter_point_id + 1 :, :]
            if max_length < embedding.shape[0]:
                max_length = embedding.shape[0]
            all_embeddings.append(embedding)
            all_attention_masks.append(att_mask[delimiter_point_id + 1 :])

        # Reshape all the section of interest for each item in all_input_ids, all_embeddings, all_attention_masks to
        # the same size
        batch_embeddings: List = list()
        batch_attention_masks: List = list()

        for idx, (embedding, att_mask) in enumerate(
            zip(all_embeddings, all_attention_masks)
        ):
            len_diff = max_length - embedding.shape[0]
            if max_length > embedding.shape[0]:
                pad_tensor = torch.zeros(len_diff, embedding_dim).to(embedding.device)
                embedding = torch.concat([embedding, pad_tensor], dim=0)

                attn_pads = torch.zeros(
                    len_diff,
                ).to(att_mask.device)
                att_mask = torch.concat([att_mask, attn_pads], -1)

            batch_embeddings += [embedding.view(-1, max_length, embedding_dim)]
            batch_attention_masks += [att_mask.view(-1, max_length)]
        
        # Create the final tensors with the contexts removed
        batch_attention_masks = torch.concat(batch_attention_masks, 0)
        batch_embeddings = torch.concat(batch_embeddings, 0)
        return batch_embeddings, batch_attention_masks

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input = input_ids
            input_ids = input_ids.view(-1, input_ids.shape[-1])
        elif inputs_embeds is not None:
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )

        attention_mask_ = attention_mask

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask_ = _expand_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
                    f" {head_mask.size()[0]}."
                )

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (
                dropout_probability < self.layerdrop
            ):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        attention_mask_,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask_,
                        layer_head_mask=(
                            head_mask[idx] if head_mask is not None else None
                        ),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        hidden_states, batch_encoder_attention_masks = self._strip_context(
            input_ids, hidden_states, attention_mask
        )

        if not return_dict:
            return tuple(
                v
                for v in [
                    hidden_states,
                    encoder_states,
                    all_attentions,
                    batch_encoder_attention_masks,
                ]
                if v is not None
            )

        return EncoderOutputs(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
            attentions=all_attentions,
            attention_mask=batch_encoder_attention_masks,
        )

In [None]:
import copy
config = copy.deepcopy(generator.config)
config.section_prob = (0.2,0.65)
config.context_delimiter_id = generator.model.get_encoder()._context_delimiter_id

In [None]:
generator.model.get_encoder()._context_delimiter_id

In [None]:
restrictive_encoder = RestrictedBartEncoder.from_pretrained("facebook/bart-base",config=config).to(device)

In [None]:
ouut = restrictive_encoder(b_input_ids.repeat_interleave(4, dim=0),b_input_mask.repeat_interleave(4, dim=0))

In [None]:
b_input_ids.shape

In [None]:
from datasets import load_dataset

dataset = load_dataset("race",'all')

In [None]:
cc=dataset['train'].features["article"]

In [None]:
dataset['train'][1]['article'].replace('\n',' ')

In [5]:
400*0.7

280.0

In [10]:
import pandas as pd
import torch
from src.utils import setuptokenizer
bart_tokenizer = setuptokenizer('facebook/bart-base',special_tokens=["[SEP]"])

In [19]:
embeddings = torch.rand((2,128,300))
attention_mask = torch.ones((2,128))
attention_mask[0,90:]= 0
attention_mask[1,110:]= 0

In [28]:
(attention_mask==1).sum(dim=1)

tensor([ 90, 110])

In [29]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = (
        model_output  # First element of model_output contains all token embeddings
    )
    input_mask_expanded = (
        attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    )
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
        input_mask_expanded.sum(1), min=1e-9
    )

In [74]:
def shrink_embeddings(embeddings,
                      attention_mask, 
                      percentage=0):

    if percentage == 0:
        return embeddings, attention_mask
    seq_lengths = (attention_mask == 1).sum(dim=1)
    
    embedding_dim = embeddings.shape[-1]

    all_embeddings = []
    all_attention_masks = []
    max_length = 0
    for sl, embed, attn in zip(seq_lengths, embeddings, attention_mask):
        summary_point =  torch.round(percentage * sl).long()
        print(summary_point)
        summary_embedding = embed[:summary_point, :]
        summary_attention = attn[:summary_point]
        summary_embedding = mean_pooling(summary_embedding.view(1,-1,embedding_dim), summary_attention)

        new_embedding = torch.concat([
            summary_embedding.view(1, -1).to(embed.device),
            embed[summary_point:, :]],
            dim=0,
        )
        new_attention = torch.concat([
            torch.ones((1,)).to(embed.device), attn[summary_point:, ]], -1
        )

        all_embeddings.append(new_embedding)
        all_attention_masks.append(new_attention)

        if max_length < new_embedding.shape[0]:
            max_length = new_embedding.shape[0]

    # Reshape all the section of interest for each item in all_input_ids, all_embeddings, all_attention_masks to
    # the same size
    batch_embeddings: List = list()
    batch_attention_masks: List = list()

    for idx, (embedding, att_mask) in enumerate(
        zip(all_embeddings, all_attention_masks)
    ):
        len_diff = max_length - embedding.shape[0]
        embedding_dim = embedding.shape[1]
        if max_length > embedding.shape[0]:
            pad_tensor = torch.zeros(len_diff, embedding_dim).to(embedding.device)
            embedding = torch.concat([embedding, pad_tensor], dim=0)

            attn_pads = torch.zeros(
                len_diff,
            ).to(att_mask.device)
            att_mask = torch.concat([att_mask, attn_pads], -1)

        batch_embeddings += [embedding.view(-1, max_length, embedding_dim)]
        batch_attention_masks += [att_mask.view(-1, max_length)]

    # Create the final tensors with the contexts removed
    batch_attention_masks = torch.concat(batch_attention_masks, 0)
    batch_embeddings = torch.concat(batch_embeddings, 0)
    
    return batch_embeddings,batch_attention_masks

In [82]:
embeddings,attention_mask = shrink_embeddings(embeddings,attention_mask,0.65)
print(embeddings.shape)

tensor(1)
tensor(1)
torch.Size([2, 40, 300])


In [85]:
embeddings

torch.Size([2, 40, 300])

In [83]:
attention_mask

tensor([[1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0.]])

In [54]:
embeddings.shape

torch.Size([2, 128, 300])

In [55]:
batch_embeddings.shape

torch.Size([2, 115, 300])

In [6]:
from transformers.models.bart.modeling_bart import (
    BartConfig,
    BartDecoder,
    BartEncoderLayer,
    BartLearnedPositionalEmbedding,
    BartPretrainedModel,
    BaseModelOutput,
    CrossEntropyLoss,
    Seq2SeqLMOutput,
    Seq2SeqModelOutput,
    _expand_mask,
    shift_tokens_right,
)
from src.model_utils import EncoderOutputs

import logging
logger = logging.getLogger(__name__)

In [None]:
class BartEncoder(BartPretrainedModel):
    """
    Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
    [`BartEncoderLayer`].

    Args:
        config: BartConfig
        embed_tokens (nn.Embedding): output embedding
    """

    def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = None):
        super().__init__(config)

        self.dropout = config.dropout
        self.layerdrop = config.encoder_layerdrop

        embed_dim = config.d_model
        self.padding_idx = config.pad_token_id

        self._context_delimiter_id = config.context_delimiter_id
        self.max_source_positions = config.max_position_embeddings
        self.embed_scale = math.sqrt(embed_dim) if config.scale_embedding else 1.0

        if embed_tokens is not None:
            self.embed_tokens = embed_tokens
        else:
            self.embed_tokens = nn.Embedding(
                config.vocab_size, embed_dim, self.padding_idx
            )

        self.embed_positions = BartLearnedPositionalEmbedding(
            config.max_position_embeddings,
            embed_dim,
        )
        self.layers = nn.ModuleList(
            [BartEncoderLayer(config) for _ in range(config.encoder_layers)]
        )
        self.layernorm_embedding = nn.LayerNorm(embed_dim)

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()
    
    
    
    
    def _resize_attention_mask(self, input_ids, attention_mask):
        """

        :param input_ids:
        :param embeddings:
        :param attention_mask:
        :return:
        """
        # identify the locations of the context_delimiter in each of the input sequence
        if type(input_ids) is list:
            input_ids = torch.LongTensor(
                input_ids,
            )
        delimiter_points = input_ids == self._context_delimiter_id

        delimiter_points_idxs = delimiter_points.nonzero(as_tuple=True)[-1]

        
        all_attention_masks = []
        all_input_ids = []
        max_length = 0

        # For item in input_ids, embeddings, attention_mask, input_ids, select the
        # portion of the tensor after the delimiter_point_id
        for delimiter_point_id,  att_mask in zip(
            delimiter_points_idxs,attention_mask
        ):
            
            if max_length < att_mask.shape[0]:
                max_length = att_mask.shape[0]
            
            all_attention_masks.append(att_mask[delimiter_point_id + 1 :])

        # Reshape all the section of interest for each item in all_input_ids, all_embeddings, all_attention_masks to
        # the same size
        batch_attention_masks: List = list()

        for idx, att_mask in enumerate( all_attention_masks):
            len_diff = max_length - att_mask.shape[0]
            if max_length > att_mask.shape[0]:

                attn_pads = torch.zeros(
                    len_diff,
                ).to(att_mask.device)
                att_mask = torch.concat([att_mask, attn_pads], -1)
                
            batch_attention_masks += [att_mask.view(-1, max_length)]
        
        # Create the final tensors with the contexts removed
        batch_attention_masks = torch.concat(batch_attention_masks, 0)
        return  batch_attention_masks
    
    def _strip_context(self, input_ids, embeddings, attention_mask):
        """

        :param input_ids:
        :param embeddings:
        :param attention_mask:
        :return:
        """
        # identify the locations of the context_delimiter in each of the input sequence
        if type(input_ids) is list:
            input_ids = torch.LongTensor(
                input_ids,
            )
        delimiter_points = input_ids == self._context_delimiter_id

        delimiter_points_idxs = delimiter_points.nonzero(as_tuple=True)[-1]

        all_embeddings = []
        all_attention_masks = []
        all_input_ids = []
        max_length = 0
        embedding_dim = embeddings.shape[-1]

        # For item in input_ids, embeddings, attention_mask, input_ids, select the
        # portion of the tensor after the delimiter_point_id
        for delimiter_point_id, embedding, att_mask in zip(
            delimiter_points_idxs, embeddings, attention_mask
        ):
            embedding = embedding[delimiter_point_id + 1 :, :]
            if max_length < embedding.shape[0]:
                max_length = embedding.shape[0]
            all_embeddings.append(embedding)
            all_attention_masks.append(att_mask[delimiter_point_id + 1 :])

        # Reshape all the section of interest for each item in all_input_ids, all_embeddings, all_attention_masks to
        # the same size
        batch_embeddings: List = list()
        batch_attention_masks: List = list()

        for idx, (embedding, att_mask) in enumerate(
            zip(all_embeddings, all_attention_masks)
        ):
            len_diff = max_length - embedding.shape[0]
            if max_length > embedding.shape[0]:
                pad_tensor = torch.zeros(len_diff, embedding_dim).to(embedding.device)
                embedding = torch.concat([embedding, pad_tensor], dim=0)

                attn_pads = torch.zeros(
                    len_diff,
                ).to(att_mask.device)
                att_mask = torch.concat([att_mask, attn_pads], -1)

            batch_embeddings += [embedding.view(-1, max_length, embedding_dim)]
            batch_attention_masks += [att_mask.view(-1, max_length)]
        
        # Create the final tensors with the contexts removed
        batch_attention_masks = torch.concat(batch_attention_masks, 0)
        batch_embeddings = torch.concat(batch_embeddings, 0)
        return delimiter_points_idxs,batch_embeddings, batch_attention_masks

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        shrink_map: Optional[dict] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, BaseModelOutput]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`BartTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
                Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:

                - 1 indicates the head is **not masked**,
                - 0 indicates the head is **masked**.

            inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
                This is useful if you want more control over how to convert `input_ids` indices into associated vectors
                than the model's internal embedding lookup matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
                for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        """
        
        
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            input = input_ids
            input_ids = input_ids.view(-1, input_ids.shape[-1])
        elif inputs_embeds is not None:
            input = inputs_embeds[:, :, -1]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale

        embed_pos = self.embed_positions(input)

        hidden_states = inputs_embeds + embed_pos
        hidden_states = self.layernorm_embedding(hidden_states)
        hidden_states = nn.functional.dropout(
            hidden_states, p=self.dropout, training=self.training
        )

        attention_mask_ = attention_mask

        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask_ = _expand_mask(attention_mask, inputs_embeds.dtype)

        encoder_states = () if output_hidden_states else None
        all_attentions = () if output_attentions else None

        # check if head_mask has a correct number of layers specified if desired
        if head_mask is not None:
            if head_mask.size()[0] != (len(self.layers)):
                raise ValueError(
                    f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
                    f" {head_mask.size()[0]}."
                )
                
        

        for idx, encoder_layer in enumerate(self.layers):
            if output_hidden_states:
                encoder_states = encoder_states + (hidden_states,)
            # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
            dropout_probability = random.uniform(0, 1)
            if self.training and (
                dropout_probability < self.layerdrop
            ):  # skip the layer
                layer_outputs = (None, None)
            else:
                if self.gradient_checkpointing and self.training:

                    def create_custom_forward(module):
                        def custom_forward(*inputs):
                            return module(*inputs, output_attentions)

                        return custom_forward

                    layer_outputs = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(encoder_layer),
                        hidden_states,
                        attention_mask_,
                        (head_mask[idx] if head_mask is not None else None),
                    )
                else:
                    layer_outputs = encoder_layer(
                        hidden_states,
                        attention_mask_,
                        layer_head_mask=(
                            head_mask[idx] if head_mask is not None else None
                        ),
                        output_attentions=output_attentions,
                    )

                hidden_states = layer_outputs[0]

            if output_attentions:
                all_attentions = all_attentions + (layer_outputs[1],)

        if output_hidden_states:
            encoder_states = encoder_states + (hidden_states,)

        delimiter_points_idxs,hidden_states, batch_encoder_attention_masks = self._strip_context(
            input_ids, hidden_states, attention_mask
        )
        
        

        if not return_dict:
            
            return tuple(
                v
                for v in [
                    hidden_states,
                    encoder_states,
                    all_attentions,
                    batch_encoder_attention_masks,
                    delimiter_points_idxs,
                ]
                if v is not None
            )
            
        #print(input_ids.shape, hidden_states.shape,batch_encoder_attention_masks.shape, " The data size or shape")
        
        return EncoderOutputs(
            last_hidden_state=hidden_states,
            hidden_states=encoder_states,
            attentions=all_attentions,
            cleaned_mask=batch_encoder_attention_masks,
            seperation_point=delimiter_points_idxs
        )