In [1]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [2]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

## Задание

1) Реализовать методы `greedy_sampling` и `generate` (1 балл)
2) Реализовать метод `random_sampling` и поддержать его в `generate` (1 балл)
3) Реализовать метод `_beam_search_generate` и поддержать его в `generate` (2 балла)
4) Реализовать методы `apply_top_p`, `apply_top_k`, `apply_temperature` и поддержать их в `generate` (1 балл)  
Все методы необходимо реализовать через векторные операции в torch/numpy везде где это возможно

In [8]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        # your code here
        return int(logits.argmax(dim=-1))

    def random_sampling(self, logits: torch.Tensor) -> int:
        # your code here
        probs = logits.softmax(dim=-1)
        return int(probs.multinomial(num_samples=1))

    def _beam_search_generate(self, prompt: str, max_length: int, num_beams: int) -> str:
        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        beam_scores = torch.zeros(num_beams)
        beam_sequences = input_ids.repeat(num_beams, 1)

        for step in range(max_length - input_ids.size(1)):
            with torch.no_grad():
                outputs = self.model(beam_sequences)
                next_token_logits = outputs.logits[:, -1, :]

            log_probs = F.log_softmax(next_token_logits, dim=-1)
            scores = beam_scores.unsqueeze(1) + log_probs


            vocab_size = next_token_logits.size(-1)
            top_scores, top_indices = torch.topk(scores.view(-1), num_beams)

            beam_indices = top_indices // vocab_size
            token_indices = top_indices % vocab_size

            new_sequences = []
            for i in range(num_beams):
                beam_idx = beam_indices[i].item()
                token_idx = token_indices[i].item()

                current_seq = beam_sequences[beam_idx]
                current_seq = current_seq.unsqueeze(0)
                new_token = torch.tensor([[token_idx]], device=current_seq.device)

                new_seq = torch.cat([current_seq, new_token], dim=1)
                new_sequences.append(new_seq)

            beam_sequences = torch.cat(new_sequences, dim=0)
            beam_scores = top_scores


            if all(seq[0, -1].item() == self.tokenizer.eos_token_id for seq in new_sequences):
                break

        best_idx = torch.argmax(beam_scores)
        return self.tokenizer.decode(beam_sequences[best_idx], skip_special_tokens=True)

    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        # your code here
        return logits / temperature

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:
        # your code here
        if top_p >= 1.0:
            return logits

        probs = F.softmax(logits, dim=-1)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

        sorted_indices_to_remove = cumulative_probs > top_p
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = -float('Inf')
        return logits

    def _apply_top_k(self, logits: torch.Tensor, top_k: int = 0) -> torch.Tensor:
        # your code here
        top_k = min(top_k, logits.size(-1))
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float('Inf')
        return logits

    def generate(
        self,
        prompt: str,
        max_length: int = 50,
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:
      # your code here
        if strategy == "beam_search":
            return self._beam_search_generate(prompt, max_length, num_beams)

        input_ids = self.tokenizer.encode(prompt, return_tensors="pt")
        generated = input_ids.clone()

        for _ in range(max_length - input_ids.size(1)):
            with torch.no_grad():
                outputs = self.model(generated)
                next_token_logits = outputs.logits[:, -1, :]

            if strategy == "greedy":
                next_token = self.greedy_sampling(next_token_logits)
            elif strategy == "random":
                logits_temp = self.apply_temperature(next_token_logits, temperature)
                next_token = self.random_sampling(logits_temp)
            elif strategy == "top_k":
                logits_temp = self.apply_temperature(next_token_logits, temperature)
                logits_topk = self._apply_top_k(logits_temp, top_k)
                next_token = self.random_sampling(logits_topk)
            elif strategy == "top_p":
                logits_temp = self.apply_temperature(next_token_logits, temperature)
                logits_topp = self._apply_top_p(logits_temp, top_p)
                next_token = self.random_sampling(logits_topp)
            else:
                raise ValueError(f"Unknown strategy: {strategy}")


            generated = torch.cat([
                generated,
                torch.tensor([[next_token]], device=generated.device)
            ], dim=1)

            if next_token == self.tokenizer.eos_token_id:
                break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [12]:

model = Model("gpt2")
prompt = "A robot walks into a bar and says:"


print(f"'{prompt}'\n")

result1 = model.generate(prompt, max_length=20, strategy="greedy")
print(f"greedy: {result1}")

result2 = model.generate(prompt, max_length=20, strategy="random", temperature=0.8)
print(f"random: {result2}")

result3 = model.generate(prompt, max_length=20, strategy="beam_search", num_beams=3)
print(f"beam: {result3}")

print("\n temperature:")
for temp in [0.3, 1.0, 1.5]:
    result = model.generate(prompt, max_length=15, strategy="random", temperature=temp)
    print(f"temperature {temp}: {result}")

'A robot walks into a bar and says:'

greedy: A robot walks into a bar and says: "I'm going to buy you a beer."

random: A robot walks into a bar and says: "Now, we're going to go fight? You
beam: A robot walks into a bar and says: "I'm going to buy you a beer."


 temperature:
temperature 0.3: A robot walks into a bar and says: 'I'm going to be
temperature 1.0: A robot walks into a bar and says: "%s>I want you
temperature 1.5: A robot walks into a bar and says: ".Che Seventh inning." Saturdays
