In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim

# Define a custom dataset
class CustomDataset(Dataset):
    def __init__(self, num_samples, num_features):
        # Generate synthetic data
        self.X = torch.randn(num_samples, num_features)
        self.y = torch.randn(num_samples, 1)  # Assuming a single target value per sample

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Simple neural network model
class SimpleNN(nn.Module):
    def __init__(self, num_features):
        super(SimpleNN, self).__init__()
        self.linear = nn.Linear(num_features, 1)  # Simple linear layer

    def forward(self, x):
        return self.linear(x)

def train_model(num_samples=100000, num_features=10, batch_size=32, epochs=5):
    # Create dataset and dataloader
    dataset = CustomDataset(num_samples, num_features)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Model, loss function, and optimizer
    model = SimpleNN(num_features)
    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    # Training loop
    for epoch in range(epochs):
        for batch_idx, (X_batch, y_batch) in enumerate(dataloader):
            # Forward pass
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            # Backward pass and optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{epochs}], Step [{batch_idx+1}/{len(dataloader)}], Loss: {loss.item():.4f}')

    return model

# Train the model
trained_model = train_model()

print(trained_model)

Epoch [1/5], Step [1/3125], Loss: 1.4264
Epoch [1/5], Step [101/3125], Loss: 1.0252
Epoch [1/5], Step [201/3125], Loss: 1.0122
Epoch [1/5], Step [301/3125], Loss: 0.9711
Epoch [1/5], Step [401/3125], Loss: 1.0386
Epoch [1/5], Step [501/3125], Loss: 0.6912
Epoch [1/5], Step [601/3125], Loss: 1.2317
Epoch [1/5], Step [701/3125], Loss: 0.9111
Epoch [1/5], Step [801/3125], Loss: 0.8604
Epoch [1/5], Step [901/3125], Loss: 0.7881
Epoch [1/5], Step [1001/3125], Loss: 0.9467
Epoch [1/5], Step [1101/3125], Loss: 1.2362
Epoch [1/5], Step [1201/3125], Loss: 0.9764
Epoch [1/5], Step [1301/3125], Loss: 1.2027
Epoch [1/5], Step [1401/3125], Loss: 0.9866
Epoch [1/5], Step [1501/3125], Loss: 1.0049
Epoch [1/5], Step [1601/3125], Loss: 0.8862
Epoch [1/5], Step [1701/3125], Loss: 0.8246
Epoch [1/5], Step [1801/3125], Loss: 1.5683
Epoch [1/5], Step [1901/3125], Loss: 0.8251
Epoch [1/5], Step [2001/3125], Loss: 0.6822
Epoch [1/5], Step [2101/3125], Loss: 0.8795
Epoch [1/5], Step [2201/3125], Loss: 1.0890
