In [1]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

In [3]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size
        self.model.eval()
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        return logits.argmax(dim=-1).item()

    def random_sampling(self, logits: torch.Tensor) -> int:
        probs = F.softmax(logits, dim=-1)
        return torch.multinomial(probs, num_samples=1).item()

    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        return logits / temperature

    def _apply_top_k(self, logits: torch.Tensor, top_k: int = 0) -> torch.Tensor:
        if top_k <= 0 or top_k >= logits.size(-1):
            return logits
        top_k_vals, _ = torch.topk(logits, top_k)
        min_top_k_val = top_k_vals[..., -1, None]
        logits = torch.where(logits < min_top_k_val, torch.full_like(logits, -float('inf')), logits)
        return logits

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        if top_p >= 1.0:
            return logits
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = False
        indices_to_remove = sorted_indices_to_remove.scatter(
            dim=-1, index=sorted_indices, src=sorted_indices_to_remove
        )
        logits = logits.masked_fill(indices_to_remove, -float('inf'))
        return logits

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int
    ) -> str:
        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        batch_size = input_ids.shape[0]
        beam_scores = torch.zeros((num_beams,), device=self.device)
        beam_sequences = input_ids.expand(num_beams, -1).clone()
        done = [False] * num_beams

        cur_len = input_ids.shape[1]
        max_len = min(max_length, self.model.config.max_position_embeddings)

        while cur_len < max_len:
            with torch.no_grad():
                outputs = self.model(beam_sequences)
                next_token_logits = outputs.logits[:, -1, :]

            next_token_probs = F.log_softmax(next_token_logits, dim=-1)
            next_token_scores = beam_scores.unsqueeze(-1) + next_token_probs

            flat_scores = next_token_scores.view(-1)
            top_scores, top_indices = torch.topk(flat_scores, num_beams * 2, sorted=True)

            beam_indices = top_indices // self.vocab_size
            token_indices = top_indices % self.vocab_size

            next_beam_sequences = []
            next_beam_scores = []
            next_done = []
            num_added = 0

            for i in range(len(top_scores)):
                beam_id = beam_indices[i].item()
                token_id = token_indices[i].item()
                score = top_scores[i].item()

                if done[beam_id]:
                    new_seq = beam_sequences[beam_id].clone()
                else:
                    new_seq = torch.cat([beam_sequences[beam_id], torch.tensor([token_id], device=self.device)], dim=0)

                if token_id == self.tokenizer.eos_token_id:
                    next_done.append(True)
                else:
                    next_done.append(done[beam_id])

                next_beam_sequences.append(new_seq)
                next_beam_scores.append(score)
                num_added += 1
                if num_added == num_beams:
                    break

            beam_sequences = torch.nn.utils.rnn.pad_sequence(
                next_beam_sequences, batch_first=True, padding_value=self.tokenizer.eos_token_id
            )
            beam_scores = torch.tensor(next_beam_scores, device=self.device)
            done = next_done
            cur_len += 1

            if all(done):
                break

        best_idx = beam_scores.argmax().item()
        best_sequence = beam_sequences[best_idx]
        generated_text = self.tokenizer.decode(best_sequence, skip_special_tokens=True)
        return generated_text

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:

        if strategy == "beam_search":
            return self._beam_search_generate(prompt, max_length, num_beams)

        input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(self.device)
        cur_len = input_ids.shape[1]
        max_len = min(max_length, self.model.config.max_position_embeddings)

        while cur_len < max_len:
            with torch.no_grad():
                outputs = self.model(input_ids)
                next_token_logits = outputs.logits[:, -1, :]

            next_token_logits = self.apply_temperature(next_token_logits, temperature)
            next_token_logits = self._apply_top_k(next_token_logits, top_k)
            next_token_logits = self._apply_top_p(next_token_logits, top_p)

            if strategy == "greedy":
                next_token = self.greedy_sampling(next_token_logits.squeeze(0))
            elif strategy == "random":
                next_token = self.random_sampling(next_token_logits.squeeze(0))
            else:
                raise ValueError(f"Unknown strategy: {strategy}")

            next_token_tensor = torch.tensor([[next_token]], device=self.device)
            input_ids = torch.cat([input_ids, next_token_tensor], dim=-1)
            cur_len += 1

            if next_token == self.tokenizer.eos_token_id:
                break

        return self.tokenizer.decode(input_ids[0], skip_special_tokens=True)

In [7]:
model = Model("gpt2")

prompt = "Two plus five equal"

print("Greedy:")
print(model.generate(prompt, strategy="greedy"))

Greedy:
Two plus five equal parts of the same number of points.

The game is played in a round-robin format, with each player playing a single round. The first player to reach the top of the round wins. The second player to
