In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from zeus.monitor import ZeusMonitor

# Dummy dataset and dataloader
X_train = torch.randn(1000, 10)  # Sample input data
y_train = torch.randint(0, 2, (1000,))  # Sample target labels
train_dataset = TensorDataset(X_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Define a simple neural network model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = SimpleModel()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# All four GPUs are measured simultaneously.
monitor = ZeusMonitor()

# Measure total time and energy within the window.
monitor.begin_window("training")
for e in range(100):

    # Measurement windows can arbitrarily be overlapped.
    monitor.begin_window("epoch")
    for x, y in train_dataloader:
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
    measurement = monitor.end_window("epoch")
    print(f"Epoch {e}: {measurement.time} s, {measurement.total_energy} J")

measurement = monitor.end_window("training")
print(f"Entire training: {measurement.time} s, {measurement.total_energy} J")



[2024-05-27 15:38:23,997] [zeus.monitor.energy](energy.py:150) Monitoring GPU indices [0].
[2024-05-27 15:38:24,810] [zeus.utils.framework](framework.py:38) PyTorch with CUDA support is available.
Epoch 0: 0.019814252853393555 s, 0.0 J
Epoch 1: 0.007529735565185547 s, 0.0 J
Epoch 2: 0.0074117183685302734 s, 0.0 J
Epoch 3: 0.007066965103149414 s, 0.0 J
Epoch 4: 0.007569313049316406 s, 0.0 J
Epoch 5: 0.007040739059448242 s, 0.0 J
Epoch 6: 0.0079803466796875 s, 0.0 J
Epoch 7: 0.0069539546966552734 s, 0.0 J
Epoch 8: 0.007540464401245117 s, 2.1929999999701977 J
Epoch 9: 0.0072765350341796875 s, 0.0 J
Epoch 10: 0.007555723190307617 s, 0.0 J
Epoch 11: 0.007149696350097656 s, 0.0 J
Epoch 12: 0.007519960403442383 s, 0.0 J
Epoch 13: 0.006961345672607422 s, 0.0 J
Epoch 14: 0.0073032379150390625 s, 0.0 J
Epoch 15: 0.006948709487915039 s, 0.0 J
Epoch 16: 0.007500648498535156 s, 0.0 J
Epoch 17: 0.006929159164428711 s, 0.0 J
Epoch 18: 0.007444858551025391 s, 2.314000000245869 J
Epoch 19: 0.0071167945