# Homework week08

In [None]:
import torch

import triton
import triton.language as tl

from transformers import AutoTokenizer, AutoModelForCausalLM

## Task 1 [Speculative Sampling][4 points]

Algorithm for accelerating transformer decoding by enabling the generation of multiple tokens from each transformer call. Our algorithm relies on the observation that the latency of parallel scoring of short continuations, generated by a faster but less powerful draft model, is comparable to that of sampling a single token from the larger target model.


Carefully read https://arxiv.org/abs/2302.01318

### Autoregressive Sampling

<img width="654" alt="image" src="https://github.com/markovka17/dla/assets/20357655/db624e40-d4f0-4e36-88e7-b58a6c646738">

Let's take `EleutherAI/gpt-neo-1.3B` LM as a draft model from https://huggingface.co and generate a couple dozen tokens.

In [None]:
# Ensure that your device is set correctly (GPU or CPU)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Load pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
model = AutoModelForCausalLM.from_pretrained('EleutherAI/gpt-neo-1.3B', torch_dtype=torch.float16).to('cuda')
model.eval()  # Set the model to evaluation mode

# Prepare a text prompt
text_prompt = ["The quick brown fox jumps"]
inputs = tokenizer(text_prompt, return_tensors='pt').to(device)  # Tokenize the text prompt and convert to tensor

# Perform text generation (inference)
# Note: manual handling means we will manage the generated text and stop criteria without using generate() method
max_length = 100  # Maximum length of the generated text
temperature = 1.0  # Sampling temperature, higher values mean more randomness

with torch.no_grad():  # Disable gradient calculation for inference
    output_sequence = inputs['input_ids']
    for _ in range(max_length - inputs['input_ids'].size(1)):
        # Predict the next token
        logits = model(output_sequence).logits[:, -1, :]
        
        # Apply temperature
        logits = logits / temperature
        
        # Sample the next token from the probability distribution
        probabilities = torch.nn.functional.softmax(logits, dim=-1)
        next_token = torch.multinomial(probabilities, num_samples=1)
        
        # Append the predicted token to the output sequence
        output_sequence = torch.cat([output_sequence, next_token], dim=1)

        # Check if the end-of-sequence token (EOS) was generated
        if next_token.item() == tokenizer.eos_token_id:
            break

# Decode and print the generated text
generated_text = tokenizer.decode(output_sequence.squeeze(), skip_special_tokens=True)
print("Generated text:\n", generated_text)


### Task 1.1

Ans let's use `EleutherAI/gpt-j-6B` as target model.

In [None]:
large_model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", torch_dtype=torch.float16).to(device)

In speculative sampling, we have two models:

    A smaller, faster draft model (e.g. EleutherAI/gpt-neo-1.3B model)
    A larger, slower target model (e.g. EleutherAI/gpt-j-6B model)

The idea is that the draft model speculates what the output is steps into the future, while the target model determines how many of those tokens we should accept. Here's an outline of the algorithm:

The draft model decodes tokens in the regular autoregressive fashion.
We get the probability outputs of the target and draft model on the new predicted sequence.
We compare the target and draft model probabilities to determine how many of the tokens we want to keep based on some rejection criteria. If a token is rejected, we resample it using a combination of the two distributions and don't accept any more tokens.
If all tokens are accepted, we can sample an additional final token from the target model probability output.

<img width="635" alt="image" src="https://github.com/markovka17/dla/assets/20357655/3954894d-8735-4f92-a835-d04eac74f190">

In [None]:
# The backbone of the speculative sample. Feel free to modify it

def speculative_sampling(x, draft_model, target_model, N, K):
    # NOTE: paper indexes arrays starting from 1, python indexes from 0, so
    # we have to add an extra -1 term when indexing using n, T, or t
    n = len(x)

    for _ in range(N):
        # Step 1: auto-regressive decode K tokens from draft model and get final p
        x_draft = x
        for _ in range(K):
            pass
            # TODO

        # Step 2: target model forward passes on x_draft
        # TODO

        # Step 3: append draft tokens based on rejection criterion and resample
        # a token on rejection
        all_accepted = True
        for _ in range(K):
            pass
            # TODO

        # Step 4: if all draft tokens were accepted, sample a final token
        if all_accepted:
            pass
            # TODO
            

        # just keeping my sanity
        assert n == len(x), f"{n} {len(x)}"

    return x

### Task 1.2

Compare the speed of SpS with ArS. The expected speed increase is 30-50%. 
The speedup is equal to `(time spent by ArS)` / `(time spend by SpS)`

Use same start prompt `The quick brown fox jumps`, `K=16` and `K=32` (compare two scenarios) and `max_length=512`

### Task 1.3

Visualise acceptence rate for `K=[16, 32, 64, 128]`, same start prompt and `max_length max_length=1024`, where graft model is `EleutherAI/gpt-neo-1.3B` and target model if `EleutherAI/gpt-j-6B`.

## Task 2 [GroupNorm in Triton][6 points]


You need to implement a 2D GroupNorm (https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html) and compare it to the PyTorch implementation.
Note that GroupNorm is very similar to LayerNorm, so you can see the LayerNorm implementation here https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html#.

In [None]:
@triton.jit
def _group_norm_fwd_fused(
    X,  # pointer to the input
    Y,  # pointer to the output
    W,  # pointer to the weights
    B,  # pointer to the biases
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride,  # how much to increase the pointer when moving by 1 row
    N,  # number of columns in X
    num_groups,  # number of groups
    eps,  # epsilon to avoid division by zero
    BLOCK_SIZE: tl.constexpr,  # Same parameters as in matmul from seminar
):
    """
    Similar to forward of nn.GroupNorm.forward
    """
    pass

In [None]:
@triton.jit
def _group_norm_bwd_dx_fused(
    DX,  # pointer to the input gradient
    DY,  # pointer to the output gradient
    DW,  # pointer to the partial sum of weights gradient
    DB,  # pointer to the partial sum of biases gradient
    X,  # pointer to the input
    W,  # pointer to the weights
    B,  # pointer to the biases
    Mean,  # pointer to the mean
    Rstd,  # pointer to the 1/std
    stride,  # how much to increase the pointer when moving by 1 row
    N,  # number of columns in X
    eps,  # epsilon to avoid division by zero
    GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr  # Same parameters as in matmul from seminar
):
    """
    Backward of GroupNorm respect to input
    """
    pass

In [None]:
@triton.jit
def _group_norm_bwd_dwdb(
    DW,  # pointer to the partial sum of weights gradient
    DB,  # pointer to the partial sum of biases gradient
    FINAL_DW,  # pointer to the weights gradient
    FINAL_DB,  # pointer to the biases gradient
    M,  # GROUP_SIZE_M
    N,  # number of columns
    BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr  # Same parameters as in matmul from seminar
):
    """
    Backward of GroupNorm respect to weights and biases (affine transform parameters)
    """
    pass

In [None]:
class GroupNorm(torch.autograd.Function):

    @staticmethod
    def forward(ctx):
        pass

    @staticmethod
    def backward(ctx):
        pass

group_norm = GroupNorm.apply

### Testing

In [None]:
def test_group_norm(input_shape, num_groups, dtype, eps=1e-5, device='cuda'):
    # create data
    B, C, H, W = input_shape
    weight = torch.rand(C, dtype=dtype, device='cuda', requires_grad=True)
    bias = torch.rand(C, dtype=dtype, device='cuda', requires_grad=True)
    x = -2.3 + 0.5 * torch.randn(input_shape, dtype=dtype, device='cuda')
    dy = .1 * torch.randn_like(x)

    x.requires_grad_(True)
    # forward pass
    y_tri = group_norm(x, num_groups, weight, bias, eps)
    y_ref = torch.nn.functional.group_norm(x, num_groups, weight, bias, eps).to(dtype)
    # backward pass (triton)
    y_tri.backward(dy, retain_graph=True)
    dx_tri, dw_tri, db_tri = [_.grad.clone() for _ in [x, weight, bias]]
    x.grad, weight.grad, bias.grad = None, None, None
    # backward pass (torch)
    y_ref.backward(dy, retain_graph=True)
    dx_ref, dw_ref, db_ref = [_.grad.clone() for _ in [x, weight, bias]]
    # compare
    assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)
    assert torch.allclose(dx_tri, dx_ref, atol=1e-2, rtol=0)
    assert torch.allclose(db_tri, db_ref, atol=1e-2, rtol=0)
    assert torch.allclose(dw_tri, dw_ref, atol=1e-2, rtol=0)

In [None]:
input_shape = (1, 32, 64, 64)
num_groups = 8
test_group_norm(input_shape, num_groups, torch.float16)

### Task 2.1

Visualize perfomance benchmark using `triton.testing.perf_report`. Similar to matmul benchmark from seminar

In [None]:
# You need to check your implementation for the following parameters
batch_size = 2
for image_resolution in [32, 128, 512, 1024, 1536, 2048]:
    for num_channels in [32, 128, 386, 512]:
        for num_groups in [1, 4, 8, 16, 32]:

            dummy_input = torch.randn(batch_size, num_channels, image_resolution, image_resolution)

            # TODO benchmark