In [None]:
import os
os.chdir("..")
os.getcwd()

In [None]:
os.getcwd()

# constants

In [None]:
from dataclasses import dataclass
from text_summarization.utils import read_yaml


CONFIG = read_yaml("config/config.yaml")

@dataclass(frozen=True)
class PredictionConstants:
    PREDICTION_ROOT_DIR_NAME = CONFIG.PREDICTION.ROOT_DIR_NAME
    FILE_NAME = CONFIG.PREDICTION.FILE_NAME




In [None]:
print("PREDICTION_ROOT_DIR_NAME:", PredictionConstants.PREDICTION_ROOT_DIR_NAME)
print("FILE_NAME:", PredictionConstants.FILE_NAME)

# configuration

In [None]:
from dataclasses import dataclass
from pathlib import Path
import os


@dataclass(frozen=True)
class PredictionConfig:
    PREDICTION_ROOT_DIR_PATH = Path(PredictionConstants.PREDICTION_ROOT_DIR_NAME)
    FILE_PATH = os.path.join(PREDICTION_ROOT_DIR_PATH, PredictionConstants.FILE_NAME)




In [None]:
print("PREDICTION_ROOT_DIR_PATH:", PredictionConfig.PREDICTION_ROOT_DIR_PATH)
print("FILE_PATH:", PredictionConfig.FILE_PATH)

# pipeline

In [None]:
from dataclasses import dataclass
from transformers import pipeline
from datetime import datetime
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from text_summarization.utils import  create_dirs, save_json


@dataclass
class PredictionPiepline:

    def predict(self, text:str, tokenizer:AutoTokenizer, model:AutoModelForSeq2SeqLM, output_file_path:str):
        time_stamp = datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
        gen_kwargs = {"length_penalty": 0.8, "num_beams":8, "max_length": 128}

        pipe = pipeline("summarization", model=model, tokenizer=tokenizer)

        output = pipe(text, **gen_kwargs)[0]["summary_text"]

        # get directory and file name for output file
        dir_path, file_name = os.path.split(output_file_path)
        
        # create dir for output file
        create_dirs(dir_path)

        # save output file to local
        file_path = os.path.join(dir_path, f"{time_stamp}_{file_name}")
        save_json({"input":text, "output":output}, file_path)

        return output
    

