In [1]:
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer
from flask import Flask, request, render_template_string
import pandas as pd

app = Flask(__name__)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load data from CSV
csv_file = 'news-article-categories.csv'
news_articles_df = pd.read_csv(csv_file)

description_column_name = 'body'
title_column_name = 'title'

# T5 model setup
model_name = 't5-small'
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)  # Move model to GPU

# Convert text data to model inputs
train_texts = news_articles_df[description_column_name].astype(str).tolist()
train_summaries = news_articles_df[title_column_name].astype(str).tolist()

# Convert text data to model inputs
train_inputs = tokenizer(train_texts, return_tensors='pt', padding=True, truncation=True, max_length=512)
train_inputs = {key: val.to(device) for key, val in train_inputs.items()}  # Move input to GPU

train_labels = tokenizer(train_summaries, return_tensors='pt', padding=True, truncation=True, max_length=128)
train_labels = {key: val.to(device) for key, val in train_labels.items()}  # Move labels to GPU

# Fine-tune the model
# Modified Training Loop with Gradient Accumulation
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
accumulation_steps = 4  # Adjust as needed

for epoch in range(3):
    for i in range(0, len(train_inputs['input_ids']), accumulation_steps):
        batch_inputs = {key: val[i:i+accumulation_steps] for key, val in train_inputs.items()}
        batch_inputs = {key: val.to(device) for key, val in batch_inputs.items()}  # Move input to GPU

        batch_labels = {key: val[i:i+accumulation_steps] for key, val in train_labels.items()}
        batch_labels = {key: val.to(device) for key, val in batch_labels.items()}  # Move labels to GPU

        outputs = model(**batch_inputs, labels=batch_labels['input_ids'])
        loss = outputs.loss
        loss.backward()

        if (i + 1) % accumulation_steps == 0 or i == len(train_inputs['input_ids']) - 1:
            optimizer.step()
            optimizer.zero_grad()

# Define a function for generating summaries
def generate_summary(prompt_template, summary_length=50, num_beams=2):
    input_ids = tokenizer.encode(prompt_template, return_tensors='pt', max_length=512, truncation=True).to(device)
    summary_ids = model.generate(input_ids, max_length=summary_length, num_beams=num_beams, length_penalty=2.0, early_stopping=True)
    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)

@app.route('/', methods=['GET', 'POST'])
def index():
    if request.method == 'POST':
        summary_length = int(request.form['length'])
        summary_style = request.form['style']

        # Get article text from the input box
        article_text = request.form['article_text']

        # Design personalized prompt templates based on user choices
        prompt_template = f'Summarize the following article: "{article_text}" with a {summary_length}-word {summary_style} summary.'
        if summary_style == 'factual':
            prompt_template = f'Provide a {summary_style} summary for the article: "{article_text}". The summary should be {summary_length} words long.'
        elif summary_style == 'humorous':
            prompt_template = f'Add a touch of humor to the {summary_style} summary of this article: "{article_text}". Keep it {summary_length} words long and funny!'

        # Generate summary
        personalized_summary = generate_summary(prompt_template, summary_length)

        return render_template_string(f'<h3>{summary_style.capitalize()} Summary:</h3><p>{personalized_summary}</p>')

    # Render the form for user input
    return render_template_string(
        '<h3>Personalized Summarization Agent</h3>'
        '<form method="post">'
        'Input Article Text:<br>'
        '<textarea name="article_text" rows="4" cols="50" placeholder="Enter article text..."></textarea><br>'
        'Summary Length:<br><input type="number" name="length" min="1" required><br>'
        'Summary Style:<br>'
        '<select name="style" required>'
        '<option value="objective">Objective</option>'
        '<option value="factual">Factual</option>'
        '<option value="humorous">Humorous</option>'
        '</select><br><br>'
        '<input type="submit" value="Generate Summary"></form>'
    )

if __name__ == '__main__':
    app.run()


Using device: cuda


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [20/Jan/2024 03:14:51] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:15:16] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:15:25] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:15:31] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:15:37] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:15:41] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:15:49] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:16:14] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:16:19] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:16:27] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:16:36] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:16:41] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:16:52] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:17:04] "POST / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:17:59] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [20/Jan/2024 03:18:07]