<a href="https://colab.research.google.com/github/jessica-hoffman/transformer_practice/blob/main/day7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import time

In [3]:
# simple model: small feedforward net
class TinyModel(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, output_dim=10):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x):
        return self.net(x)

In [5]:
# choose device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('using device:', device)
if device.type == 'cuda':
    print("GPU name:", torch.cuda.get_device_name(0))

using device: cuda
GPU name: Tesla T4


In [6]:
model = TinyModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

def profile_batch(batch_size, seq_len=512, num_classes=10):
    print(f'\n=== Batch size {batch_size} ===')

    # generate random input and labels
    x = torch.randn(batch_size, seq_len, device=device)
    y = torch.randint(0, num_classes, (batch_size,), device=device)

    # reset gradients
    optimizer.zero_grad()

    # record memory before forward
    torch.cuda.empty_cache()
    if device.type == 'cuda':
        mem_before = torch.cuda.memory_allocated(device)/1024**2
    else:
        mem_before = 0.0

    # forward pass timing
    start = time.time()
    outputs = model(x)
    loss = criterion(outputs, y)
    fwd_time = time.time() - start

    # backward pass timing
    start = time.time()
    loss.backward()
    optimizer.step()
    bwd_time = time.time() - start

    if device.type == 'cuda':
        mem_after = torch.cuda.memory_allocated(device)/1024**2
    else:
        mem_after = 0.0

    print(f'Forward time: {fwd_time:.4f}s, Backward time: {bwd_time:.4f}s')
    print(f'Memory before: {mem_before:.2f} MB, after: {mem_after:.2f} MB')

# try different batch sizes
for bs in [8,32,128,512]:
    profile_batch(bs)


=== Batch size 8 ===
Forward time: 0.3116s, Backward time: 0.3331s
Memory before: 0.53 MB, after: 18.31 MB

=== Batch size 32 ===
Forward time: 0.0053s, Backward time: 0.0013s
Memory before: 17.85 MB, after: 18.36 MB

=== Batch size 128 ===
Forward time: 0.0004s, Backward time: 0.0012s
Memory before: 18.03 MB, after: 18.55 MB

=== Batch size 512 ===
Forward time: 0.0006s, Backward time: 0.0011s
Memory before: 18.79 MB, after: 19.32 MB
