In [36]:
import os
import sys
import importlib
import argparse
import nltk
import pandas as pd
import numpy as np
import random
import torch

In [37]:
sys.path.append('../')

In [38]:
from datasets import load_dataset

In [39]:
from transformers import (
    AutoConfig,
    BartConfig,
    LongformerConfig,
    AutoTokenizer,
    AutoModelForSeq2SeqLM, 
    DataCollatorForSeq2Seq, 
)

## Device

In [40]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

## Datasets

In [41]:
news_dataset = load_dataset('metamong1/summarization_news', 
    use_auth_token='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')

Reusing dataset news_summarization (/opt/ml/.cache/huggingface/datasets/metamong1___news_summarization/News Summarization/1.0.0/ae25c3215dc878e979d01f1157dbfb014c0a6985fc959ae45eaf10847db75600)


  0%|          | 0/2 [00:00<?, ?it/s]

## Preprocessing & Filtering

In [42]:
from preprocessor import DocsPreprocessor, Filter

In [43]:
data_preprocessor = DocsPreprocessor()
data_filter = Filter(title_size=5)

In [44]:
news_dataset.cleanup_cache_files()
news_dataset = news_dataset.map(data_preprocessor.for_train)
news_dataset = news_dataset.filter(data_filter)

  0%|          | 0/240628 [00:00<?, ?ex/s]

  0%|          | 0/60157 [00:00<?, ?ex/s]

  0%|          | 0/241 [00:00<?, ?ba/s]

  0%|          | 0/61 [00:00<?, ?ba/s]

In [45]:
news_dataset

DatasetDict({
    train: Dataset({
        features: ['doc_id', 'title', 'text', 'doc_type', 'file'],
        num_rows: 214237
    })
    validation: Dataset({
        features: ['doc_id', 'title', 'text', 'doc_type', 'file'],
        num_rows: 53572
    })
})

## Tokenizer

In [46]:
model_checkpoint = 'gogamza/kobart-summarization'

In [47]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)



## data arguments

In [48]:
from arguments import DataTrainingArguments

In [49]:
data_args = DataTrainingArguments

In [50]:
data_args.max_source_length = 1024

## Processing

In [51]:
from processor import preprocess_function
from functools import partial

In [52]:
train_dataset = news_dataset['train']

In [53]:
column_names = train_dataset.column_names
print(column_names)

['doc_id', 'title', 'text', 'doc_type', 'file']


In [54]:
prep_fn  = partial(preprocess_function, tokenizer=tokenizer, data_args=data_args)
train_dataset = train_dataset.map(
    prep_fn,
    batched=True,
    num_proc=data_args.preprocessing_num_workers,
    remove_columns=column_names,
    load_from_cache_file=not data_args.overwrite_cache,
)

In [58]:
train_dataset

Dataset({
    features: ['attention_mask', 'input_ids', 'labels'],
    num_rows: 214237
})

## Config

In [59]:
bart_config = BartConfig.from_pretrained(model_checkpoint)
longformer_config = LongformerConfig.from_pretrained('allenai/longformer-base-4096')

## Model

In [60]:
from custom_model import CustomForConditionalGeneration

In [61]:
model = CustomForConditionalGeneration(model_config=bart_config, 
    encoder_config=longformer_config)

## Data Collator

In [62]:
from transformers import DataCollatorForSeq2Seq

In [63]:
label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
print(label_pad_token_id)

-100


In [64]:
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=label_pad_token_id,
    pad_to_multiple_of=8
)

## Training

In [65]:
from torch.utils.data import DataLoader

### Data Loader

In [66]:
train_dataloader = DataLoader(
    train_dataset, shuffle=True, batch_size=8, collate_fn=data_collator
)

In [68]:
for batch in train_dataloader:
    break

{k: v.shape for k, v in batch.items()}

{'attention_mask': torch.Size([8, 640]),
 'input_ids': torch.Size([8, 640]),
 'labels': torch.Size([8, 23]),
 'decoder_input_ids': torch.Size([8, 23])}

### Model Outpus

In [69]:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)

In [70]:
outputs.keys()

odict_keys(['loss', 'logits', 'past_key_values', 'encoder_last_hidden_state'])

In [71]:
outputs['loss']

tensor(10.5004, grad_fn=<NllLossBackward0>)

In [32]:
outputs['logits'].shape

torch.Size([8, 29, 30000])

In [33]:
outputs['encoder_last_hidden_state'].shape

torch.Size([8, 512, 768])