In [1]:
from typing import Optional

import torch
import torch.nn
from torch.utils.data import DataLoader

import transformers
from transformers import BartTokenizerFast, BartForConditionalGeneration

from model import BartSummaryModelV2, SentenceClassifierOutput
from dataset import SummaryDataset
from inference import extract_sentences
from utils import collate_fn

In [2]:
MODEL_NAME = "gogamza/kobart-summarization"
tokenizer = BartTokenizerFast.from_pretrained(MODEL_NAME)
model = BartSummaryModelV2.from_pretrained(MODEL_NAME)

Some weights of BartSummaryModelV2 were not initialized from the model checkpoint at gogamza/kobart-summarization and are newly initialized: ['classification_head.out_proj.weight', 'classification_head.dense.bias', 'classification_head.out_proj.bias', 'classification_head.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
dataset = SummaryDataset("/opt/ml/dataset/Training/train.parquet", tokenizer, is_train=True)

In [4]:
dataloader = DataLoader(dataset, 4, shuffle=False, collate_fn=lambda x: collate_fn(x, tokenizer.pad_token_id, sort_by_length=False))

In [5]:
batch = next(iter(dataloader))
batch

{'input_ids': tensor([[    0, 26407,  9770,  ...,     3,     3,     3],
         [    0,   255, 11764,  ...,     3,     3,     3],
         [    0, 12126,  9506,  ..., 23111, 15964,     1],
         [    0, 17493, 17245,  ...,     3,     3,     3]]),
 'attention_mask': tensor([[1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 0., 0., 0.],
         [1., 1., 1.,  ..., 1., 1., 1.],
         [1., 1., 1.,  ..., 0., 0., 0.]]),
 'eos_positions': tensor([[ 14,  31,  53,  87, 122, 153, 199, 241, 281, 300, 329, 369, 425,   0,
            0,   0,   0,   0],
         [ 18,  35,  63,  94, 130, 149, 181, 217, 240, 259, 275, 298, 356,   0,
            0,   0,   0,   0],
         [ 13,  30,  55,  83, 118, 176, 201, 211, 226, 263, 285, 307, 332, 354,
          392, 418, 443, 462],
         [ 19,  36,  67,  98, 131, 169, 193, 231, 288, 336,   0,   0,   0,   0,
            0,   0,   0,   0]]),
 'answers': tensor([[ 2,  3, 10],
         [ 2,  4, 11],
         [ 3,  5,  7],
         [ 2,  3,  4]]

In [6]:
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
eos_positions = batch["eos_positions"]
answers = batch["answers"]
labels = batch["labels"]

In [7]:
gen_inputs = extract_sentences(input_ids, eos_positions, answers, tokenizer)
# extraction이 아직 없어서 answers로 대체했습니당
# 실제 inference에서는 ext_ids에 들어가는 아이가
# 오름차순으로 sorting이 되어 있어야 할 것 같아요! 그래야 순서가 바뀌지 않을 것 같습니다

In [10]:
for k, v in gen_inputs.items():
    print(f"{k}: {v.shape}")

input_ids: torch.Size([4, 97])
attention_mask: torch.Size([4, 97])
