# ***Hierarchical BiLSTM***

## Import Required Libraries

In [1]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=9826cb7c7564b35b9c1558c22f2234b4941b754d02384013d0cb085713fc0ad0
  Stored in directory: /root/.cache/pip/wheels/1e/19/43/8a442dc83660ca25e163e1bd1f89919284ab0d0c1475475148
Successfully built rouge_score
Installing collected packages: rouge_score
Successfully installed rouge_score-0.1.2


In [2]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from gensim.models import Word2Vec
from tqdm import tqdm
import re
from rouge_score import rouge_scorer
import matplotlib.pyplot as plt
import gc
from collections import Counter
import traceback

## Load Processed Data

In [3]:
train_df = pd.read_csv('/kaggle/input/text-summarization/train.csv')
valid_df = pd.read_csv('/kaggle/input/text-summarization/valid.csv')

In [None]:
def doc_to_sentences(doc: str) -> list:
    parts = re.split(r'\s([.!?:])(?:\s+|$)', doc.strip())
    parts = [part for part in parts if part]
    sentences = []
    buffer = ''
    for part in parts:
        if part.strip() in '.!?"\'':
            buffer += ' ' + part.strip()
        else:
            if buffer:
                sentences.append(buffer)
            buffer = part.strip()
    if buffer:
        sentences.append(buffer.strip())
    return sentences

doc_test = 'Tổng bí thư Nguyễn Phú Trọng phát biểu lúc 12h49 ngày 1/1/2990 . " Quốc hội khai. mạc phiên họp thứ 2 ! " . '
print(doc_to_sentences(doc_test))

In [None]:
docs = train_df['content'].apply(doc_to_sentences).tolist()
doc_lengths = [len(doc) for doc in docs]
sent_lengths = [len(sent.split()) for doc in docs for sent in doc]
summary_lengths = [len(summary.split()) for summary in train_df['summary']]

fig, axes = plt.subplots(1, 3, figsize=(18, 4))
fig.suptitle('Train data distribution')

axes[0].hist(doc_lengths, bins=50, color='green')
axes[0].set_title('Doc Lengths')
axes[0].set_xlabel('Number of sentences')
axes[0].set_ylabel('Frequency')

axes[1].hist(sent_lengths, bins=50, color='orange')
axes[1].set_title('Sentence Lengths')
axes[1].set_xlabel('Number of words')
axes[1].set_ylabel('Frequency')

axes[2].hist(summary_lengths, bins=50, color='blue')
axes[2].set_title('Summary Lengths')
axes[2].set_xlabel('Number of words')
axes[2].set_ylabel('Frequency')

plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.show()

## Load Word2Vec

In [4]:
class Vocab:
    def __init__(self, word2vec_model_path: str):
        self.model = Word2Vec.load(word2vec_model_path)
        
        self.word2id = {}
        self.id2word = {}
        self.embedding_dim = self.model.vector_size
        self.build()
        self.unk_id = self.word2id['<UNK>']
        self.pad_id = self.word2id['<PAD>']
        self.sos_id = self.word2id['<SOS>']
        self.eos_id = self.word2id['<EOS>']
        self.num_id = self.get_index('<NUM>')
        self.time_id = self.get_index('<TIME>')
        self.date_id = self.get_index('<DATE>')
        self.num_regexp = r"[\d.,]*\d[\d.,]*"
        self.time_regexp = r"(\d{1,2}h(\d{1,2})?)"
        self.date_regexp = r"(\d{1,2}/\d{1,2}(/\d{2,4})?)"

    def build(self):
        id = 0
        for word in self.model.wv.index_to_key:
            self.word2id[word] = id
            self.id2word[id] = word
            id += 1
        special_tokens = ['<UNK>', '<PAD>', '<SOS>', '<EOS>']
        for token in special_tokens:
            self.word2id[token] = id
            self.id2word[id] = token
            id += 1
    
    def __len__(self):
        return len(self.word2id)
    
    def get_index(self, word):
        return self.word2id.get(word, self.unk_id)
    
    def get_word(self, id):
        return self.id2word.get(id, '<UNK>')

    def decode(self, ids, oov):
        ids = [id for id in ids if id not in [self.unk_id, self.pad_id, self.eos_id, self.sos_id, self.num_id, self.time_id, self.date_id]]
        words = []
        len_vocab = self.__len__()
        for id in ids:
            if id >= len_vocab + len(oov):
                continue
            elif id >= len_vocab:
                words.append(oov[id - len_vocab])
            else:
                words.append(self.get_word(id))
        return ' '.join(words)

In [5]:
vocab = Vocab('/kaggle/input/text-summarization/word2vec_skipgram.model')

In [None]:
print(len(vocab))

In [None]:
vocab.model.wv.most_similar("tốt", topn=10)

## Create Dataset and Loader

In [6]:
class SummarizationDataset(Dataset):
    def __init__(
        self, 
        df: pd.DataFrame, 
        max_doc_length: int, 
        max_sentence_length:int, 
        max_summary_length: int, 
        vocab: Vocab
    ):
        self.max_doc_length = max_doc_length
        self.max_sentence_length = max_sentence_length
        self.max_summary_length = max_summary_length
        self.vocab = vocab
        self.data = []
        for row in df.itertuples(index=False):
            input, extend_input, oov = self.encode_content(row.content)
            target, extend_target = self.encode_summary(row.summary, oov)
            if len(input) == 0 or len(target) == 0:
                continue
            self.data.append((input, extend_input, target, extend_target, oov))

    def encode_content(self, content):
        parts = re.split(r'\s([.!?:])(?:\s+|$)', content.strip())
        parts = [part for part in parts if part]
        sentences = []
        buffer = ''
        for part in parts:
            if part.strip() in '.!?"\'':
                buffer += ' ' + part.strip()
            else:
                if buffer:
                    sentences.append(buffer)
                buffer = part.strip()
        if buffer:
            sentences.append(buffer.strip())
        tokens = [sentence.split() for sentence in sentences]
        if len(tokens) > self.max_doc_length:
            tokens = tokens[:self.max_doc_length]
        origin = []
        extend = []
        oov = []
        for sentence in tokens:
            origin_sent = []
            extend_sent = []
            if len(sentence) > self.max_sentence_length:
                sentence = sentence[:self.max_sentence_length]
            for word in sentence:
                id = self.vocab.get_index(word)
                if id == self.vocab.unk_id:
                    if re.fullmatch(self.vocab.num_regexp, word):
                        id = self.vocab.num_id
                    elif re.fullmatch(self.vocab.time_regexp, word):
                        id = self.vocab.time_id
                    elif re.fullmatch(self.vocab.date_regexp, word):
                        id = self.vocab.date_id
                    origin_sent.append(id)
                    if word not in oov:
                        oov.append(word)
                    extend_sent.append(len(self.vocab) + oov.index(word))
                else:
                    origin_sent.append(id)
                    extend_sent.append(id)
            origin.append(origin_sent)
            extend.append(extend_sent)
        return origin, extend, oov
    
    def encode_summary(self, summary, oov):
        sentence = summary.strip().split()
        if len(sentence) + 2 > self.max_summary_length:
            sentence = sentence[:self.max_summary_length - 2]
        origin = []
        extend = []
        for word in sentence:
            id = self.vocab.get_index(word)
            if id == self.vocab.unk_id:
                if re.fullmatch(self.vocab.num_regexp, word):
                    id = self.vocab.num_id
                elif re.fullmatch(self.vocab.time_regexp, word):
                    id = self.vocab.time_id
                elif re.fullmatch(self.vocab.date_regexp, word):
                    id = self.vocab.date_id
            origin.append(id)
            if word in oov:
                extend.append(len(self.vocab) + oov.index(word))
            else: 
                extend.append(id)
        return origin, extend

    def __len__(self):
        return len(self.data)
    
    def pad_content(self, doc):
        padded = []
        if len(doc) > self.max_doc_length:
            doc = doc[:self.max_doc_length]
        for sentence in doc:
            if len(sentence) < self.max_sentence_length:
                sentence += [self.vocab.pad_id] * (self.max_sentence_length - len(sentence))
            elif len(sentence) > self.max_sentence_length:
                sentence = sentence[:self.max_sentence_length]
            padded.append(sentence)
        if len(padded) < self.max_doc_length:
            padded += [[self.vocab.pad_id] * self.max_sentence_length] * (self.max_doc_length - len(padded))
        
        return padded
    def pad_summary(self, doc):
        if len(doc) > self.max_summary_length:
            doc = doc[:self.max_summary_length]
        else:
            doc += [self.vocab.pad_id] * (self.max_summary_length - len(doc))
        
        return doc

    def __getitem__(self, index):
        input, extend_input, target, extend_target, oov = self.data[index]
        target = [self.vocab.sos_id] + target + [self.vocab.eos_id]
        extend_target = [self.vocab.sos_id] + extend_target + [self.vocab.eos_id]
        input = self.pad_content(input)
        attention_mask = [[1 if token != self.vocab.pad_id else 0 for token in sent] for sent in input]
        extend_input = self.pad_content(extend_input)
        target = self.pad_summary(target)
        extend_target = self.pad_summary(extend_target)
        
        return {
            'input': torch.tensor(input, dtype=torch.long),
            'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
            'extend_input': torch.tensor(extend_input, dtype=torch.long),
            'target': torch.tensor(target, dtype=torch.long),
            'extend_target': torch.tensor(extend_target, dtype=torch.long),
            'oov': oov
        }

In [7]:
class DatasetChunk(Dataset):
    def __init__(self, dataset, indices):
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        return self.dataset[real_idx]

In [8]:
def collate_fn(batch):
    input_ids = torch.stack([item['input'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    extend_input_ids = torch.stack([item['extend_input'] for item in batch])
    target_ids = torch.stack([item['target'] for item in batch])
    extend_target_ids = torch.stack([item['extend_target'] for item in batch])
    oov_lists = [item['oov'] for item in batch]

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'extend_input_ids': extend_input_ids,
        'target_ids': target_ids,
        'extend_target_ids': extend_target_ids,
        'oov_lists': oov_lists
    }

In [9]:
max_doc_length = 45
max_sent_length = 60
max_summary_length = 60

In [10]:
train_dataset = SummarizationDataset(
    df=train_df, 
    max_doc_length= max_doc_length, 
    max_sentence_length=max_sent_length, 
    max_summary_length=max_summary_length, 
    vocab=vocab
)
valid_dataset = SummarizationDataset(
    df=valid_df, 
    max_doc_length= max_doc_length, 
    max_sentence_length=max_sent_length, 
    max_summary_length=max_summary_length, 
    vocab=vocab
)

In [11]:
def create_dataloaders(dataset, n_chunks, batch_size=16, collate_fn=collate_fn):
    length = len(dataset)
    indices = np.random.permutation(length)
    chunk_size = length // n_chunks

    dataloaders = []
    for i in range(n_chunks):
        start_idx = i * chunk_size
        end_idx = (i + 1) * chunk_size if i < n_chunks - 1 else length
        chunk_indices = indices[start_idx:end_idx]

        chunk_dataset = DatasetChunk(dataset, chunk_indices)
        dl = DataLoader(
            chunk_dataset,
            batch_size=batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=4,
            collate_fn=collate_fn
        )
        dataloaders.append(dl)
    return dataloaders

In [12]:
train_loaders = create_dataloaders(train_dataset, n_chunks=4, batch_size=16)
valid_loader = DataLoader(valid_dataset, batch_size=128, pin_memory=True, shuffle=False, num_workers=4, collate_fn=collate_fn)

## Define Model

In [13]:
class Word2VecEmbedding(nn.Module):
    def __init__(self, vocab: Vocab):
        super().__init__()
        self.vocab = vocab
        self.embedding = nn.Embedding(len(vocab), vocab.embedding_dim)
        self.adapter = nn.Linear(vocab.embedding_dim, vocab.embedding_dim)
        self.load_pretrained_weights()
    
    def load_pretrained_weights(self):
        self.embedding.weight.requires_grad = False
        weights = []
        for i in range(len(self.vocab)):
            word = self.vocab.get_word(i)
            if word in self.vocab.model.wv:
                vec = self.vocab.model.wv[word]
            else:
                vec = np.random.normal(scale=0.1, size=self.vocab.embedding_dim)
            weights.append(torch.tensor(vec, dtype=torch.float32))
        weights[self.vocab.pad_id] = torch.zeros(self.vocab.embedding_dim)
        
        weights_tensor = torch.stack(weights)
        with torch.no_grad():
            self.embedding.weight.copy_(weights_tensor)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.adapter(x)
        x = F.gelu(x)
        return x

In [14]:
class AttentionPooling(nn.Module):
    def __init__(
        self,
        hidden_dim,
        # dropout
    ):
        super().__init__()
        self.adapter = nn.Linear(hidden_dim, hidden_dim)
        self.norm = nn.LayerNorm(hidden_dim)
        self.score_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim//2),
            nn.GELU(),
            nn.Linear(hidden_dim//2, 1)
        )
        nn.init.constant_(self.score_mlp[-1].bias, 0.1)
        self.temperature = nn.Parameter(torch.tensor(1.0))
        # self.dropout = nn.Dropout(dropout)
        
    def forward(self, hiddens, mask=None):
        h = self.adapter(hiddens)
        h = F.relu(h)
        h = self.norm(h)
        
        # [..., H] -> [..., 1] -> [...]
        scores = self.score_mlp(h).squeeze(-1) / self.temperature

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e4)

        # [...] -> [..., 1]
        weights = F.softmax(scores, dim=-1).unsqueeze(-1)

        context = torch.sum(h * weights, dim=-2)

        return context + 0.1 * hiddens.mean(dim=-2)

In [15]:
class ResidualBlock(nn.Module):
    def __init__(
        self,
        input_dim,
        hidden_dim,
        # dropout=0.3,
        activation=nn.ReLU,
    ):
        super().__init__()
        self.norm = nn.LayerNorm(input_dim)
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, input_dim)
        # self.dropout = nn.Dropout(dropout)
        self.activation = activation()

    def forward(self, x):
        x = self.norm(x)
        out = self.fc1(x)
        out = self.activation(out)

        out = self.fc2(out)
        return self.activation(x + out)

In [16]:
class Encoder(nn.Module):
    def __init__(
        self,
        embedding_dim,
        hidden_size_word, 
        word_residual_configs,
        hidden_size_sent,
        sent_residual_configs,
        dropout=0.3,
    ):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_size_word = hidden_size_word
        self.hidden_size_sent = hidden_size_sent
        self.dropout = nn.Dropout(dropout)

        self.word_input_norm = nn.LayerNorm(embedding_dim)
        
        # Word-level BiLSTM
        self.word_layer = nn.LSTM(
            input_size=self.embedding_dim,
            hidden_size=self.hidden_size_word,
            bidirectional=True,
            batch_first=True
        )

        self.word_attention_pooling = AttentionPooling(
            2 * self.hidden_size_word, 
            # dropout
        )

        self.word_residuals = nn.Sequential(
            *[
                ResidualBlock(
                    input_dim=2 * hidden_size_word,
                    hidden_dim=config['hidden_dim'],
                    # dropout=dropout,
                    activation=config['activation']
                )
                for config in word_residual_configs
            ]
        )

        self.sent_input_norm = nn.LayerNorm(2 * hidden_size_word)
        
        # Sentence-level BiLSTM
        self.sent_layer = nn.LSTM(
            input_size=self.hidden_size_word * 2,
            hidden_size=self.hidden_size_sent,
            bidirectional=True,
            batch_first=True
        )

        self.sent_attention_pooling = AttentionPooling(
            2 * self.hidden_size_sent, 
            # dropout
        )
        
        self.sent_residuals = nn.Sequential(
            *[
                ResidualBlock(
                    input_dim=2 * self.hidden_size_sent,
                    hidden_dim=config['hidden_dim'],
                    # dropout=dropout,
                    activation=config['activation']
                )
                for config in sent_residual_configs
            ]
        )
        
    def forward(self, embedded_inputs, attention_masks, debug=False):
        """
        B: Batch size
        S: Number of sentences
        W: Number of words in a sentences
        D: Embedding dim
        HW: Hidden word size
        HS: Hidden sent size

        Args:
            embedded_inputs: [B, S, W, D]
            attention_mask: [B, S, W]
        Returns:
            output: [B, 2HS]
            word_layer_outputs: [B, S * W, 2HW]
        """

        B, S, W = attention_masks.shape
        device = embedded_inputs.device

        # Flatten 
        # [B, S, W, D] -> [B * S, W, D]
        flatted_inputs = embedded_inputs.view(B * S, W, -1)
        flatted_inputs = self.word_input_norm(flatted_inputs)
        flatted_inputs = self.dropout(flatted_inputs)

        # [B, S, W] -> [B * S, W]
        flatted_masks = attention_masks.view(B * S, -1)

        # === Word-level BiLSTM ===
        
        # Compute lengths for packing
        sent_lengths = flatted_masks.sum(dim=1).cpu()
        valid_masks = sent_lengths > 0

        packed_word_layer_inputs = pack_padded_sequence(
            flatted_inputs[valid_masks],
            lengths=sent_lengths[valid_masks],
            batch_first=True,
            enforce_sorted=False
        )
        packed_word_layer_outputs, _ = self.word_layer(packed_word_layer_inputs)
        
        # [B_valid, W, 2HW]
        unpacked_word_layer_outputs, _ = pad_packed_sequence(packed_word_layer_outputs, batch_first=True, total_length=W)

        # [B * S, W, 2HW]
        word_layer_outputs = torch.zeros(B * S, W, 2 * self.hidden_size_word, device=device)
        word_layer_outputs[valid_masks] = unpacked_word_layer_outputs

        # [B * S, W, 2HW] -> [B * S, 2HW]
        sent_layer_inputs = self.word_attention_pooling(word_layer_outputs, flatted_masks)
        sent_layer_inputs = self.dropout(sent_layer_inputs)
        
        sent_layer_inputs = self.word_residuals(sent_layer_inputs)

        # [B * S, 2HW] -> [B, S, 2HW]
        sent_layer_inputs = sent_layer_inputs.view(B, S, -1)
        sent_layer_inputs = self.sent_input_norm(sent_layer_inputs)
        sent_layer_inputs = self.dropout(sent_layer_inputs)

        # === Sentence-level BiLSTM ===

        # Compute lengths for packing
        doc_masks = (attention_masks.sum(dim=2) > 0).long()
        doc_lengths = doc_masks.sum(dim=-1).cpu()

        packed_sent_layer_inputs = pack_padded_sequence(
            sent_layer_inputs,
            lengths=doc_lengths,
            batch_first=True,
            enforce_sorted=False
        )
        packed_sent_layer_outputs, _ = self.sent_layer(packed_sent_layer_inputs)

        sent_layer_outputs, _ = pad_packed_sequence(packed_sent_layer_outputs, batch_first=True, total_length=S)

        # [B, S, 2HS] -> [B, 2HS]
        outputs = self.sent_attention_pooling(sent_layer_outputs)
        outputs = self.dropout(outputs)
        
        outputs = self.sent_residuals(outputs)
        
        return outputs, word_layer_outputs.view(B, S * W, -1)

In [17]:
class Attention(nn.Module):
    def __init__(
        self,
        hidden_size_word,
        hidden_size_decoder,
        attention_dim,
        dropout
    ):
        super().__init__()
        self.enc_proj = nn.Linear(hidden_size_word * 2, attention_dim)
        self.enc_norm = nn.LayerNorm(attention_dim)
        self.dec_proj = nn.Linear(hidden_size_decoder, attention_dim, bias=False)
        self.dec_norm = nn.LayerNorm(attention_dim)
        self.score_proj = nn.Linear(attention_dim, 1)
        self.dropout = nn.Dropout(dropout)
        nn.init.constant_(self.score_proj.bias, 0.1)
    
    def forward(self, dec_hidden, enc_outputs, enc_proj=None, enc_mask=None):
        """
        Args:
            dec_hidden: [B, 1, H]
            enc_outputs: [B, S * W, 2HW]
            enc_proj: [B, S * W, A]
            enc_mask: [B, S * W]
        Returns:
            context: [B, 2HW]
        """
        B, SxW, _ = enc_outputs.size()

        if enc_proj is None:
            # [B, S * W, 2HW] -> [B, S * W, A]
            enc_proj = self.enc_proj(enc_outputs)
            enc_proj = self.enc_norm(enc_proj)

        # [B, 1, H] -> [B, 1, A]
        dec_proj = self.dec_proj(dec_hidden)
        dec_proj = self.dec_norm(dec_proj)

        attn_features = F.gelu(enc_proj + dec_proj)
        attn_features = self.dropout(attn_features)

        # [B, S * W, A] -> [B, S * W, 1] -> [B, S * W]
        attn_scores = self.score_proj(attn_features).squeeze(-1)

        if enc_mask is not None:
            attn_scores = attn_scores.masked_fill(enc_mask == 0, float('-inf'))

        attn_weights = F.softmax(attn_scores, dim=-1)

        # [B, 1, S * W] @ [B, S * W, 2HW] -> [B, 1, 2HW] -> [B, 2HW]
        context = torch.bmm(attn_weights.unsqueeze(1), enc_outputs).squeeze(1)

        return context, attn_weights

In [18]:
class PointerGenerator(nn.Module):
    def __init__(
        self, 
        hidden_size_decoder, 
        hidden_size_word, 
        embedding_dim
    ):
        super().__init__()
        self.ptr_proj = nn.Linear(hidden_size_decoder + 2 * hidden_size_word + embedding_dim, 1)

    def forward(self, context, emb_input, vocab_dist, attn_weights, ext_input_ids, ext_vocab_size):
        """
        Args:
            context: [B, 2HW + H]
            embedded_input: [B, D]
            vocab_dist: [B, V]
            attn_weights: [B, S * W]
            ext_input_ids: [B]
            ext_vocab_size: int
        Returns:
            final_dist: [B, EV]
        """
        B, V = vocab_dist.size()
        device = context.device

        # [B, 2HW + H] cat [B, D] -> [B, 2HW + H + D]
        ptr_input = torch.cat([context, emb_input], dim=-1)

        # [B, 2HW + H + D] -> [B, 1]
        ptr_gate = torch.sigmoid(self.ptr_proj(ptr_input))

        # [B, 1] * [B, V] -> [B, V]
        vocab_dist_scaled = ptr_gate * vocab_dist

        # [B, 1] * [B, S * W] -> [B, S * W]
        attn_dist_scaled = (1 - ptr_gate) * attn_weights

        final_dist = torch.zeros(B, ext_vocab_size, device=device)
        final_dist[:, :vocab_dist.size(-1)] += vocab_dist_scaled
        final_dist.scatter_add_(1, ext_input_ids.long(), attn_dist_scaled)

        return final_dist

In [19]:
class Decoder(nn.Module):
    def __init__(
        self, 
        embedding_dim, 
        embedding_matrix,   
        hidden_size, 
        vocab_size,
        hidden_size_word, 
        residual_configs,
        attn_dim=256, 
        dropout=0.3
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding_dim = embedding_dim
        self.vocab_size = vocab_size
        self.hidden_size_word=hidden_size_word
        self.embedding_matrix = embedding_matrix

        self.dropout = nn.Dropout(dropout)
        
        # === Attention ===

        # Project encoder word_outputs
        self.attention = Attention(
            hidden_size_word=hidden_size_word,
            hidden_size_decoder=hidden_size,
            attention_dim=attn_dim,
            dropout=dropout
        )

        self.lstm_input_norm = nn.LayerNorm(embedding_dim + 2 * hidden_size_word)

        # === Decoder LSTM ===
        self.dec_lstm = nn.LSTM(
            input_size=self.embedding_dim + 2 * hidden_size_word,
            hidden_size=hidden_size,
            batch_first=True
        )

        self.residuals = nn.Sequential(
            *[
                ResidualBlock(
                    input_dim=hidden_size,
                    hidden_dim=config['hidden_dim'],
                    # dropout=dropout,
                    activation=config['activation']
                )
                for config in residual_configs
            ]
        )

        # === Vocab projection ===
        self.contextual_dec_out_norm = nn.LayerNorm(hidden_size + 2 * hidden_size_word)
        self.emb_proj = nn.Linear(hidden_size + 2 * hidden_size_word, self.embedding_dim)
        
        # === Pointer-generator gate ===
        self.ptr_gen = PointerGenerator(
            hidden_size_decoder=self.hidden_size,
            hidden_size_word=self.hidden_size_word,
            embedding_dim=self.embedding_dim
        )
        
    def forward(
        self, 
        embedded_input, 
        decoder_state, 
        encoder_word_outputs, 
        ext_input_ids, 
        ext_vocab_size, 
        encoder_mask=None,
        debug=False,
        enc_proj=None
    ):
        """
        B: Batch size
        S: Number of sentences
        W: Number of words in a sentences
        D: Embedding dim
        HW: Hidden word size
        HS: Hidden sent size
        H: Hidden decoder size
        A: Attention dim
        V: Vocab size
        EV: Extended vocab size
        
        Args:
            embedded_input: [B, D]
            decoder_state: ([1, B, H], [1, B, H])
            encoder_word_outputs: [B, S * W, 2HW]
            ext_input_ids: [B, S * W]
            ext_vocab_size: int
        Returns:
            final_dist: [B, EV]
            next_decoder_state: ([1, B, H], [1, B, H])
        """
        B, SxW, _ = encoder_word_outputs.size()

        # === Embedding input token ===
        
        # [B, D] -> [B, 1, D]
        embedded_input = embedded_input.unsqueeze(1)

        # === Attention ===
        # [1, B, H] -> [B, 1, H]
        dec_hidden = decoder_state[0].transpose(0, 1)

        # [B, 2HW], [B, S * W]
        context, attn_weights = self.attention(dec_hidden, encoder_word_outputs, enc_proj, encoder_mask)

        # [B, 2HW] -> [B, 1, 2HW]
        context = context.unsqueeze(1)

        # === Decoder LSTM ===
        
        # [B, 1, D] cat [B, 1, 2HW] -> [B, 1, D + 2HW]
        lstm_input = torch.cat([embedded_input, context], dim=-1)
        lstm_input = self.lstm_input_norm(lstm_input)
        lstm_input = self.dropout(lstm_input)

        lstm_out, next_decoder_state = self.dec_lstm(lstm_input, decoder_state)
        lstm_out = self.residuals(lstm_out)

        # === Vocab distribution ===
        
        # [B, 1, H] cat [B, 1, 2HW] -> [B, 1, H + 2HW]
        contextual_dec_out = torch.cat([lstm_out, context], dim=-1)
        contextual_dec_out = self.contextual_dec_out_norm(contextual_dec_out)
        contextual_dec_out = self.dropout(contextual_dec_out)

        # [B, 1, H + 2HW] -> [B, 1, D] -> [B, D]
        emb_out = self.emb_proj(contextual_dec_out).squeeze(1)
        
        # [B, D] x [D, V] -> [B, V]
        vocab_logits = torch.matmul(emb_out, self.embedding_matrix.T)
        vocab_dist = F.softmax(vocab_logits, dim=-1)

        # === Pointer generator ===

        # [B, 1, H + 2HW] -> [B, H + 2HW]
        contextual_dec_out = contextual_dec_out.squeeze(1)

        # [B, 1, D] -> [B, D]
        embedded_input = embedded_input.squeeze(1)
        
        final_dist = self.ptr_gen(contextual_dec_out, embedded_input, vocab_dist, attn_weights, ext_input_ids, ext_vocab_size)

        return final_dist, next_decoder_state

In [20]:
class Model(nn.Module):
    def __init__(
            self,
            vocab: Vocab,
            enc_hidden_size_word,
            enc_word_residual_configs,
            enc_hidden_size_sent,
            enc_sent_residual_configs,
            dec_hidden_size,
            dec_residual_configs,
            dec_attn_dim,
            device: torch.device,
            dropout=0.3
    ):
        super().__init__()
        self.vocab = vocab

        # === Embedding ===
        self.embedding_layer = Word2VecEmbedding(vocab)

        # === Encoder ===
        self.encoder = Encoder(
            embedding_dim=vocab.embedding_dim,
            hidden_size_word=enc_hidden_size_word,
            word_residual_configs=enc_word_residual_configs,
            hidden_size_sent=enc_hidden_size_sent,
            sent_residual_configs=enc_sent_residual_configs,
            dropout=dropout
        )

        # === Adapter ===
        self.dropout = nn.Dropout(dropout)
        self.adapter = nn.Linear(2 * enc_hidden_size_sent, dec_hidden_size)

        # === Decoder ===
        self.decoder = Decoder(
            embedding_dim=vocab.embedding_dim,
            embedding_matrix=self.embedding_layer.embedding.weight,
            hidden_size=dec_hidden_size,
            vocab_size=len(vocab),
            hidden_size_word=enc_hidden_size_word,
            residual_configs=dec_residual_configs,
            attn_dim=dec_attn_dim,
            dropout=dropout
        )
        
        self.to(device)
    
    def forward(
        self, 
        input_ids, 
        ext_input_ids, 
        target_ids, 
        attention_mask, 
        teacher_forcing_ratio=0.5, 
        debug=False
    ):
        """
        B: Batch size
        S: Number of sentences
        W: Number of words in a sentences
        D: Embedding dim
        HW: Hidden word size
        HS: Hidden sent size
        H: Hidden decoder size
        A: Attention dim
        V: Vocab size
        EV: Extended vocab size
        L: Target length
        
        Args:
            input_ids: [B, S, W]
            ext_input_ids: [B, S, W]
            target_ids: [B, L]
            attention_mask: [B, S, W]
        Returns:
            final_dists: [B, L, EV]
        """
        B, T = target_ids.size()
        pad_id = self.vocab.pad_id
        
        target_mask = target_ids != pad_id
        max_len = target_mask.sum(dim=1).max().item()

        # === Encode ===

        # [B, S, W] -> [B, S, W, D]
        enc_embedded_input = self.embedding_layer(input_ids)
        
        # [B, 2HS], [B, S * W, 2HW]
        sent_hidden, word_outputs = self.encoder(enc_embedded_input, attention_mask, debug=debug)

        # Precompute encoder attention projection

        # [B, S * W, 2HW] -> [B, S * W, A]
        enc_proj = self.decoder.attention.enc_proj(word_outputs)
        enc_proj = self.decoder.attention.enc_norm(enc_proj)
        
        # === Init decoder state ===
        
        # [B]
        decoder_input = torch.full((B,), self.vocab.sos_id, dtype=torch.long, device=input_ids.device)
        
        # [B] -> [B, D]
        decoder_embedded_input = self.embedding_layer(decoder_input)
        
        # [B, 2HS] -> [B, H] -> [1, B, H]
        h = self.adapter(sent_hidden).unsqueeze(0)
        h = F.relu(h)
        h = self.dropout(h)
        c = torch.zeros_like(h)
        
        decoder_state = (h, c)
        
        # [B, S, W] -> [B, S * W]
        input_ids = input_ids.view(B, -1)
        ext_input_ids = ext_input_ids.view(B, -1)

        # [B, S, W] -> [B, S * W]
        encoder_mask = attention_mask.view(B, -1)

        ext_vocab_size = ext_input_ids.max().item() + 1

        # === Decode ===
        final_dists = torch.zeros(B, T, ext_vocab_size, device=input_ids.device, dtype=torch.float32)
        for t in range(max_len):

            # [B, EV], ([1, B, H], [1, B, H])
            final_dist, decoder_state = self.decoder(
                decoder_embedded_input, 
                decoder_state, 
                word_outputs,
                ext_input_ids, 
                ext_vocab_size,
                encoder_mask=encoder_mask,
                debug=debug and (t % 20 == 0),
                enc_proj=enc_proj
            )

            final_dists[:, t, :] = final_dist

            if torch.rand(1).item() < teacher_forcing_ratio:
                decoder_input = target_ids[:, t]
                decoder_embedded_input = self.embedding_layer(decoder_input)
            else:
                decoder_input = final_dist.argmax(1)
                oov_mask = decoder_input >= len(self.vocab)
                if oov_mask.any():
                    oov_tokens = decoder_input[oov_mask]
                    token_pos = (ext_input_ids[oov_mask] == oov_tokens.unsqueeze(1)).float().argmax(dim=1)
                    copied_token = input_ids[oov_mask, token_pos]
                    decoder_input[oov_mask] = copied_token
                decoder_embedded_input = self.embedding_layer(decoder_input)

        return final_dists

In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model(
    vocab=vocab,
    enc_hidden_size_word=192,
    enc_word_residual_configs=[
        {'hidden_dim': 768, 'activation': nn.GELU},
    ],
    enc_hidden_size_sent=384,
    enc_sent_residual_configs=[
        {'hidden_dim': 1536, 'activation': nn.GELU},
        # {'hidden_dim': 1536, 'activation': nn.GELU},
    ],
    dec_hidden_size=512,
    dec_residual_configs=[
        {'hidden_dim': 2048, 'activation': nn.GELU},  
        {'hidden_dim': 1024, 'activation': nn.GELU},  
        # {'hidden_dim': 2048, 'activation': nn.ReLU},  
    ],
    dec_attn_dim=256,
    device=device,
    dropout=0.25
)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)

In [None]:
# del model
# del optimizer
# del scheduler
torch.cuda.empty_cache()
gc.collect()

In [None]:
def count_params(model):
    encoder_params_count = sum(p.numel() for p in model.encoder.parameters() if p.requires_grad)
    decoder_params_count = sum(p.numel() for p in model.decoder.parameters() if p.requires_grad)
    model_params_count = sum(p.numel() for p in model.parameters() if p.requires_grad)

    print(f'Encoder: {encoder_params_count}')
    print(f'Decoder: {decoder_params_count}')
    print(f'Model: {model_params_count}')

count_params(model)

In [None]:
def print_gradients(model, max_elements=500):
    for name, param in model.named_parameters():
        if param.grad is None:
            print(f"{name}: No gradient")
        else:
            grad = param.grad.detach().cpu().numpy()
            num_elements = grad.size
            if num_elements <= max_elements:
                print(f"{name} grad (shape={grad.shape}):\n{grad}\n")
            else:
                print(
                    f"{name} grad (shape={grad.shape}, elements={num_elements}) "
                    f"min={grad.min():.6f}, max={grad.max():.6f}, mean={grad.mean():.6f}"
                )


In [None]:
print_gradients(model, max_elements=100)

## Train

In [22]:
def get_tfr(epoch):
    k = 10
    return k / (k + np.exp(epoch / k))

In [23]:
def train_one_epoch(
    model,
    dataloaders,
    optimizer,
    device,
    teacher_forcing_ratio=0.5,
):
    model.train()
    total_loss = 0
    step = 0
    for dl in dataloaders:
        for batch in tqdm(dl, desc='Train'):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                extend_input_ids = batch['extend_input_ids'].to(device)
                target_ids = batch['target_ids'].to(device)
                extend_target_ids = batch['extend_target_ids'].to(device)
                # oov_lists = batch['oov_lists']

                optimizer.zero_grad()

                outputs = model(
                    input_ids,
                    extend_input_ids,
                    target_ids,
                    attention_mask,
                    # oov_lists,
                    teacher_forcing_ratio,
                    # debug=step%100==0
                )   # [B, T, EV]

                B, T, EV = outputs.shape
                outputs = outputs.view(B * T, EV)
                targets = extend_target_ids.view(-1)

                mask = targets != model.vocab.pad_id
                
                picked_probs = outputs[torch.arange(B * T), targets]
                picked_probs = picked_probs[mask]
                log_probs = - torch.log(picked_probs + 1e-12)
                loss = log_probs.mean()

                loss.backward()
                nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

                optimizer.step()

                total_loss += loss.item()
                step += 1
                # if step % 100 == 0:
                #     print(total_loss / step)
            except Exception as e:
                torch.cuda.empty_cache()
                print(e)
                traceback.print_exc()
        torch.cuda.empty_cache()
        gc.collect()
        print(f'Train loss: {total_loss / step}')

    return total_loss / sum([len(dl) for dl in dataloaders])

def evaluate(model, dataloader, device):
    model.eval()
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)
    total_score = 0
    total_sample = 0
    max_len = 100
    pad_id = model.vocab.pad_id
    sos_id = model.vocab.sos_id
    eos_id = model.vocab.eos_id

    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Validate'):
            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                extend_input_ids = batch['extend_input_ids'].to(device)
                target_ids = batch['target_ids'].to(device)
                extend_target_ids = batch['extend_target_ids'].to(device)
                oov_lists = batch['oov_lists']

                B, T = target_ids.size()

                enc_embedded_input = model.embedding_layer(input_ids)
            
                sent_hidden, word_outputs = model.encoder(enc_embedded_input, attention_mask)

                enc_proj = model.decoder.attention.enc_proj(word_outputs)
                enc_proj = model.decoder.attention.enc_norm(enc_proj)
        
                h = model.adapter(sent_hidden).unsqueeze(0)
                h = F.relu(h)
                c = torch.zeros_like(h)
                decoder_state = (h, c)

                input_ids = input_ids.view(B, -1)
                ext_input_ids = extend_input_ids.view(B, -1)

                encoder_mask = attention_mask.view(B, -1)

                ext_vocab_size = ext_input_ids.max().item() + 1

                seqs = torch.full((B, max_len), sos_id, dtype=torch.long, device=device)
                decoder_embedded_input = model.embedding_layer(seqs[:, 0])
                active_mask = torch.full((B,), True, dtype=torch.bool, device=device)
                for t in range(max_len):
                    if not active_mask.any():
                        break
        
                    # [B, EV], ([1, B, H], [1, B, H])
                    final_dist, decoder_state = model.decoder(
                        decoder_embedded_input, 
                        decoder_state, 
                        word_outputs,
                        ext_input_ids, 
                        ext_vocab_size,
                        encoder_mask=encoder_mask,
                        # debug=debug and (t % 20 == 0),
                        enc_proj=enc_proj
                    )
        
                    decoder_input = final_dist.argmax(1)
                    seqs[active_mask, t] = decoder_input

                    next_active = decoder_input != eos_id
                    temp_active_mask = active_mask.clone()
                    temp_active_mask[active_mask] = next_active
                    active_mask = temp_active_mask
                    decoder_input = decoder_input[next_active]
                    decoder_state = decoder_state[0][:, next_active], decoder_state[1][:, next_active]
                    word_outputs = word_outputs[next_active]
                    ext_input_ids = ext_input_ids[next_active]
                    encoder_mask = encoder_mask[next_active]
                    enc_proj = enc_proj[next_active]
                                        
                    oov_mask = decoder_input >= len(model.vocab)
                    if oov_mask.any():
                        oov_tokens = decoder_input[oov_mask]
                        token_pos = (ext_input_ids[oov_mask] == oov_tokens.unsqueeze(1)).float().argmax(dim=1)
                        copied_token = input_ids[active_mask][oov_mask, token_pos]
                        decoder_input[oov_mask] = copied_token
                    decoder_embedded_input = model.embedding_layer(decoder_input)
    
                seqs = seqs.tolist()
                for pred_seq, tgt_seq, oov_list in zip(seqs, extend_target_ids.tolist(), oov_lists):
                    pred_text = model.vocab.decode(pred_seq, oov_list)
                    tgt_text = model.vocab.decode(tgt_seq, oov_list)
                    rougeL = scorer.score(tgt_text, pred_text)['rougeL'].fmeasure
                    total_score += rougeL
                total_sample += B

            except Exception as e:
                torch.cuda.empty_cache()
                print(e)
                traceback.print_exc()

    rougeL_F1 = total_score / total_sample
    return rougeL_F1

def train_model(
    model, 
    train_loaders, 
    valid_loader, 
    optimizer,
    # scheduler, 
    device, 
    num_epochs,
    checkpoint_path=None
):
    history=[]
    best_score=float('-inf')
    best_epoch = 0
    
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        history = checkpoint['history']
        for epoch_summary in history:
            print(f'Epoch {epoch_summary[0]}:')
            print(f'\tTrain Loss: {epoch_summary[1]}')
            print(f'\tValid RougeL F1 Score: {epoch_summary[2]}')
            print('-' * 50)
        best_entry = max(history, key=lambda x: x[2])
        best_epoch, _, best_score = best_entry
    
    start_epoch = len(history) + 1
    counter = start_epoch - best_epoch - 1
    for epoch in range(start_epoch, start_epoch + num_epochs):
        print(f"Epoch {epoch}\n")
        tfr = get_tfr(epoch)
        train_loss = train_one_epoch(model, train_loaders, optimizer, device, tfr)
        torch.cuda.empty_cache()
        gc.collect()
        
        valid_score = evaluate(model, valid_loader, device)
        
        print(f"\nTrain Loss: {train_loss:.4f}")
        print(f"Valid RougeL F1 Score: {valid_score:.4f}")
        
        torch.cuda.empty_cache()
        gc.collect()

        history.append([epoch, train_loss, valid_score])
        train_loaders = create_dataloaders(train_dataset, n_chunks=4, batch_size=16)

        print('-' * 50)
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'history': history
        },f'last_model.pt')
        if valid_score > best_score:
            best_score = valid_score
            best_epoch = epoch
            counter = 0
            lr = optimizer.param_groups[0]['lr']
            torch.save({
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'history': history
            },f'best_model.pt')
        else:
            counter += 1
            if counter >= 5:
                print(f"\nEarly stop")
                print(f"Best model with rougeL F1 score {best_score:.4f}")
                breakđây
        torch.cuda.empty_cache()
        gc.collect()
    return history

In [24]:
history = train_model(
    model=model,
    train_loaders=train_loaders,
    valid_loader=valid_loader,
    optimizer=optimizer,   
    device=device,
    num_epochs=4,
    checkpoint_path='/kaggle/input/temp_sum_cl1/keras/default/4/last_model (2).pt'
)

Epoch 1:
	Train Loss: 4.023476720182794
	Valid RougeL F1 Score: 0.3212430272957445
--------------------------------------------------
Epoch 2:
	Train Loss: 3.485653291535287
	Valid RougeL F1 Score: 0.3349350829569631
--------------------------------------------------
Epoch 3:
	Train Loss: 3.3265960655958904
	Valid RougeL F1 Score: 0.3423784768101486
--------------------------------------------------
Epoch 4:
	Train Loss: 3.2439604321588105
	Valid RougeL F1 Score: 0.3453295379966757
--------------------------------------------------
Epoch 5:
	Train Loss: 3.1962489743700924
	Valid RougeL F1 Score: 0.350088586728179
--------------------------------------------------
Epoch 6:
	Train Loss: 3.167288500675912
	Valid RougeL F1 Score: 0.34979728399091703
--------------------------------------------------
Epoch 7:
	Train Loss: 3.151025692734897
	Valid RougeL F1 Score: 0.3522262259552
--------------------------------------------------
Epoch 8:
	Train Loss: 3.1489390870827036
	Valid RougeL F1 Scor

Train: 100%|██████████| 4746/4746 [37:23<00:00,  2.12it/s]


Train loss: 3.365365261123294


Train: 100%|██████████| 4746/4746 [37:17<00:00,  2.12it/s]


Train loss: 3.3697469319616045


Train: 100%|██████████| 4746/4746 [37:22<00:00,  2.12it/s]


Train loss: 3.373250090015924


Train: 100%|██████████| 4746/4746 [37:22<00:00,  2.12it/s]


Train loss: 3.3774940278074985


Validate: 100%|██████████| 524/524 [06:22<00:00,  1.37it/s]



Train Loss: 3.3775
Valid RougeL F1 Score: 0.3563
--------------------------------------------------
Epoch 18



Train: 100%|██████████| 4746/4746 [37:20<00:00,  2.12it/s]


Train loss: 3.4085345110007093


Train: 100%|██████████| 4746/4746 [37:24<00:00,  2.11it/s]


Train loss: 3.4181075974878157


Train: 100%|██████████| 4746/4746 [37:28<00:00,  2.11it/s]


Train loss: 3.4198784989693776


Train: 100%|██████████| 4746/4746 [37:26<00:00,  2.11it/s]


Train loss: 3.422079115308542


Validate: 100%|██████████| 524/524 [06:19<00:00,  1.38it/s]



Train Loss: 3.4221
Valid RougeL F1 Score: 0.3590
--------------------------------------------------
Epoch 19



Train: 100%|██████████| 4746/4746 [37:27<00:00,  2.11it/s]


Train loss: 3.4608597342745946


Train: 100%|██████████| 4746/4746 [37:15<00:00,  2.12it/s]


Train loss: 3.464447443451602


Train: 100%|██████████| 4746/4746 [37:18<00:00,  2.12it/s]


Train loss: 3.4645791653000013


Train: 100%|██████████| 4746/4746 [37:19<00:00,  2.12it/s]


Train loss: 3.469405723689379


Validate: 100%|██████████| 524/524 [06:16<00:00,  1.39it/s]



Train Loss: 3.4694
Valid RougeL F1 Score: 0.3565
--------------------------------------------------
Epoch 20



Train: 100%|██████████| 4746/4746 [37:24<00:00,  2.11it/s]


Train loss: 3.5051844111987926


Train: 100%|██████████| 4746/4746 [37:21<00:00,  2.12it/s]


Train loss: 3.5117353535230986


Train: 100%|██████████| 4746/4746 [37:21<00:00,  2.12it/s]


Train loss: 3.5157973740204347


Train: 100%|██████████| 4746/4746 [37:19<00:00,  2.12it/s]


Train loss: 3.5197370967631896


Validate: 100%|██████████| 524/524 [06:19<00:00,  1.38it/s]



Train Loss: 3.5197
Valid RougeL F1 Score: 0.3526
--------------------------------------------------


## Test

In [25]:
def generate(model, input_ids, ext_input_ids, attention_mask, oov_lists, max_len=100, beam_size=5):
    """
    input_ids: [B, S, W]
    extend_input_ids: [B, S, W]
    attention_mask: [B, S, W]
    oov_lists: [B, []]
    """

    model.eval()
    B, S, W = input_ids.shape
    sos_id = model.vocab.sos_id
    eos_id = model.vocab.eos_id
    vocab_size = len(model.vocab)
    device = input_ids.device
    
    with torch.no_grad():

        enc_embedded_input = model.embedding_layer(input_ids)
        
        sent_hidden, word_outputs = model.encoder(enc_embedded_input, attention_mask)
        
        enc_proj = model.decoder.attention.enc_proj(word_outputs)
        enc_proj = model.decoder.attention.enc_norm(enc_proj)
                    
        h = model.adapter(sent_hidden).unsqueeze(0)
        h = F.relu(h)
        h = h.repeat_interleave(beam_size, dim=1)
        c = torch.zeros_like(h)
        decoder_state = (h, c)
        
        input_ids = input_ids.view(B, -1)
        ext_input_ids = ext_input_ids.view(B, -1)
        encoder_mask = attention_mask.view(B, -1)
        ext_vocab_size = ext_input_ids.max().item() + 1

        word_outputs  = word_outputs.repeat_interleave(beam_size, dim=0)   # [B*beam_size, S*W, 2HW]
        enc_proj = enc_proj.repeat_interleave(beam_size, dim=0)       # [B*beam_size, S*W, A]
        encoder_mask  = encoder_mask.repeat_interleave(beam_size, dim=0)   # [B*beam_size, S*W]
        ext_input_ids = ext_input_ids.repeat_interleave(beam_size, dim=0)  # [B*beam_size, S*W]
        input_ids = input_ids.repeat_interleave(beam_size, dim=0)
        
        # seqs: [B * beam-size, n]
        # log_probs: [B * beam-size]
        seqs = torch.full((B * beam_size, 1), sos_id, dtype=torch.long, device=input_ids.device)
        log_probs = torch.zeros(B * beam_size, device=input_ids.device)
        decoder_embedded_input = model.embedding_layer(seqs[:, -1])
        # List({'seq', 'log_prob'})
        finished = [[] for _ in range(B)]

        for _ in range(max_len):

            # [B * beam_size, EV], ([1, B * beam-size, H], [1, B * beam-size, H])
            final_dist, decoder_state = model.decoder(
                decoder_embedded_input, 
                decoder_state, 
                word_outputs,
                ext_input_ids, 
                ext_vocab_size,
                encoder_mask=encoder_mask,
                enc_proj=enc_proj
            )

            log_prob = final_dist.log()

            # [B, beam-size, EV]
            log_prob = log_prob.view(B, beam_size, -1)

            # [B, beam-size, EV]
            total_log_probs = log_probs.view(B, beam_size, 1) + log_prob
            
            # [B, beam-size * EV]
            total_log_probs = total_log_probs.view(B, -1)

            # [B, beam-size]
            topk_log_probs, topk_ids = total_log_probs.topk(beam_size, dim=-1)
            beam_indices = topk_ids // ext_vocab_size
            token_indices = topk_ids % ext_vocab_size

            flat_beam_indices = (beam_indices + (torch.arange(B, device=device) * beam_size).unsqueeze(1)).view(-1)

            old_seqs = seqs[flat_beam_indices]
            new_tokens = token_indices.view(-1, 1)
            seqs = torch.cat([old_seqs, new_tokens], dim=1)

            decoder_input = seqs[:, -1]
            oov_mask = decoder_input >= vocab_size
            if oov_mask.any():
                oov_tokens = decoder_input[oov_mask]
                token_pos = (ext_input_ids[oov_mask] == oov_tokens.unsqueeze(1)).float().argmax(dim=1)
                copied_token = input_ids[oov_mask, token_pos]
                decoder_input[oov_mask] = copied_token
            decoder_embedded_input = model.embedding_layer(decoder_input)

            new_h = decoder_state[0].index_select(1, flat_beam_indices)
            new_c = decoder_state[1].index_select(1, flat_beam_indices)
            decoder_state = (new_h, new_c)

            log_probs = topk_log_probs.view(-1)
            eos_mask = (token_indices == eos_id)
            if eos_mask.any():
                # [[b, k], ...]
                eos_indices = eos_mask.nonzero(as_tuple=False)
                flat_eos_indices = eos_indices[:, 0] * beam_size + eos_indices[:, 1]

                eos_scores = log_probs[flat_eos_indices]
                eos_seqs = seqs[flat_eos_indices]

                for (b, k), score, seq in zip(eos_indices.tolist(), eos_scores.tolist(), eos_seqs):
                    finished[b].append({
                        'seq': seq.tolist(),
                        'log_prob': score / (((5 + len(seq)) / 6) ** 0.6)
                    })
                log_probs[flat_eos_indices] = -1e9
            all_enough = all(len(finished[b]) >= beam_size for b in range(B))
            if all_enough:
                break
        
        results = []
        for b in range(B):
            if finished[b]:
                best_seq = max(finished[b], key=lambda x: x['log_prob'])['seq']
            else:
                best_seq = seqs[b * beam_size].tolist()
            best_seq = best_seq[1:]
            if eos_id in best_seq:
                id = best_seq.index(eos_id)
                best_seq = best_seq[:id]
            results.append(best_seq)
        return results

In [26]:
def test(model, test_loader):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'], use_stemmer=False)

    all_metrics = {
        'rouge1': {'precision': [], 'recall': [], 'f1': []},
        'rouge2': {'precision': [], 'recall': [], 'f1': []},
        'rougeL': {'precision': [], 'recall': [], 'f1': []},
    }
    
    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_loader, desc='Evaluate on test set'):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            extend_input_ids = batch['extend_input_ids'].to(device)
            target_ids = batch['target_ids'].to(device)
            extend_target_ids = batch['extend_target_ids'].to(device)
            oov_lists = batch['oov_lists']

            outputs = generate(model, input_ids, extend_input_ids, attention_mask, oov_lists)
            for pred, tgt, oov in zip(outputs, extend_target_ids.tolist(), oov_lists):
                pred_seq = model.vocab.decode(pred, oov)
                tgt_seq = model.vocab.decode(tgt, oov)

                scores = scorer.score(tgt_seq, pred_seq)
                for key in all_metrics:
                    all_metrics[key]['precision'].append(scores[key].precision)
                    all_metrics[key]['recall'].append(scores[key].recall)
                    all_metrics[key]['f1'].append(scores[key].fmeasure)
        
    avg_metrics = {}
    for key in all_metrics:
        avg_metrics[key] = {
            'precision': sum(all_metrics[key]['precision']) / len(all_metrics[key]['precision']),
            'recall': sum(all_metrics[key]['recall']) / len(all_metrics[key]['recall']),
            'f1': sum(all_metrics[key]['f1']) / len(all_metrics[key]['f1']),
        }

    return avg_metrics

In [27]:
test_df = pd.read_csv('/kaggle/input/text-summarization/test.csv')
test_dataset = SummarizationDataset(
    df=test_df, 
    max_doc_length= max_doc_length, 
    max_sentence_length=max_sent_length, 
    max_summary_length=max_summary_length, 
    vocab=vocab
)

In [28]:
test_loader = DataLoader(test_dataset, batch_size=64, pin_memory=True, num_workers=4, collate_fn=collate_fn)

In [29]:
checkpoint = torch.load('/kaggle/working/best_model.pt')

model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [30]:
scores = test(model, test_loader)
for k, v in scores.items():
    print(f"{k}: Precision={v['precision']:.4f}, Recall={v['recall']:.4f}, F1={v['f1']:.4f}")

Evaluate on test set: 100%|██████████| 1057/1057 [22:59<00:00,  1.31s/it]

rouge1: Precision=0.6269, Recall=0.5095, F1=0.5474
rouge2: Precision=0.2679, Recall=0.2185, F1=0.2342
rougeL: Precision=0.4113, Recall=0.3331, F1=0.3581



