<a href="https://colab.research.google.com/github/marb543/CART498-GenAI/blob/main/A2/assignment_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch

# Load GPT-2 tokenizer and model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

def apply_p_plus_n(text, n):
    """
    Apply the P+N technique where the last word of each line is replaced with
    a token based on adjusted probabilities to produce a more varied prediction.
    Ensure predictions are words only, not punctuation.
    """
    lines = text.split("\n")
    modified_lines = []

    for line in lines:
        words = line.split()
        if not words:
            modified_lines.append(line)
            continue

        last_word = words[-1].rstrip(".,;!?")
        context = " ".join(words[:-1])

        # Tokenize the context and add the last word
        input_ids = tokenizer.encode(context + " " + last_word, return_tensors="pt")

        # Generate predictions for the next token
        with torch.no_grad():
            outputs = model(input_ids)

        logits = outputs.logits[0, -1]

        # Calculate probabilities with adjusted logits
        probabilities = torch.softmax(logits / 1.5, dim=-1)

        # Filter out punctuation tokens
        token_ids = torch.arange(probabilities.size(0))
        non_punct_tokens = [tid for tid in token_ids if tokenizer.decode([tid]).strip().isalpha()]
        filtered_probs = probabilities[non_punct_tokens]
        filtered_indices = torch.tensor(non_punct_tokens)

        # Get the nth most probable token
        top_indices = torch.topk(filtered_probs, k=max(n, 7)).indices
        selected_token = filtered_indices[top_indices[n - 1] if len(top_indices) >= n else top_indices[-1]]
        new_word = tokenizer.decode([selected_token.item()]).strip()

        # Replace the last word and reconstruct the line
        modified_line = " ".join(words[:-1] + [new_word])
        modified_lines.append(modified_line)

    return "\n".join(modified_lines)

# Input text (The Snow Man by Wallace Stevens)
input_text = """One must have a mind of winter
To regard the frost and the boughs
Of the pine-trees crusted with snow;
And have been cold a long time
To behold the junipers shagged with ice,
The spruces rough in the distant glitter
Of the January sun; and not to think
Of any misery in the sound of the wind,
In the sound of a few leaves,
Which is the sound of the land
Full of the same wind
That is blowing in the same bare place
For the listener, who listens in the snow,
And, nothing himself, beholds
Nothing that is not there and the nothing that is."""

# Apply P+7 and P+42 transformations
processed_text_p7 = apply_p_plus_n(input_text, 7)
processed_text_p42 = apply_p_plus_n(input_text, 42)

# Display results
print("=== Original Text ===")
print(input_text)
print("\n=== Processed Text (P+7) ===")
print(processed_text_p7)
print("\n=== Processed Text (P+42) ===")
print(processed_text_p42)

=== Original Text ===
One must have a mind of winter
To regard the frost and the boughs
Of the pine-trees crusted with snow;
And have been cold a long time
To behold the junipers shagged with ice,
The spruces rough in the distant glitter
Of the January sun; and not to think
Of any misery in the sound of the wind,
In the sound of a few leaves,
Which is the sound of the land
Full of the same wind
That is blowing in the same bare place
For the listener, who listens in the snow,
And, nothing himself, beholds
Nothing that is not there and the nothing that is.

=== Processed Text (P+7) ===
One must have a mind of or
To regard the frost and the which
Of the pine-trees crusted with are
And have been cold a long so
To behold the junipers shagged with by
The spruces rough in the distant in
Of the January sun; and not to about
Of any misery in the sound of the I
In the sound of a few flying
Which is the sound of the on
Full of the same from
That is blowing in the same bare I
For the listener, who