In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(10, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
    
    def forward(self, x):
        return self.net(x)

# Create synthetic data
X = torch.randn(1000, 10)
y = torch.sum(X**2, dim=1, keepdim=True)

# Training function
def train_step(model, X, y, optimizer):
    optimizer.zero_grad()
    output = model(X)
    loss = nn.MSELoss()(output, y)
    loss.backward()
    optimizer.step()
    return loss.item()

# Regular model
model = SimpleModel()
optimizer = torch.optim.Adam(model.parameters())

# Train one batch
loss = train_step(model, X, y, optimizer)
print(f"Regular model loss: {loss:.6f}")

# Compiled modelSAC
compiled_model = torch.compile(SimpleModel())
compiled_optimizer = torch.optim.Adam(compiled_model.parameters())

# Train one batch with compiled model
compiled_loss = train_step(compiled_model, X, y, compiled_optimizer)
print(f"Compiled model loss: {compiled_loss:.6f}")

Regular model loss: 120.338402
Compiled model loss: 121.052567


In [None]:
def show_compilation_modes():
    model = MLPModel()
    input_data = torch.randn(1, 10)
    
    compilation_modes = {
        'Default': {},
        'Reduce-overhead': {'mode': 'reduce-overhead'},
        'Max-autotune': {'mode': 'max-autotune'}
    }
    
    print("\nTesting different compilation modes:")
    
    for name, kwargs in compilation_modes.items():
        try:
            compiled_model = torch.compile(model, **kwargs)
            
            # Warmup
            _ = compiled_model(input_data)
            
            # Timing
            start_time = time.time()
            for _ in range(10000):
                _ = compiled_model(input_data)
            end_time = time.time()
            
            print(f"{name} mode - Time for 10000 forward passes: {end_time - start_time:.4f}s")
        except Exception as e:
            print(f"{name} mode - Compilation failed: {e}")
            
            
show_compilation_modes()


Testing different compilation modes:
Default mode - Time for 10000 forward passes: 1.4041s


skipping cudagraphs due to skipping cudagraphs due to cpu device. Found from : 
   File "/tmp/ipykernel_16364/1571453352.py", line 26, in forward
    return self.layers(x)



Reduce-overhead mode - Time for 10000 forward passes: 1.4908s


skipping cudagraphs due to skipping cudagraphs due to cpu device. Found from : 
   File "/tmp/ipykernel_16364/1571453352.py", line 26, in forward
    return self.layers(x)



Max-autotune mode - Time for 10000 forward passes: 1.4898s
