In [14]:
from dataclasses import dataclass
from Model import PashkoModel
import torch

@dataclass
class PashkoConfig:
    sequence_length: int = 1024
    vocab_size: int = 50304
    embed_dim: int = 768

    encoder = 'gpt2'

    num_heads: int = 12
    num_blocks: int = 12

    dropout: float = 0.0

    ffnn_bias: bool = False
    qkv_bias: bool = False
    layernorm_bias: bool = False

    topK: int = 10
    temperature: float = 1.0

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

Pashko = PashkoModel(PashkoConfig())

In [15]:
num_params = Pashko.num_params()[1]
print(f"Number of parameters roughly {Pashko.num_params()[0]}, exactly = {Pashko.num_params()[1]}")

Number of parameters roughly 123.60M, exactly = 123595776


Memory Size of Checkpoints

In [16]:
parameter_bytes = num_params*4 #32 Bit Floating Point
parameter_bytes *= 3 #2 Extra per parameter for optimiser
print(f"Estimated checkpoint size: {parameter_bytes/1e9:.2f} GB")

Estimated checkpoint size: 1.48 GB


GPU Memory

In [17]:
available_memory = 12e9 #Memory available on my graphics card 3080ti
print(f"Memory taken up in GPU for parameters: {parameter_bytes/available_memory * 100:.2f}%")

Memory taken up in GPU for parameters: 12.36%


FLOP Estimations (Formula from PALM paper)

In [18]:
config = PashkoConfig()
def palm_flops():
    L, H, Q, T = config.num_blocks, config.num_heads, config.embed_dim//config.num_heads, config.sequence_length
    mf_per_token = 6*num_params + 12*L*H*Q*T
    mf = mf_per_token * config.sequence_length
    return mf
palm_flops()

875336564736

Usage %

In [19]:
input = "silver angel at it's best"
input = [Pashko.Encoder.encode(input)] * 20
input = torch.LongTensor(input)
targets = torch.LongTensor([target for sublist in input for target in sublist])
input = input.to(device)
targets = targets.to(device)

In [20]:
optimiser = torch.optim.AdamW(Pashko.parameters())

In [21]:
Pashko = Pashko.to(device)

In [22]:
%%timeit -r 1000
optimiser.zero_grad()
_, loss = Pashko(input, targets)
loss.backward()
optimiser.step()

26 ms ± 824 µs per loop (mean ± std. dev. of 1000 runs, 10 loops each)


In [25]:
batch_size = 20
measured_time = 0.26
measured_throughput = batch_size / measured_time
flops_achieved = palm_flops() * measured_throughput

gpu_flops = 136e12

mfu = flops_achieved / gpu_flops

print(f"Fraction of GPU used: {flops_achieved / gpu_flops * 100:.2f}%, MFU: {mfu}")

Fraction of GPU used: 49.51%, MFU: 0.49509986693212665


Training Time Approximation

In [26]:
tokens_num = 300e9 // 5 #WebText_20p
flops_throughput = gpu_flops * mfu
flops_needed = 6 * num_params * tokens_num
time_needed_s = flops_needed / flops_throughput
print(f"Time needed to train the model: {time_needed_s/3600/24:.2f} days")

Time needed to train the model: 7.65 days
