In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

### 1. Tensor

In [2]:
X_cpu = torch.randn(10000, 3, device="cpu")
W_cpu = torch.randn(3, 10000, requires_grad=True, device="cpu")

X_mps = X_cpu.clone().to("mps")
W_mps = W_cpu.clone().to("mps")

In [3]:
X_cpu.device, W_cpu.device, X_mps.device, W_mps.device

(device(type='cpu'),
 device(type='cpu'),
 device(type='mps', index=0),
 device(type='mps', index=0))

In [4]:
# Measure time for CPU
#%timeit torch.matmul(X_cpu, W_cpu)

In [5]:
# Measure time for MPS
#%timeit torch.matmul(X_mps, W_mps)

### 2. Dataset

In [6]:
class TestDataset(Dataset):

    def __init__(self):
        self.data = torch.randn(10_000, 5, dtype=torch.float32)
        self.target = torch.randint(0, 2, (10_000, 3), dtype=torch.float32)


    def __getitem__(self, index: int):
        x = self.data[index]
        y = self.target[index]
        return x, y

    def __len__(self):
        return len(self.data)  

In [7]:
train_dataset = TestDataset()
eval_dataset = TestDataset()

### 3. DataLoader

In [8]:
batch_size = 64

In [9]:
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False)
val_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False)

### 4. Module

In [10]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.input_layer = nn.Sequential(nn.Linear(5, 32), nn.ReLU())
        self.hidden_layer = nn.Sequential(nn.Linear(32, 32), nn.ReLU())
        self.output_layer = nn.Sequential(nn.Linear(32, 3), nn.Sigmoid())

    def forward(self, x):
        x = self.input_layer(x)
        x = self.hidden_layer(x)
        x = self.output_layer(x)
        return x

### 5. Training

In [11]:
model = Net()

In [12]:
def _validate(model, dataloader, criterion):

    total_loss = 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            y_pred = model(X)
            loss = criterion(y_pred, y)
            total_loss += loss.item()
    return total_loss / len(dataloader)


def train_and_validate(model, train_dataloader, val_dataloader, epochs):
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    validation_losses = []

    for epoch in range(epochs):
        model.train()
        for batch, (X, y) in enumerate(train_dataloader):
            
            y_pred = model(X)
            loss = criterion(y_pred, y)

            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

        model.eval()
        validation_loss = _validate(model, val_dataloader, criterion)
        validation_losses.append(validation_loss)
        print(f"Epoch {epoch}, validation loss: {validation_loss}")

In [13]:
train_and_validate(model, train_dataloader, val_dataloader, epochs=25)

Epoch 0, validation loss: 0.6933184812782677
Epoch 1, validation loss: 0.6935971565307326
Epoch 2, validation loss: 0.693384034998098
Epoch 3, validation loss: 0.6931565062255617
Epoch 4, validation loss: 0.6932208784826243
Epoch 5, validation loss: 0.6932399140042105
Epoch 6, validation loss: 0.6931964823394824
Epoch 7, validation loss: 0.6933117884739189
Epoch 8, validation loss: 0.6934066984304197
Epoch 9, validation loss: 0.6932498512754015
Epoch 10, validation loss: 0.6933976282739336
Epoch 11, validation loss: 0.693495299026465
Epoch 12, validation loss: 0.693691302256979
Epoch 13, validation loss: 0.6939755700955725
Epoch 14, validation loss: 0.6942356832467826
Epoch 15, validation loss: 0.6941197005806455
Epoch 16, validation loss: 0.693653444575656
Epoch 17, validation loss: 0.6944151513136116
Epoch 18, validation loss: 0.6937655347168066
Epoch 19, validation loss: 0.6938322577506874
Epoch 20, validation loss: 0.6935547134678834
Epoch 21, validation loss: 0.6937626194042764
Ep