In [38]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import pandas as pd
import math
from transformers import PreTrainedTokenizerFast

In [39]:
class AbstractiveSummarizationDataset(Dataset):
    
    def __init__(self, csv_path: str):
        print("Loading CSV file...")
        self.df = pd.read_csv(csv_path)

    def __getitem__(self, idx):
        doc, sum_ = self.df.iloc[idx]
        return doc + " <sep> " + sum_ + " <eos>"
    
    def __len__(self):
        return len(self.df.index)
    

class TokenizeCollate:

    def __init__(self, tokenizer_obj):
        self.tokenizer = tokenizer_obj

    def __call__(self, x):
        tokenized = self.tokenizer(x, return_tensors="pt", padding=True).input_ids
        return tokenized[:, :-1], tokenized[:, 1:]

In [40]:
class PositionalEncoder(nn.Module):

    def __init__(self, d_model, max_seq_len, p=0.1):
        super(PositionalEncoder, self).__init__()
        position = torch.arange(max_seq_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_seq_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)
        self.dropout = nn.Dropout(p=p)

    # x has shape [batch, seq_len, embed_dim]
    def forward(self, x):
        _, seq_len, _ = x.shape
        out = self.dropout(x + self.pe[:, :seq_len, :])
        return out


class QKVLayer(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads, bias=True):
        super(QKVLayer, self).__init__()
        assert output_dim % num_heads == 0
        self.num_heads = num_heads
        self.q_layer = nn.Linear(input_dim, output_dim, bias=bias)
        self.kv_layer = nn.Linear(input_dim, output_dim * 2, bias=bias)

    def reshape(self, x):
        return x.view(x.shape[0], x.shape[1], self.num_heads, -1).permute(0, 2, 1, 3)

    def forward(self, x1, x2=None):
        """
        x1: [batch_size, seq1_len, embed_dim]
        x2: [batch_size, seq2_len, embed_dim]
        (x2 is None if self attention)
        """
        q = self.q_layer(x1)
        if x2 is None:
            k, v = self.kv_layer(x1).chunk(2, dim=-1)
        else:
            k, v = self.kv_layer(x2).chunk(2, dim=-1)
        return q, k, v
        

def scaled_dot_product_attention(q, k, v, mask):
    """
    q: [batch_size, num_heads, head_dim, seq1_len]
    k: [batch_size, num_heads, head_dim, seq2_len]
    v: [batch_size, num_heads, head_dim, seq2_len]
    mask: [batch_size, num_heads, seq1_len, seq1_len]
    (seq1_len = seq2_len for self attention)
    """
    qk = q.matmul(k.transpose(-1, -2)) / math.sqrt(q.shape[-1])
    if mask:
        mask = torch.tril(torch.ones(1, 1, qk.shape[-2], qk.shape[-1])).type(torch.bool)
        qk = qk.masked_fill(~mask, -torch.inf)
    attn_weights = qk.softmax(dim=-1)
    return attn_weights.matmul(v)


class CompressKV(nn.Module):

    def __init__(self, input_shape, compress_factor):
        super(CompressKV, self).__init__()
        self.compress_factor = compress_factor
        self.conv = nn.Conv1d(
            in_channels=input_shape,
            out_channels=input_shape,
            kernel_size=compress_factor,
            stride=compress_factor
        )

    def pad(self, x):
        pad_amt = self.compress_factor - (x.shape[-1] % self.compress_factor)
        pad_amt = 0 if pad_amt == self.compress_factor else pad_amt
        return torch.cat([x, torch.zeros(x.shape[0], x.shape[1], pad_amt)], dim=-1)

    def forward(self, x):
        """
        x: [batch_size, seq_len, embed_dim]
        """
        x = x.permute(0, 2, 1)
        x = self.pad(x)
        return self.conv(x).permute(0, 2, 1)


class CompressedMultiHeadAttention(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads, compress_factor):
        super(CompressedMultiHeadAttention, self).__init__()
        self.qkv_layer = QKVLayer(input_dim, output_dim, num_heads)
        self.compress_k = CompressKV(output_dim, compress_factor)
        self.compress_v = CompressKV(output_dim, compress_factor)
        self.out_proj = nn.Linear(output_dim, output_dim)

    def reshape(self, x):
        x = x.permute(0, 2, 1, 3)
        return x.reshape(x.shape[0], x.shape[1], -1)

    def forward(self, x, mask, prev_kv=None):
        """
        x: [batch_size, seq_len, embed_dim]
        kv: (k, v) where k, v: [batch_size, seq_len, num_heads * per_head_dim]
        (kv is not None when attending to stored keys and values)
        """
        if prev_kv is None:
            q, k_, v_ = self.qkv_layer(x)
            k, v = self.compress_k(k_), self.compress_v(v_)
        else:
            q, k_, v_ = self.qkv_layer(x)
            k = self.compress_k(torch.cat([prev_kv[0], k_], dim=1))
            v = self.compress_v(torch.cat([prev_kv[1], v_], dim=1))

        q, k, v = self.qkv_layer.reshape(q), self.qkv_layer.reshape(k), self.qkv_layer.reshape(v)
        attn_outputs = scaled_dot_product_attention(q, k, v, mask)
        return self.reshape(attn_outputs), torch.cat([k_.unsqueeze(0), v_.unsqueeze(0)], dim=0)
    

class FeedForward(nn.Module):

    def __init__(self, input_dim):
        super(FeedForward, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(input_dim, input_dim * 4),
            nn.ReLU(),
            nn.Linear(input_dim * 4, input_dim)
        )

    def forward(self, x):
        return self.layer(x)


class DecoderLayer(nn.Module):

    def __init__(self, input_dim, output_dim, num_heads, compress_factor, dropout_p):
        super(DecoderLayer, self).__init__()
        self.dropout_p = dropout_p
        self.masked_compressed_mha = CompressedMultiHeadAttention(input_dim, output_dim, num_heads, compress_factor)
        self.feed_forward = FeedForward(output_dim)
        self.ln1 = nn.LayerNorm(output_dim)
        self.ln2 = nn.LayerNorm(output_dim)

    def forward(self, x, mask, prev_kv=None):
        skip_x = x
        x, kv = self.masked_compressed_mha(x, mask=mask, prev_kv=prev_kv)
        x = self.ln1(F.dropout(x, self.dropout_p) + skip_x)
        skip_x = x
        x = self.feed_forward(x)
        x = self.ln2(F.dropout(x, self.dropout_p) + skip_x)
        return x, kv
    

class DecoderOnlyModel(nn.Module):

    def __init__(self, vocab_size, d_model, num_heads, num_layers, compress_factor, dropout_p, max_seq_len=10000):
        super(DecoderOnlyModel, self).__init__()
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model)
        self.pos_enc = PositionalEncoder(d_model, max_seq_len)
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, d_model, num_heads, compress_factor, dropout_p) for _ in range(num_layers)
        ])
        self.out_proj = nn.Linear(d_model, vocab_size, bias=False)
        self.out_proj.weight = self.embed.weight

    def forward(self, x, prev_kvs=None):
        """
        inputs:
            x: [batch_size, seq_len] shaped tensor of labels upto vocab_size
            prev_kvs: list of length len(self.layers), Each element is of shape [batch_size, context_len, d_model]

        outputs:
            output1: [batch_size, seq_len, vocab_size]
                logits vector
            output2: [num_layers, 2, batch_size, seq_len, d_model]
                concatenated, stored k and v embeddings for all layers
        """
        x = self.pos_enc(self.embed(x) * math.sqrt(self.d_model))

        if prev_kvs is None:
            prev_kvs = [None] * len(self.layers)
        assert len(prev_kvs) == len(self.layers)
        
        kv_out = []
        for layer, prev_kv in zip(self.layers, prev_kvs):
            x, kv = layer(x, mask=True, prev_kv=prev_kv)
            kv_out.append(kv.unsqueeze(0))

        return self.out_proj(x), torch.cat(kv_out, dim=0)        

In [None]:
BATCH_SIZE = 32
D_MODEL = 512
NUM_HEADS = 8
NUM_LAYERS = 6
COMPRESS_FACTOR = 3
DROPOUT_P = 0.1

In [41]:
dataset = AbstractiveSummarizationDataset("xsum.csv")
tokenizer = PreTrainedTokenizerFast.from_pretrained("tokenizer")
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=TokenizeCollate(tokenizer))

Loading CSV file...


In [42]:
dec = DecoderOnlyModel(
    vocab_size=tokenizer.vocab_size,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    compress_factor=COMPRESS_FACTOR,
    dropout_p=DROPOUT_P,
)