In [None]:
import torch
import torch.nn as nn
import time

In [None]:
from importlib.metadata import version

pkgs = [
    'torch',
    'thop',
]
for p in pkgs:
    print(f"{p}: {version(p)}")

In [None]:
from thop import profile
from gpt_model import GPTModel
from gpt_download import BASE_CONFIG, model_configs

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = 2
input_tensor = torch.randint(0, 50257, (batch_size, 1024)).to(device)
input_tensor.shape

In [None]:
config = BASE_CONFIG.copy()
for size in model_configs:
    config.update(model_configs[size])
    print(config)
    model = GPTModel(config).bfloat16()
    model.to(device)

    macs, params = profile(model, inputs=(input_tensor,), verbose=False)
    flops = macs * 2
    print(f"{size:18} flops = {flops:.1e}, number of parameters: {params/1e6:.2f}M")

    del model
    torch.cuda.empty_cache()
    break

In [None]:
# start_time = time.time()
config = BASE_CONFIG.copy()
for size in model_configs:
    min_batch_size = 1
    max_batch_size = None
    max_possible_batch_size = 4096

    config.update(model_configs[size])

    while min_batch_size < max_possible_batch_size:
        batch_size = (min_batch_size + max_possible_batch_size) // 2
        try:
            input_tensor = torch.randint(0, 50257, (batch_size, 1024)).to(device)
            model = GPTModel(config).bfloat16()
            model.to(device)
            macs, params = profile(model, inputs=(input_tensor,), verbose=False)
            flops = macs * 2
            print(f"{size:18} flops = {flops:.1e}, number of parameters: {params/1e6:.2f}M, batch_size = {batch_size}")

            min_batch_size = batch_size + 1
            max_batch_size = batch_size
        except RuntimeError as e:
            if "out of memory" in str(e):
                max_possible_batch_size = batch_size - 1

            del model, input_tensor
            torch.cuda.empty_cache()

    print(f"{size:18} flops = {flops:.1e}, number of parameters: {params/1e6:.2f}M, max batch_size = {max_batch_size}")
    break

In [None]:
max_flops_per_second = {
    "A100": {
        torch.float32: 19.49e12,
        torch.float16: 77.97e12,
        torch.bfloat16: 77.97e12,
    }
}

In [None]:
device_name = torch.cuda.get_device_name(0)
for model_name in max_flops_per_second:
    if model_name in device_name:
        print(model_name)

In [None]:
model_name

In [None]:
# model = GPTModel(config).bfloat16()
# model.to(device)
# data_type = next(model.parameters()).dtype
# print(data_type)
# max_flops_per_second = max_flops_per_second[model_name].get(data_type, 0)

# del model
# torch.cuda.empty_cache()

In [None]:
config = BASE_CONFIG.copy()
model_name = 'A100'
for size in model_configs:
    min_batch_size = 1
    max_batch_size = None
    max_possible_batch_size = 4096

    config.update(model_configs[size])

    while min_batch_size < max_possible_batch_size:
        batch_size = (min_batch_size + max_possible_batch_size) // 2
        try:
            input_tensor = torch.randint(0, 50257, (batch_size, 1024)).to(device)
            model = GPTModel(config).bfloat16()
            model.to(device)
            model.train()

            torch.cuda.synchronize()
            start_time = time.time()
            output = model(input_tensor)
            loss = output.sum()
            loss.backward()

            torch.cuda.synchronize()
            end_time = time.time()
            elapsed_time = end_time - start_time

            macs, params = profile(model, inputs=(input_tensor,), verbose=False)
            flops_forward = macs * 2
            flops_backward = flops_forward * 2
            total_flops = flops_forward + flops_backward

            data_type = next(model.parameters()).dtype
            max_flops_per_second_model = max_flops_per_second[model_name].get(data_type, 0)
            tokens_processed = batch_size * 1024
            observed_tokens_per_sec = tokens_processed / elapsed_time

            theoretical_tokens_per_sec = max_flops_per_second_model / (total_flops / tokens_processed)

            mfu = observed_tokens_per_sec / theoretical_tokens_per_sec

            print(f"{size:18} flops = {total_flops:.1e}, number of parameters: {params/1e6:.2f}M, batch_size = {batch_size}, mfu: {mfu:.4f}")

            min_batch_size = batch_size + 1
            max_batch_size = batch_size

            del model, input_tensor, output, loss
            torch.cuda.empty_cache()
        except RuntimeError as e:
            if "out of memory" in str(e):
                max_possible_batch_size = batch_size - 1

            del model, input_tensor
            torch.cuda.empty_cache()

    print(f"{size:18} flops = {flops:.1e}, number of parameters: {params/1e6:.2f}M, max batch_size = {max_batch_size}")
    break