In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import transformers
from transformers import BartTokenizerFast

from model import BartSummaryModelV2
from inference import get_top_k_sentences, extract_sentences
from dataset import SummaryDataset
from utils import collate_fn
from truncate import (
    batch_truncate, 
    batch_truncate_with_eq, 
    batch_truncate_with_len, 
    gather_lengths,
)

In [2]:
MODEL_NAME = "gogamza/kobart-summarization"

tokenizer = BartTokenizerFast.from_pretrained(MODEL_NAME)
model = BartSummaryModelV2.from_pretrained(MODEL_NAME)

model = model.cuda()

Some weights of BartSummaryModelV2 were not initialized from the model checkpoint at gogamza/kobart-summarization and are newly initialized: ['classification_head.out_proj.bias', 'classification_head.dense.bias', 'classification_head.out_proj.weight', 'classification_head.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
dataset = SummaryDataset("/opt/ml/dataset/Training/train.parquet", tokenizer, is_train=True, truncate=False)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False, collate_fn=lambda x: collate_fn(x, tokenizer.pad_token_id))

In [4]:
batch = next(iter(dataloader))

In [8]:
input_ids = batch["input_ids"].cuda()
attention_mask = batch["attention_mask"].cuda()
answers = batch["answers"].cuda()

In [16]:
MAX_LEN = 256
model.eval()

for idx, batch in enumerate(dataloader):

    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    answers = batch["answers"]

    print(f"Starting step {idx}")
    while input_ids is not None:

        _input_ids, input_ids, mapping = batch_truncate_with_eq(
            input_ids, 
            MAX_LEN, 
            sep=tokenizer.eos_token_id, 
            padding_value=tokenizer.pad_token_id, 
            eos_value=tokenizer.eos_token_id, 
            return_mapping=True
        )

        lengths = gather_lengths(_input_ids, tokenizer.pad_token_id)

        _attention_mask = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor([1.0] * l) for l in lengths], 
            batch_first=True, 
            padding_value=0.0
        ).to(_input_ids.device)

        _input_ids = _input_ids.cuda()
        _attention_mask = _attention_mask.cuda()
        print(f"_input_ids.shape: {_input_ids.shape}, _attention_mask.shape: {_attention_mask.shape}")

        ext_out = model.classify(
            input_ids=_input_ids, 
            attention_mask=_attention_mask,
        )
        print([f"{k}: {v.shape}" for k, v in ext_out.items() if isinstance(v, torch.Tensor)])

        top_ext_ids = get_top_k_sentences(
            logits=ext_out.logits.clone().detach().cpu(),
            eos_positions=batch["eos_positions"],
            k=3,
        )
        _ext_batch = extract_sentences(batch["input_ids"], batch["eos_positions"], top_ext_ids, tokenizer)
        print(f"_ext_batch.shape: {_ext_batch['input_ids'].shape}")

        input_ids = torch.cat([_ext_batch["input_ids"], input_ids], 1)
        print(f"new _input_ids.shape: {_input_ids.shape}")
        print()

    if idx > 5:
        break

Starting step 0
_input_ids.shape: torch.Size([4, 242]), _attention_mask.shape: torch.Size([4, 242])
['logits: torch.Size([4, 9])', 'encoder_last_hidden_state: torch.Size([4, 242, 768])']
_ext_batch.shape: torch.Size([4, 120])
new _input_ids.shape: torch.Size([4, 242])

_input_ids.shape: torch.Size([4, 248]), _attention_mask.shape: torch.Size([4, 248])
['logits: torch.Size([4, 8])', 'encoder_last_hidden_state: torch.Size([4, 248, 768])']
_ext_batch.shape: torch.Size([4, 94])
new _input_ids.shape: torch.Size([4, 248])

_input_ids.shape: torch.Size([4, 225]), _attention_mask.shape: torch.Size([4, 202])


ValueError: Attention mask should be of size (4, 1, 225, 225), but is torch.Size([4, 1, 202, 202])