In [None]:
import torch
import syft as sy

# Hook PyTorch to PySyft
hook = sy.TorchHook(torch)

# Create two virtual workers to simulate two clients
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")

# Generate some synthetic data for each client
data_alice = torch.tensor([1.0, 2.0, 3.0]).send(alice)
data_bob = torch.tensor([4.0, 5.0, 6.0]).send(bob)

# Define a simple model
model = torch.nn.Linear(1, 1)

# Federated learning setup
model = model.fix_precision().share(alice, bob)

# Training on decentralized data
for _ in range(10):
    model.zero_grad()
    prediction_alice = model(data_alice)
    prediction_bob = model(data_bob)
    loss = ((prediction_alice + prediction_bob) - 15).pow(2).sum()
    loss.backward()
    model.weight.data.sub_(model.weight.grad * 0.1)
    model.bias.data.sub_(model.bias.grad * 0.1)

# Securely get the trained model back
model = model.get().float_precision()

# Make predictions on a new data point
new_data = torch.tensor([7.0])
prediction = model(new_data)
print(f"Predicted value for input {new_data}: {prediction.item()}")
