# View sample predictions

Trained model from checkpoint

In [1]:
import argparse
import random
import sys
import warnings

import pandas as pd
import yaml

from transformers.models.mt5 import MT5Tokenizer

sys.path.append('..')  # Allow import of project packages
from motion_title_generator.data.motions_data_module import MotionsDataModule
from motion_title_generator.models import t5
from motion_title_generator.lit_models import MT5LitModel

warnings.filterwarnings('ignore')
pd.set_option('display.max_colwidth', 500)

BASE_PATH = "../training/logs/lightning_logs/version_1/"
MODEL_PATH = BASE_PATH + "checkpoints/epoch=009-val_loss=1.242.ckpt"
HPARAMS_PATH =  BASE_PATH + "hparams.yaml"

  from .autonotebook import tqdm as notebook_tqdm


### Load data

In [2]:
# Load pytorch-lightning experiment args
with open(HPARAMS_PATH, "r") as hparams_file:
   lightning_config = argparse.Namespace(**yaml.load(hparams_file, Loader=yaml.Loader))

dataset = MotionsDataModule(lightning_config)
dataset.prepare_data()
dataset.setup()
print(dataset)

Preprocessing data ...
Filtered 618 rows with missing values.
Filtered 7444 texts shorter than 150 characters.
Filtered 16919 texts with generic title.
Number of rows remaining: 145070
Preprocessed data saved to /home/erik/proj/swedish_parliament_motion_summarization/data/downloaded/prepped_training_data.feather
Using 145070 of 145070 examples.
Train: 108802, Val: 21760, Test: 14508
<motion_title_generator.data.motions_data_module.MotionsDataModule object at 0x7efedc76db50>


### Load model

In [3]:
# Instanciate model to pass to lit_model
model = t5.MT5(data_config={}, args=lightning_config)

lit_model = MT5LitModel.load_from_checkpoint(
    checkpoint_path=MODEL_PATH,
    model=model,
)
lit_model.eval()

tokenizer = MT5Tokenizer.from_pretrained(model.model_name)

Lightning automatically upgraded your loaded checkpoint from v1.6.4 to v2.0.2. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint --file ../training/logs/lightning_logs/version_1/checkpoints/epoch=009-val_loss=1.242.ckpt`


### Generate title

In [18]:
def summarize(model, text, tokenizer, text_max_num_tokens, summary_max_num_tokens):
    text_encoding = tokenizer(
        text,
        max_length=text_max_num_tokens,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt",
    )

    model.model.to("cpu")
    generated_ids = model.model.generate(
        input_ids=text_encoding["input_ids"],
        attention_mask=text_encoding["attention_mask"],
        max_length=summary_max_num_tokens,
        num_beams=2,
        repetition_penalty=5.0,
        length_penalty=2.0,
        early_stopping=True,
    )

    preds = [
        tokenizer.decode(
            gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        for gen_id in generated_ids
    ]

    return "".join(preds)

In [19]:
def show_sample_pred(tokenizer):
    sample_index = random.randint(0, len(dataset.data_test.data))
    text = dataset.data_test.data[sample_index]
    true_summary = dataset.data_test.targets[sample_index]
    
    model_summary = summarize(
        model=model,
        text=text,
        tokenizer=tokenizer,
        text_max_num_tokens=512,
        summary_max_num_tokens=64
    )
    print("Motion text:")
    print(50*"-")
    print(text[:500])
    print(50*"-")
    print("Actual title:")
    print(true_summary)
    print(50*"-")
    print("Predicted title:")
    print(model_summary)
    print(50*"-")

show_sample_pred(lit_model.tokenizer)

Motion text:
--------------------------------------------------
Arbetsmarknaden på både scen och medieområdet har förändrats kraftigt under senare år. På scenkonstområdet (teater, musikteater och dans) är allt färre fast anställda och alltfler frilansare. Medieområdet (film, tv, radio och nya medier) har alltid kännetecknats av frilansare. Men också utvecklingen på medieområdet har gått mot en ytterligare fragmentiserad arbetsmarknad med arbetsgivare eller uppdragsgivare som inte kan eller vill ta ett mer långsiktigt ansvar för sina korttidsanställda. Denna 
--------------------------------------------------
Actual title:
Kompetensutveckling inom scenkonst och film
--------------------------------------------------
Predicted title:
Kompetensutveckling inom scenkonst och film
--------------------------------------------------
