In [1]:
import sys
from typing import Optional
sys.path.append('../')

#https://github.com/amazon-science/efficient-longdoc-classification
from functools import partial
import nltk
import pickle as pk
import torch
from context_enforcement.models.context_enforcer import compute_context_boundary
from context_enforcement.trainers.train_bart3 import model_init
from context_enforcement.data.common import create_text_tokenizer, SmartCollator
from context_enforcement.trainers.common import get_dataset_specified_tasks
from pytorch_lightning import seed_everything

import sys
import os
seed_everything(1376)

def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)

  from .autonotebook import tqdm as notebook_tqdm
Global seed set to 1376
[nltk_data] Downloading package punkt to /home/nlplab/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
Global seed set to 1376


In [6]:
from transformers import BartConfig
bart_config = BartConfig.from_pretrained('facebook/bart-base')
tokenizer = create_text_tokenizer('facebook/bart-base')

task_dataset_gen = get_dataset_specified_tasks('xsum')

train_dataset = None
eval_dataset = None
test_dataset = None
if task_dataset_gen is not None:
    raw_dataset = task_dataset_gen(tokenizer=tokenizer, )
    train_dataset = raw_dataset['train']
    eval_dataset = raw_dataset['validation']
    test_dataset = raw_dataset['test']

In [45]:
from typing import Tuple, Optional

import torch
from torch import nn
from torch.nn import Linear
from transformers import BartConfig
from transformers.activations import ACT2FN
from transformers.models.bart.modeling_bart import BartAttention

from context_enforcement.models.context_enforcer import split_contexts_with_boundary


class BartContextEnforcerLayer(nn.Module):
    def __init__(self,
                 config: BartConfig,
                 is_normal_layer: bool = False
                 ):
        super().__init__()
        self.embed_dim = config.d_model
        
        # Check if use_random_restriction is true or false
        self.context_enforcer = None
        self.context_enforcer_layer_norm = None
        self.self_attn = BartAttention(
            embed_dim=self.embed_dim,
            num_heads=config.encoder_attention_heads,
            dropout=config.attention_dropout,
        )
        self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)

        self.dropout = config.dropout
        self.activation_fn = ACT2FN[config.activation_function]
        self.activation_dropout = config.activation_dropout
        self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
        self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
        self.final_layer_norm = nn.LayerNorm(self.embed_dim)

        # Context modeling
        dim = self.embed_dim
        self._wlc = Linear(dim, dim)
        self._wrc = Linear(dim, dim)
        self._wfc = Linear(dim, dim)

    def _compute_context(self, hidden_states: torch.FloatTensor,
                         attention_mask: torch.FloatTensor,
                         context_boundary: Tuple[int, int],
                         layer_head_mask: torch.FloatTensor,
                         output_attentions: Optional[bool] = False, ):

        [left_context, focus_context, right_context] = split_contexts_with_boundary(
            hidden_states,
            context_boundary,
        )

        boundary_start, boundary_end = context_boundary
        fc_seq_len = focus_context.shape[1]
        lc_seq_len = left_context.shape[1]
        rc_seq_len = right_context.shape[1]
        left_attention = None
        right_attention = None
        focus_attention_mask = None
        if attention_mask is not None:
            left_attention = attention_mask[:, :, :boundary_start, boundary_start:boundary_end]
            right_attention = attention_mask[:, :, boundary_end:, boundary_start:boundary_end]
            zc= torch.zeros(size=(left_context.shape[0], 1, fc_seq_len, 1))
            focus_attention_mask = torch.concat([attention_mask[:, :, boundary_start:boundary_end, :boundary_start],
                                                 zc,
                                                 attention_mask[:, :, boundary_start:boundary_end, boundary_end:]
                                                 ], dim=3)

        # Compute the context focus attention
        left_focus, left_focus_attention, _ = self.self_attn(
            hidden_states=left_context,
            key_value_states=focus_context,
            attention_mask=left_attention,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        right_focus, right_focus_attention, _ = self.self_attn(
            hidden_states=right_context,
            key_value_states=focus_context,
            attention_mask=right_attention,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
        )
        right_focus = self.activation_fn(self._wrc(right_focus)) + right_context
        left_focus = self.activation_fn(self._wlc(left_focus)) + left_context

        boundary = torch.zeros(
            (left_focus.shape[0], 1, left_focus.shape[-1]), device=left_focus.device
        )
        context = torch.concat([left_focus, boundary, right_focus], dim=1)

        focus_rep, focus_attention,_ = self.self_attn(hidden_states=focus_context,
                                                     key_value_states=context,
                                                     attention_mask=focus_attention_mask,
                                                     layer_head_mask=layer_head_mask,
                                                     output_attentions=output_attentions, )
        focus_rep = self.activation_fn(self._wfc(focus_rep)) + focus_context

        # Stitch the full reps together
        full_rep = torch.concat([left_focus, focus_rep, right_focus], dim=1)
        atten_weights = [left_focus_attention, focus_attention, right_focus_attention]

        return full_rep, atten_weights

    def forward(
            self,
            hidden_states: torch.FloatTensor,
            attention_mask: torch.FloatTensor,
            context_boundary: Tuple[int, int],
            layer_head_mask: torch.FloatTensor,
            output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(seq_len, batch, embed_dim)`
            attention_mask (`torch.FloatTensor`): attention mask of size
                `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
            layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
                `(encoder_attention_heads,)`.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
        """
        residual = hidden_states

        hidden_states, attn_weights = self._compute_context(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            layer_head_mask=layer_head_mask,
            output_attentions=output_attentions,
            context_boundary=context_boundary
        )
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states
        hidden_states = self.self_attn_layer_norm(hidden_states)

        residual = hidden_states
        hidden_states = self.activation_fn(self.fc1(hidden_states))
        hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
        hidden_states = self.fc2(hidden_states)
        hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
        hidden_states = residual + hidden_states

        hidden_states = self.final_layer_norm(hidden_states)

        if hidden_states.dtype == torch.float16 and (
                torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
        ):
            clamp_value = torch.finfo(hidden_states.dtype).max - 1000
            hidden_states = torch.clamp(hidden_states,
                                        min=-clamp_value, max=clamp_value)

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (attn_weights,)

        return outputs




In [46]:
bart_layer = BartContextEnforcerLayer(config=bart_config)

In [47]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [14]:
te1= test_dataset[692]
b_input_ids = te1.input_ids.view(1, -1).to(device)
b_input_mask = te1.attention_mask.view(1, -1).to(device)
attention_mask = _expand_mask(b_input_mask, torch.float32)

In [1]:
#seed_everything(1376)
boundary_sample =  (0.15, 0.65)
seq_len = te1.input_ids[:1024].shape[0]

boundary_width = int(0.33*seq_len) 
context_boundary = compute_context_boundary(seq_len,
                                            context_sampling_bounds=boundary_sample,
                                            context_max_len=boundary_width)
context_boundary,seq_len

NameError: name 'te1' is not defined

In [26]:
601-432

169

In [48]:
embed= torch.rand(size=(1,seq_len,bart_config.d_model))

In [49]:
boutput= bart_layer(embed,attention_mask,context_boundary,layer_head_mask=None,output_attentions=True)

torch.Size([1, 1, 432, 169])
torch.Size([1, 1, 169, 1])
ZZZZZZZZzzz
torch.Size([1, 1, 169, 432])
torch.Size([1, 1, 169, 207])
