In [None]:
from transformers import ProphetNetForConditionalGeneration, ProphetNetTokenizer, ProphetNetEncoder

model_raw = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")
model = ProphetNetForConditionalGeneration.from_pretrained("microsoft/prophetnet-large-uncased")
tokenizer = ProphetNetTokenizer.from_pretrained("microsoft/prophetnet-large-uncased")

In [13]:
import torch
text = torch.load('../../data/processed/cnn-dm/text/test/chunk_0.pt')
summary = torch.load('../../data/processed/cnn-dm/summary/test/chunk_0.pt')

In [14]:
import torch

sd = torch.load('../../models/unfrozen-cnn/epoch0_end', map_location=torch.device('cpu'))
input_ids = text.input_ids[:1]
attention_mask = text.attention_mask[:1]
labels = summary.input_ids[:1]

In [15]:
new_sd = {key[13:]: value for key, value in sd.items()}
model.load_state_dict(new_sd)

<All keys matched successfully>

Original text

In [5]:
tokenizer.decode(input_ids[0], skip_special_tokens=True)

'costa rica has taken its border dispute with nicaragua to international court repeating claims that its territory has been invaded in a statement thursday costa ricas foreign ministry said the country had filed a lawsuit at the international court of justice in the hague netherlands to end a situation that threatens imminent and irreparable harm to costa rica the suit asks the court to stop the construction of a canal on costa rican soil according to the statement tensions between nicaragua and costa rica have flared over calero island a parcel of land on the atlantic coast nicaragua denies its troops are in costa rican territory costa rica claims it has been invaded costa rica claims that in addition to the nicaraguan troops a dredging project in the river is dumping sediment on its side of the border and that a costa rican flag in the area was replaced with a nicaraguan flag nicaragua has accused costa rica of breaking diplomatic relations between the countries also thursday the org

Original summary

In [6]:
tokenizer.decode(labels[0], skip_special_tokens=True)

'costa ricas foreign ministry says the situation threatens imminent and irreparable harm the suit asks the court to stop the construction of a canal on costa rican soil foreign ministers from the region will discuss the situation december'

Generated based on input_ids

In [7]:
tokenizer.decode(model_raw.generate(input_ids=input_ids)[0])

"[SEP] according to the statement the suit asks the court to stop the construction of a canal on costa rican soil to end a situation that threatens imminent harm to costa rica and its citizens to end the situation that is threatening costa rica's sovereignty and to costa ricans to stop it from happening costa rica has taken a lawsuit at the international court of justice in the hague [SEP]"

In [8]:
tokenizer.decode(model.generate(input_ids=input_ids)[0])

'[SEP] costa ricas foreign ministry says it has filed a lawsuit at the international court of justice in the hague netherlands to end a situation that threatens imminent and irreparable harm to costa rica the suit asks the court to stop the construction of a canal on costa rican soil [SEP]'

Generate based on encoder outputs

In [9]:
enc_outputs = model_raw.prophetnet.encoder(input_ids=input_ids, attention_mask=attention_mask)
enc_outputs2 = model.prophetnet.encoder(input_ids=input_ids, attention_mask=attention_mask)

In [10]:
tokenizer.decode(model_raw.generate(encoder_outputs=enc_outputs)[0])

'[SEP] to be continued. [SEP]'

In [11]:
tokenizer.decode(model.generate(encoder_outputs=enc_outputs)[0])

'[SEP] new new york city is set to be a new city for the first time in years the city of new york will be the first to be built in the new town of st louis it is one of the most popular places in the world to be set up for a new town the city is now known as the town of the new city of the world has been set up with a new set of buildings and a new one that will be built on the site of the first of the three stages of the road that will go through the middle of the country and go through a series of steps that will lead to a new area of the city which will be named the new york bridge and the new bridge will be'

In [12]:
tokenizer.decode(model_raw.generate(encoder_outputs=enc_outputs2)[0])

'[SEP] tensions between nicaragua and costa rica have flared over calero island a parcel of land on the atlantic coast the statement was translated by the national news agency. 1629. 1656. 1654. 1641. 1654 [SEP]'

In [13]:
tokenizer.decode(model.generate(encoder_outputs=enc_outputs2)[0])

'[SEP] costa rica has filed a lawsuit at the international court of justice in the hague netherlands to end a situation that threatens irreparambling harm to the country [SEP]'

In [14]:
enc_outputs.last_hidden_state.shape

torch.Size([16, 512, 1024])

In [15]:
enc_outputs2.last_hidden_state.shape

torch.Size([16, 512, 1024])

# Cross attn

In [17]:
hidden_states=torch.randn(1, 500, 1024) #hidden_states
key_value_states=torch.randn(1, 500, 1024)#encoder_hidden_states
key_value_states20=torch.randn(20, 500, 1024)#encoder_hidden_states

In [18]:
def _shape(tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, 16, 64).transpose(1, 2).contiguous()
proj = torch.nn.Linear(1024, 1024)

In [19]:
batch_size, tgt_len, hidden_size = hidden_states.size()
is_cross_attention = key_value_states is not None

query_states = proj(hidden_states) / (64**0.5)    # 64 = head_dim

In [20]:
key_states = _shape(proj(key_value_states), -1, batch_size)
value_states = _shape(proj(key_value_states), -1, batch_size)
key_states20 = _shape(proj(key_value_states20), -1, batch_size)
value_states20 = _shape(proj(key_value_states20), -1, batch_size)

In [21]:
key_states.shape

torch.Size([1, 16, 500, 64])

In [22]:
key_states20.shape

torch.Size([1, 16, 10000, 64])

In [23]:
past_key_value = (key_states, value_states)
past_key_value20 = (key_states20, value_states20)

In [24]:
proj_shape = (batch_size * 16, -1, 64)
query_states = _shape(query_states, tgt_len, batch_size).view(*proj_shape)
key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)
key_states20 = key_states20.view(*proj_shape)
value_states20 = value_states20.view(*proj_shape)

In [25]:
key_states.shape

torch.Size([16, 500, 64])

In [26]:
key_states20.shape

torch.Size([16, 10000, 64])

In [27]:
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
attn_weights.shape

torch.Size([16, 500, 500])

In [28]:
attn_weights20 = torch.bmm(query_states, key_states20.transpose(1, 2))
attn_weights20.shape

torch.Size([16, 500, 10000])

In [29]:
attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_weights20 = torch.nn.functional.softmax(attn_weights20, dim=-1)

In [30]:
attn_probs = torch.nn.functional.dropout(
        attn_weights,
        training=False
    )
attn_probs20 = torch.nn.functional.dropout(
        attn_weights20,
        training=False
    )

In [31]:
attn_output = torch.bmm(attn_probs, value_states)
attn_output20 = torch.bmm(attn_probs20, value_states20)

In [32]:
attn_output.shape

torch.Size([16, 500, 64])

In [33]:
attn_output20.shape

torch.Size([16, 500, 64])

Now we're back to same shape without problems (line 755)

https://github.com/huggingface/transformers/blob/v4.17.0/src/transformers/models/prophetnet/modeling_prophetnet.py#L755

In [1]:
import torch

def _shape(tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, 16, 64).transpose(1, 2).contiguous()
proj = torch.nn.Linear(1024, 1024)

In [35]:
hidden_states=torch.randn(1, 500, 1024) #hidden_states
batch_size, tgt_len, hidden_size = hidden_states.size()
query_states = proj(hidden_states) / (64**0.5)    # 64 = head_dim

proj_shape = (batch_size * 16, -1, 64)
query_states = _shape(query_states, tgt_len, batch_size).view(*proj_shape)

In [36]:
key_value_states=torch.randn(bs, 500, 1024) #encoder_hidden_states
key_states = _shape(proj(key_value_states), -1, batch_size)
value_states = _shape(proj(key_value_states), -1, batch_size)

key_states = key_states.view(*proj_shape)
value_states = value_states.view(*proj_shape)

In [37]:
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
attn_probs = torch.nn.functional.softmax(attn_weights, dim=-1)
attn_output = torch.bmm(attn_probs, value_states)

In [39]:
attn_output = (
    attn_output.view(batch_size, 16, tgt_len, 64)
    .transpose(1, 2)
    .reshape(batch_size, tgt_len, hidden_size)
)

In [6]:
for bs in [1, 20]:
    hidden_states=torch.randn(1, 500, 1024) #hidden_states
    batch_size, tgt_len, hidden_size = hidden_states.size()
    key_value_states=torch.randn(bs, 500, 1024) #encoder_hidden_states
    query_states = proj(hidden_states) / (64**0.5)    # 64 = head_dim

    print(f'Target: {hidden_states.shape}     Encoder: {key_value_states.shape}')

    key_states = _shape(proj(key_value_states), -1, batch_size)
    value_states = _shape(proj(key_value_states), -1, batch_size)

    proj_shape = (batch_size * 16, -1, 64)
    query_states = _shape(query_states, tgt_len, batch_size).view(*proj_shape)
    key_states = key_states.view(*proj_shape)
    value_states = value_states.view(*proj_shape)
    print("Query shape:", query_states.shape)
    print("Key/value shape:", key_states.shape)

    attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
    print("Attn_weights shape:", attn_weights.shape)
    attn_probs = torch.nn.functional.softmax(attn_weights, dim=-1)
    attn_output = torch.bmm(attn_probs, value_states)
    print("Attn_output shape:", attn_output.shape)

    attn_output = (
        attn_output.view(batch_size, 16, tgt_len, 64)
        .transpose(1, 2)
        .reshape(batch_size, tgt_len, hidden_size)
    )
    print("Output shape:", attn_output.shape)


    print(f'Output: {attn_output.sum()}')
    print()


Target: torch.Size([1, 500, 1024])     Encoder: torch.Size([1, 500, 1024])
Query shape: torch.Size([16, 500, 64])
Key/value shape: torch.Size([16, 500, 64])
Attn_weights shape: torch.Size([16, 500, 500])
Attn_output shape: torch.Size([16, 500, 64])
Output shape: torch.Size([1, 500, 1024])
Output: -15.079450607299805

Target: torch.Size([1, 500, 1024])     Encoder: torch.Size([20, 500, 1024])
Query shape: torch.Size([16, 500, 64])
Key/value shape: torch.Size([16, 10000, 64])
Attn_weights shape: torch.Size([16, 500, 10000])
Attn_output shape: torch.Size([16, 500, 64])
Output shape: torch.Size([1, 500, 1024])
Output: -195.8182373046875

