In [1]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

In [57]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        return torch.argmax(logits, dim=-1).item()

    def random_sampling(self, logits: torch.Tensor) -> int:
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1).item()

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int
    ) -> str:
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")

        beam_scores = torch.zeros(num_beams, dtype=torch.float)
        beam_sequences = input_ids.repeat(num_beams, 1)

        for step in range(max_length - len(input_ids[0])):
            with torch.no_grad():
                outputs = self.model(beam_sequences)
                next_token_logits = outputs.logits[:, -1, :]

            next_token_probs = F.softmax(next_token_logits, dim=-1)

            vocab_size = next_token_probs.size(-1)
            expanded_scores = beam_scores.unsqueeze(1).expand(-1, vocab_size)
            token_scores = expanded_scores + torch.log(next_token_probs)

            flat_scores = token_scores.view(-1)
            top_scores, top_indices = torch.topk(flat_scores, num_beams)

            beam_indices = top_indices // vocab_size
            token_indices = top_indices % vocab_size

            new_sequences = []
            for i in range(num_beams):
                beam_idx = beam_indices[i]
                token_idx = token_indices[i]
                new_sequence = torch.cat([
                    beam_sequences[beam_idx],
                    token_idx.unsqueeze(0)
                ])
                new_sequences.append(new_sequence)

            beam_sequences = torch.stack(new_sequences)
            beam_scores = top_scores

            if torch.all(beam_sequences[:, -1] == self.tokenizer.eos_token_id):
                break

        best_beam_idx = torch.argmax(beam_scores)
        best_sequence = beam_sequences[best_beam_idx]

        return self.tokenizer.decode(best_sequence, skip_special_tokens=True)

    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        return logits / temperature

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        if top_p >= 1.0:
            return logits

        probs = F.softmax(logits, dim=-1)

        sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1)

        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        mask = cumulative_probs > top_p

        mask[:, 1:] = mask[:, :-1].clone()
        mask[:, 0] = False

        scatter_mask = torch.zeros_like(mask)
        scatter_mask.scatter_(-1, sorted_indices, mask)
        filtered_logits = logits.masked_fill(scatter_mask, float('-inf'))

        return filtered_logits

    def _apply_top_k(self, logits: torch.Tensor, top_k: float = None) -> torch.Tensor:
        if top_k is None or top_k <= 0 or top_k > self.vocab_size:
            return logits

        values, _ = torch.topk(logits, top_k)
        min_value = values[:, -1] if logits.dim() > 1 else values[-1]

        mask = logits < min_value.unsqueeze(-1) if logits.dim() > 1 else logits < min_value

        filtered_logits = logits.masked_fill(mask, float('-inf'))
        return filtered_logits

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:

        if strategy == "beam_search":
            return self._beam_search_generate(prompt, max_length, num_beams)

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        generated = input_ids.clone()

        for _ in range(max_length - len(input_ids[0])):
            with torch.no_grad():
                outputs = self.model(generated)
                next_token_logits = outputs.logits[:, -1, :]

            if temperature != 1.0:
                next_token_logits = self.apply_temperature(next_token_logits, temperature)

            if top_k > 0:
                next_token_logits = self._apply_top_k(next_token_logits, top_k)

            if top_p < 1.0:
                next_token_logits = self._apply_top_p(next_token_logits, top_p)

            if strategy == "greedy":
                next_token_id = self.greedy_sampling(next_token_logits)
            elif strategy == "random":
                next_token_id = self.random_sampling(next_token_logits)

            next_token = torch.tensor([[next_token_id]], dtype=torch.long)
            generated = torch.cat([generated, next_token], dim=1)

            if next_token_id == self.tokenizer.eos_token_id:
                break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [58]:
model = Model()

In [49]:
prompts = [
    'To be or not to',
    'The funny name for my cat is',
    'The capital of Greece is'
]

In [51]:
for prompt in prompts:
    result = model.generate(prompt, strategy="greedy", max_length=35)
    print(result, '\n')

To be or not to be, the only thing that matters is that you're a good person.

I'm not saying that you should be a good person. I 

The funny name for my cat is "The Cat."

I'm not sure if I'm going to be able to keep my cat, but I'm sure I'll 

The capital of Greece is Athens, and the capital of the country is Athens. The capital of Greece is Athens.

The capital of Greece is Athens. The capital of 



In [52]:
for prompt in prompts:
    result = model.generate(prompt, strategy="random", temperature=1.0, max_length=30)
    print(result, '\n')

To be or not to be, the most striking signs of symptomatology are those phenomena which \iikal \u28s\. \ 

The funny name for my cat is Terra (oh and for the cats, like the hat! Who knew you'd donate my work through this blog!). 

The capital of Greece is Athens – another city of rampantly poor inhabitants, who live as refugees and find comfort from the Greek dependents, often inter 



In [53]:
for prompt in prompts:
    result = model.generate(prompt, strategy="random", temperature=0.5, max_length=30)
    print(result, '\n')

To be or not to be, the most important thing is to have a good relationship with the person you are with.

As a general rule 

The funny name for my cat is "Mongoose."

I've been a fan of the show since it first came out and I'm 

The capital of Greece is Athens. The country's economy is booming, and the country's exports are growing.

But the country's economy is 



In [54]:
for prompt in prompts:
    result = model.generate(prompt, strategy="random", temperature=2.0, max_length=30)
    print(result, '\n')

To be or not to self meared from trunk lows src Biplay reparl 1899 Associate Rails Deputy, missions mur Collot?Enlarge decree Eugene Fish 

The funny name for my cat is Darth GambleLock vibucci remix Dolphan desauriad… Curryzilla upload replace companion masks added laser fencing cul 

The capital of Greece is game sublime cash dunk MichelleTurBelow Stories UN Shutdown role src   denotes pride alarms capable McGee erection funnel gadget reception allev lady Gets 



In [55]:
for prompt in prompts:
    result = model.generate(prompt, strategy="random", top_k=10, max_length=30)
    print(result, '\n')

To be or not to be. The only way to know is to go to the store, buy some candy and then go back to the car." 

The funny name for my cat is "Darling," which means "little girl," so I guess he doesn't know what that means?


 

The capital of Greece is Athens. It is an island of over a thousand islands.

In the early days of the Greek Empire, Athens had 



In [56]:
for prompt in prompts:
    result = model.generate(prompt, strategy="random", top_p=0.9, max_length=60)
    print(result, '\n')

To be or not to be an early blazer, you'll need to hide it from around the house as well as with your car. Make sure to keep your UAV as close as you can to the wall and wall of the house so your guests can see your UAV when you come out 

The funny name for my cat is The Cats. Actually, while she is adorable, she has no traits of hers; she spends most of her time staring at her people. This is because her personality doesn't seem to have much to do with nature itself; her broad shoulders, a special tan belly 

The capital of Greece is Athens, but it took five years for it to see its bailout payments cut, and now is starting to see major cuts. Not long ago, an IMF report found that, despite a radical collapse, GDP grew by $1.8tn in 2014. New direct deposits from 



In [59]:
for prompt in prompts:
    result = model.generate(prompt, strategy="beam_search", num_beams=3, max_length=30)
    print(result, '\n')

To be or not to be, the only thing that matters is that you're a good person.

I'm not saying that you should be 

The funny name for my cat is "The Cat."

I'm not sure if I'm going to be able to keep my cat, but 

The capital of Greece is Athens, and the capital of the country is Athens. The capital of Greece is Athens.

The capital of Greece is 



In [60]:
for prompt in prompts:
    result = model.generate(prompt, strategy="random", temperature=0.7, top_p=0.9, max_length=30)
    print(result, '\n')

To be or not to be a real person, I don't know. I'm not interested in it. I'm not interested in being a crazy 

The funny name for my cat is "S-bot." It's a "bot" that's created to make it look like a person's hand 

The capital of Greece is the country with the most mobile-phone use, and the number of mobile phones used per capita is expected to increase by more 

