## Training the models

In this notebook, you will find the procedure that we used to train the models corresponding to the starting example. Considering the size of the models, this notebook can be run without GPUs. Please note that it may take some time if you run it until the end.

In [None]:
# Import the libraries

import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm

In [None]:
# Load the data

data = torch.load("data.pt").reshape(2 * 200, 2).float()
labels = torch.cat([torch.zeros(200), torch.ones(200)])

In [None]:
# Setup

loss = nn.BCEWithLogitsLoss()
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)
num_epochs = 10
num_models = 10000

In [None]:
# Train the models (takes around 30 minutes)

models = []
for k in tqdm(range(num_models)):
    model = nn.Sequential(nn.Linear(2, 2), nn.ReLU(), nn.Linear(2, 1)).float()
    optimizer = torch.optim.SGD(model.parameters(), lr=2)
    for i in range(num_epochs):
        mean_loss = 0
        for batch in dataloader:
            optimizer.zero_grad()
            x, y = batch
            y_hat = model(x)

            l = loss(y_hat, y.unsqueeze(1))
            l.backward()
            optimizer.step()

            # Compute some arbitrary metric to check if the model learned correctly
            mean_loss += l.item()

    if mean_loss / len(dataloader) <= 0.01:
        models.append(model.state_dict())

Save the models to disk if you want to replace the pre-trained models in other notebooks (commented to avoid writing on your disk without your acknowledgment).

In [None]:
# torch.save(models, "your_models.pt")