In [44]:
import os
import re
import warnings
import random
from collections import defaultdict
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import numpy as np
import numpy as np
import torch
from tqdm.notebook import tqdm
from transformers import GPT2LMHeadModel, GPT2Tokenizer

warnings.filterwarnings("ignore")

In [45]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything(42)

## Задание

1) Реализовать методы `greedy_sampling` и `generate` (1 балл)
2) Реализовать метод `random_sampling` и поддержать его в `generate` (1 балл)
3) Реализовать метод `_beam_search_generate` и поддержать его в `generate` (2 балла)
4) Реализовать методы `apply_top_p`, `apply_top_k`, `apply_temperature` и поддержать их в `generate` (1 балл)  
Все методы необходимо реализовать через векторные операции в torch/numpy везде где это возможно

### Задание 1. Реализация Greedy_sampling и generate

Greedy sampling – всегда выбираем токен с максимальным логитом.

In [46]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        if logits.dim() == 2:  # (batch_sz, vocab_sz)
            token_id = torch.argmax(logits, dim=-1)[0]
        elif logits.dim() == 1:  # (vocab_sz)
            token_id = torch.argmax(logits)
        return int(token_id.item())

    def generate(
        self,
        prompt: str,
        max_length: int = 50,  # количество новых токенов
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3,
    ) -> str:
        self.model.eval()

        input_promt_tok = self.tokenizer.encode(prompt, return_tensors="pt").to(
            self.device
        )

        generated = input_promt_tok

        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.model(input_ids=generated)

                next_token_logits = outputs.logits[:, -1, :]
                next_token_id = self.greedy_sampling(next_token_logits)

                next_token_tensor = torch.tensor(
                    [[next_token_id]], device=self.device, dtype=generated.dtype
                )
                generated = torch.cat([generated, next_token_tensor], dim=1)

                # ранняя остановка по EOS
                if next_token_id == self.tokenizer.eos_token_id:
                    break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [47]:
model = Model("gpt2")
print(model.generate("Machine learning is", max_length=50, strategy="greedy"))

Machine learning is a very powerful tool for learning about the world around us. It's a tool that can help us understand the world around us, and it's a tool that can help us understand the world around us.

The world is changing. We're


### Задание 2. Реализация метода random sampling и поддержка его в generate

Random sampling – случайный выбор токена по softmax распределению.

In [48]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        if logits.dim() == 2:  # (batch_sz, vocab_sz)
            token_id = torch.argmax(logits, dim=-1)[0]
        elif logits.dim() == 1:  # (vocab_sz)
            token_id = torch.argmax(logits)
        return int(token_id.item())

    def random_sampling(self, logits: torch.Tensor) -> int:

        if logits.dim() == 2:  # (batch_sz,vocab_sz)
            logits = logits[0]

        probs = F.softmax(logits, dim=-1)  # (vocab_sz)
        token_id = torch.multinomial(probs, num_samples=1)

        return int(token_id.item())

    def generate(
        self,
        prompt: str,
        max_length: int = 50,  # количество новых токенов
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3,
    ) -> str:
        self.model.eval()

        input_promt_tok = self.tokenizer.encode(prompt, return_tensors="pt").to(
            self.device
        )

        generated = input_promt_tok

        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.model(input_ids=generated)

                next_token_logits = outputs.logits[:, -1, :]

                if strategy == "greedy":
                    next_token_id = self.greedy_sampling(next_token_logits)
                elif strategy == "random":
                    next_token_id = self.random_sampling(next_token_logits)

                next_token_tensor = torch.tensor(
                    [[next_token_id]], device=self.device, dtype=generated.dtype
                )
                generated = torch.cat([generated, next_token_tensor], dim=1)

                if next_token_id == self.tokenizer.eos_token_id:
                    break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [49]:
model = Model("gpt2")

print("GREEDY:")
print(model.generate("Machine learning is", max_length=30, strategy="greedy"))

print("\nRANDOM:")
print(model.generate("Machine learning is", max_length=30, strategy="random"))

print("\nGREEDY:")
print(model.generate("Machine learning is", max_length=30, strategy="greedy"))

print("\nRANDOM:")
print(model.generate("Machine learning is", max_length=30, strategy="random"))

GREEDY:
Machine learning is a very powerful tool for learning about the world around us. It's a tool that can help us understand the world around us, and it's a

RANDOM:
Machine learning is solid and intelligent work, and in this more egalitarian Zurich, the project will once again bring weirder roles for cognitive design to play inside ACM

GREEDY:
Machine learning is a very powerful tool for learning about the world around us. It's a tool that can help us understand the world around us, and it's a

RANDOM:
Machine learning is an American invention in several respects. It involves exploration of datasets for global North American smog recorders and consumers can have accuracy were it to adjust their


### Задание 3. Реализация метода beam_search_generate и поддержка его в generate

Beam search – на каждом шаге:
1. Для варианта считаем логиты следующего токена;
2. Идем по всем веткам на фиксир. глубину и считаем вероятность по ветке;
3. Выбираем лучшую ветку.

In [50]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        if logits.dim() == 2: # (batch_sz, vocab_sz)
            token_id = torch.argmax(logits, dim=-1)[0]
        elif logits.dim() == 1: # (vocab_sz)
            token_id = torch.argmax(logits)
        return int(token_id.item())

    def random_sampling(self, logits: torch.Tensor) -> int:
        
        if logits.dim() == 2: #(batch_sz,vocab_sz)
            logits = logits[0]
        
        probs = F.softmax(logits, dim=-1) #(vocab_sz)
        token_id = torch.multinomial(probs, num_samples=1)

        return int(token_id.item())

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int # количество гипотез на шаге
    ) -> str:
        self.model.eval()
        
        input_promt_tok = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        eos_id = self.tokenizer.eos_token_id

        beams = [(input_promt_tok, 0.0)]
        
        with torch.no_grad():
            for _ in range(max_length):
  
                all_input_batch = torch.cat([seq for (seq, _) in beams], dim=0)  # (B, seq_len)

                outputs = self.model(input_ids=all_input_batch)
          
                next_logits = outputs.logits[:, -1, :] # логит последнего токена

                log_probs = F.log_softmax(next_logits, dim=-1)  # (B, vocab)
                vocab_size = log_probs.size(-1)

                beam_scores = torch.tensor([score for (_, score) in beams],
device=self.device)

                total_scores = beam_scores.unsqueeze(1) + log_probs  # (B, vocab)

                total_scores_row = total_scores.view(-1)  # (B * vocab)
                topk = min(num_beams, total_scores_row.size(0))
                top_scores, top_indices = torch.topk(total_scores_row, k=topk)

                new_beams = []
                for score, idx in zip(top_scores, top_indices):
  
                    beam_idx = (idx // vocab_size).item()
                    token_id = (idx % vocab_size).item()

                    seq, _ = beams[beam_idx]

                    token_tensor = torch.tensor([[token_id]], device=self.device, dtype=seq.dtype)

                    new_seq = torch.cat([seq, token_tensor], dim=1)

                    new_beams.append((new_seq, float(score.item())))

                beams = new_beams

                if all(seq[0, -1].item() == eos_id for (seq, _) in beams):
                    break
                
        best_seq, _ = max(beams, key=lambda x: x[1])
        return self.tokenizer.decode(best_seq[0], skip_special_tokens=True)

    def generate(
        self,
        prompt: str,
        max_length: int = 50, # количество новых токенов
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:
        self.model.eval()

        if strategy == "beam":
            return self._beam_search_generate(prompt=prompt,max_length=max_length, num_beams=num_beams)

        input_promt_tok= self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)

        generated = input_promt_tok

        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.model(input_ids=generated)
            
                next_token_logits = outputs.logits[:, -1, :]

                if strategy == "greedy":
                    next_token_id = self.greedy_sampling(next_token_logits)
                elif strategy == "random":
                    next_token_id = self.random_sampling(next_token_logits)

                next_token_tensor = torch.tensor(
                    [[next_token_id]],
                    device=self.device,
                    dtype=generated.dtype
                )
                generated = torch.cat([generated, next_token_tensor], dim=1)

                if next_token_id == self.tokenizer.eos_token_id:
                    break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [51]:
model = Model("gpt2")

print("BEAM (num_beams=3):")
print(model.generate("Machine learning is", max_length=30, strategy="beam", num_beams=3))

print("\nBEAM (num_beams=6):")
print(model.generate("Machine learning is", max_length=30, strategy="beam", num_beams=5))

print("\nBEAM (num_beams=9):")
print(model.generate("Machine learning is", max_length=30, strategy="beam", num_beams=5))

BEAM (num_beams=3):
Machine learning is a great way to learn about the world around you. It's a great way to learn about the world around you. It's a great way to

BEAM (num_beams=6):
Machine learning is one of the most promising areas of research in the field of artificial intelligence.

In a paper published in the Proceedings of the National Academy of Sciences

BEAM (num_beams=9):
Machine learning is one of the most promising areas of research in the field of artificial intelligence.

In a paper published in the Proceedings of the National Academy of Sciences


### Задание 4. Реализация методов apply_top_p, apply_top_k, apply_temperature и поддержка их в generate

1. Apply_temperature – масштабирование распределения логитов;
2. Apply_top_k – оставляем только top_k токенов с максимальными логитами;
3. Apply_top_p – берем токены, отсортирвоанные по убыванию вер., суммируем из вероятности, пока сумма не превысит порог. 

In [52]:
class Model:
    def __init__(self, model_name: str = "gpt2"):
        self.model = GPT2LMHeadModel.from_pretrained(model_name)
        self.tokenizer = GPT2Tokenizer.from_pretrained(model_name)
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.vocab_size = self.tokenizer.vocab_size

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def greedy_sampling(self, logits: torch.Tensor) -> int:
        if logits.dim() == 2: # (batch_sz, vocab_sz)
            token_id = torch.argmax(logits, dim=-1)[0]
        elif logits.dim() == 1: # (vocab_sz)
            token_id = torch.argmax(logits)
        return int(token_id.item())

    def random_sampling(self, logits: torch.Tensor) -> int:
        
        if logits.dim() == 2: #(batch_sz,vocab_sz)
            logits = logits[0]
        
        probs = F.softmax(logits, dim=-1) #(vocab_sz)
        token_id = torch.multinomial(probs, num_samples=1)

        return int(token_id.item())

    def _beam_search_generate(
        self,
        prompt: str,
        max_length: int,
        num_beams: int # количество гипотез на шаге
    ) -> str:
        self.model.eval()
        
        input_promt_tok = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
        eos_id = self.tokenizer.eos_token_id

        beams = [(input_promt_tok, 0.0)]
        
        with torch.no_grad():
            for _ in range(max_length):
  
                all_input_batch = torch.cat([seq for (seq, _) in beams], dim=0)  # (B, seq_len)

                outputs = self.model(input_ids=all_input_batch)
          
                next_logits = outputs.logits[:, -1, :] # логит последнего токена

                log_probs = F.log_softmax(next_logits, dim=-1)  # (B, vocab)
                vocab_size = log_probs.size(-1)

                beam_scores = torch.tensor([score for (_, score) in beams],
device=self.device)

                total_scores = beam_scores.unsqueeze(1) + log_probs  # (B, vocab)

                total_scores_row = total_scores.view(-1)  # (B * vocab)
                topk = min(num_beams, total_scores_row.size(0))
                top_scores, top_indices = torch.topk(total_scores_row, k=topk)

                new_beams = []
                for score, idx in zip(top_scores, top_indices):
  
                    beam_idx = (idx // vocab_size).item()
                    token_id = (idx % vocab_size).item()

                    seq, _ = beams[beam_idx]

                    token_tensor = torch.tensor([[token_id]], device=self.device, dtype=seq.dtype)

                    new_seq = torch.cat([seq, token_tensor], dim=1)

                    new_beams.append((new_seq, float(score.item())))

                beams = new_beams

                if all(seq[0, -1].item() == eos_id for (seq, _) in beams):
                    break
                
        best_seq, _ = max(beams, key=lambda x: x[1])
        return self.tokenizer.decode(best_seq[0], skip_special_tokens=True)

    def apply_temperature(self, logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
        return logits / temperature

    def _apply_top_p(self, logits: torch.Tensor, top_p: float = 1.0) -> torch.Tensor:

        if top_p >= 1.0:
          return logits
        
        single = False
        if logits.dim() == 1:
            logits_bv = logits.unsqueeze(0)
            single = True
        elif logits.dim() == 2:
            logits_bv = logits

        ver = F.softmax(logits_bv, dim=-1)
        sorted_ver, sorted_idx = ver.sort(dim=-1, descending=True)

        cumsum = sorted_ver.cumsum(dim=-1)

        keep_len = (cumsum < top_p).sum(dim=-1) + 1 # +1 — первый токен за top_p

        V = ver.size(-1)
        arange = torch.arange(V, device=ver.device).unsqueeze(0).expand_as(sorted_ver)
        
        keep_sorted_mask = arange < keep_len.unsqueeze(1)

        keep_mask = torch.zeros_like(ver, dtype=torch.bool)
        keep_mask.scatter_(1, sorted_idx, keep_sorted_mask)

        neg_inf = logits_bv.new_full((), float("-inf"))
        filtered = torch.where(keep_mask, logits_bv, neg_inf)

        return filtered.squeeze(0) if single else filtered

    def _apply_top_k(self, logits: torch.Tensor, top_k: int) -> torch.Tensor:

        if top_k <= 0:
            return logits

        filtered_logits = logits.clone()

        if filtered_logits.dim() == 1:
            vocab_size = filtered_logits.size(0)
            if top_k >= vocab_size:
                return filtered_logits

            topk_values, _ = torch.topk(filtered_logits, top_k)
      
            threshold = topk_values[-1]
    
            filtered_logits[filtered_logits < threshold] = float("-inf")

        elif filtered_logits.dim() == 2:
            vocab_size = filtered_logits.size(1)
            if top_k >= vocab_size:
                return filtered_logits

            topk_values, _ = torch.topk(filtered_logits, top_k, dim=-1)
            thresholds = topk_values[:, -1].unsqueeze(-1)
            mask = filtered_logits < thresholds
            filtered_logits[mask] = float("-inf")

        return filtered_logits

    def generate(
        self,
        prompt: str,
        max_length: int = 50, # количество новых токенов
        strategy: str = "greedy",
        temperature: float = 1.0,
        top_k: int = 0,
        top_p: float = 1.0,
        num_beams: int = 3
    ) -> str:
        self.model.eval()

        if strategy == "beam":
            return self._beam_search_generate(prompt=prompt,max_length=max_length, num_beams=num_beams)

        input_promt_tok= self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)

        generated = input_promt_tok

        with torch.no_grad():
            for _ in range(max_length):
                outputs = self.model(input_ids=generated)
            
                next_token_logits = outputs.logits[:, -1, :]

                if strategy == "greedy":
                    next_token_id = self.greedy_sampling(next_token_logits)
                elif strategy == "random":
                    
                    modified_logits = next_token_logits
                    
                    if temperature != 1.0:
                        modified_logits = self._apply_temperature(modified_logits, temperature)

                    if top_k > 0:
                        modified_logits = self._apply_top_k(modified_logits, top_k)

                    if top_p < 1.0:
                        modified_logits = self._apply_top_p(modified_logits, top_p)

                    next_token_id = self.random_sampling(modified_logits)

                next_token_tensor = torch.tensor(
                    [[next_token_id]],
                    device=self.device,
                    dtype=generated.dtype
                )
                generated = torch.cat([generated, next_token_tensor], dim=1)

                if next_token_id == self.tokenizer.eos_token_id:
                    break

        return self.tokenizer.decode(generated[0], skip_special_tokens=True)

In [53]:

print("RANDOM, default:")
print(model.generate("Machine learning is", max_length=30, strategy="random"))

print("\nRANDOM + temperature=1.5:")
print(model.generate("Machine learning is", max_length=30, strategy="random", temperature=1.5))

print("\nRANDOM + top_k=10:")
print(model.generate("Machine learning is", max_length=30, strategy="random", top_k=10))

print("\nRANDOM + top_p=0.9:")
print(model.generate("Machine learning is", max_length=30, strategy="random", top_p=0.9))

RANDOM, default:
Machine learning is becoming increasingly sophisticated with each passing year.

But does the technology give us confidence that simple, timeless algorithms can outperform demanding systems with significant UI

RANDOM + temperature=1.5:
Machine learning is taking the web further than ever before and Web applications are making real progress. In addition to recognition, researchers around the world are creating strong technology platforms for

RANDOM + top_k=10:
Machine learning is excellent science and science fiction narrative, but they miss the dark side. When it is a slush fund, female fans typically get what they came for

RANDOM + top_p=0.9:
Machine learning is a powerful approach when applied to cognitive plasticity in users. Follow along below for our sensuous treatise on this theme of learning via neural integration.
