In [1]:
import torch
from transformers import PegasusForConditionalGeneration, PegasusTokenizer

In [2]:
# Create tokenizer
tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum")

In [3]:
model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum")

Some weights of PegasusForConditionalGeneration were not initialized from the model checkpoint at google/pegasus-xsum and are newly initialized: ['model.decoder.embed_positions.weight', 'model.encoder.embed_positions.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
text = """
A recurrent neural network (RNN) is one of the two broad types of artificial neural network, characterized by direction of the flow of information between its layers. In contrast to the uni-directional feedforward neural network, it is a bi-directional artificial neural network, meaning that it allows the output from some nodes to affect subsequent input to the same nodes. Their ability to use internal state (memory) to process arbitrary sequences of inputs[1][2][3] makes them applicable to tasks such as unsegmented, connected handwriting recognition[4] or speech recognition.[5][6] The term "recurrent neural network" is used to refer to the class of networks with an infinite impulse response, whereas "convolutional neural network" refers to the class of finite impulse response. Both classes of networks exhibit temporal dynamic behavior.[7] A finite impulse recurrent network is a directed acyclic graph that can be unrolled and replaced with a strictly feedforward neural network, while an infinite impulse recurrent network is a directed cyclic graph that cannot be unrolled.

Additional stored states and the storage under direct control by the network can be added to both infinite-impulse and finite-impulse networks. Another network or graph can also replace the storage if that incorporates time delays or has feedback loops. Such controlled states are referred to as gated states or gated memory and are part of long short-term memory networks (LSTMs) and gated recurrent units. This is also called Feedback Neural Network (FNN). Recurrent neural networks are theoretically Turing complete and can run arbitrary programs to process arbitrary sequences of inputs.[8]
"""

In [5]:
tokens = tokenizer(text, truncation=True, padding="longest", return_tensors="pt")

In [6]:
summary = model.generate(**tokens)

In [7]:
summary

tensor([[    0,   202, 27441, 14849,   952,   117,   114,  6989,   121, 37390,
          4958, 14849,   952,   108,  2050,   120,   126,   871,   109,  2940,
           135,   181, 11406,   112,  2384,  5751,  3196,   112,   109,   310,
         11406,   107,     1]])

In [8]:
print(tokenizer.decode(summary[0]))

<pad>A recurrent neural network is a bi-directional artificial neural network, meaning that it allows the output from some nodes to affect subsequent input to the same nodes.</s>


In [9]:
type(tokens['input_ids'])

torch.Tensor

In [10]:
len(tokens['input_ids'][0])

325

In [11]:
text = "Hello and this is some text. And"

In [12]:
test_tokens = tokenizer(text,  truncation=True, padding="longest", return_tensors="pt")

In [13]:
len(test_tokens['input_ids'][0])

9

In [14]:
test_tokens['input_ids'][0]

tensor([8087,  111,  136,  117,  181, 1352,  107,  325,    1])

In [15]:
print(tokenizer.max_len_single_sentence)

511


In [16]:
import re

def count_words_and_punctuation(text):
    # Define the regex pattern for words (considering words as sequences of alphanumeric characters)
    word_pattern = r'\b\w+\b'
    # Define the regex pattern for punctuation marks (common punctuation marks)
    punctuation_pattern = r'[.,!?;:()\'"-]'

    # Find all words using the word pattern
    words = re.findall(word_pattern, text)
    # Find all punctuation marks using the punctuation pattern
    punctuation_marks = re.findall(punctuation_pattern, text)

    # Count the number of words and punctuation marks
    word_count = len(words)
    punctuation_count = len(punctuation_marks)

    return word_count, punctuation_count


In [23]:
test = "This sentence is. six words only"

In [24]:
word_count, punct_count = count_words_and_punctuation(test)

In [25]:
max_len = 3

In [26]:
word_count / max_len

2.0

In [29]:
def split_text_equal_words(text):
    word_pattern = r'\b\w+\b'
    
    elements = re.findall(word_pattern, text)
    
    total_elements = len(elements)
    half_elements = total_elements // 2
    
    running_count = 0
    split_index = 0
    
    for match in re.finditer(word_pattern, text):
        running_count += 1
        if running_count >= half_elements:
            split_index = match.end()
            break
    
    part1 = text[:split_index]
    part2 = text[split_index:]
    
    return part1, part2