In [1]:
!pip install torch torchvision web3

Collecting web3
  Downloading web3-7.3.0-py3-none-any.whl.metadata (5.3 kB)
Collecting eth-abi>=5.0.1 (from web3)
  Downloading eth_abi-5.1.0-py3-none-any.whl.metadata (5.1 kB)
Collecting eth-account>=0.13.1 (from web3)
  Downloading eth_account-0.13.4-py3-none-any.whl.metadata (5.3 kB)
Collecting eth-hash>=0.5.1 (from eth-hash[pycryptodome]>=0.5.1->web3)
  Downloading eth_hash-0.7.0-py3-none-any.whl.metadata (5.4 kB)
Collecting eth-typing>=5.0.0 (from web3)
  Downloading eth_typing-5.0.0-py3-none-any.whl.metadata (5.1 kB)
Collecting eth-utils>=5.0.0 (from web3)
  Downloading eth_utils-5.0.0-py3-none-any.whl.metadata (5.4 kB)
Collecting hexbytes>=1.2.0 (from web3)
  Downloading hexbytes-1.2.1-py3-none-any.whl.metadata (3.7 kB)
Collecting types-requests>=2.0.0 (from web3)
  Downloading types_requests-2.32.0.20240914-py3-none-any.whl.metadata (1.9 kB)
Collecting websockets>=10.0.0 (from web3)
  Downloading websockets-13.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_

In [19]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
from web3 import Web3

# Set up Ethereum
infura_url = 'https://sepolia.infura.io/v3/ba0531e055a744aba6851b78bdda21be'
web3 = Web3(Web3.HTTPProvider(infura_url))
account = '0xeC85984aB1f737979Ae3a640c66F49AB71aba490'
private_key = 'a564e210c37c8569f90ec5c11d08a3733f68581ac333af3fbc1b48262fbf31a0'
contract_address = web3.to_checksum_address('0x4bd3de79afc02629a23d85dcc2692f4c0671eadd')  # Convert to checksum address

# ABI of the contract
contract_abi = [
    {
        "inputs": [],
        "name": "getLogs",
        "outputs": [
            {
                "components": [
                    {
                        "internalType": "address",
                        "name": "client",
                        "type": "address"
                    },
                    {
                        "internalType": "uint256",
                        "name": "timestamp",
                        "type": "uint256"
                    },
                    {
                        "internalType": "uint256",
                        "name": "accuracy",
                        "type": "uint256"
                    }
                ],
                "internalType": "struct AccuracyLogger.AccuracyLog[]",
                "name": "",
                "type": "tuple[]"
            }
        ],
        "stateMutability": "view",
        "type": "function"
    },
    {
        "inputs": [],
        "name": "getLogsLength",
        "outputs": [
            {
                "internalType": "uint256",
                "name": "",
                "type": "uint256"
            }
        ],
        "stateMutability": "view",
        "type": "function"
    },
    {
        "inputs": [
            {
                "internalType": "uint256",
                "name": "_accuracy",
                "type": "uint256"
            }
        ],
        "name": "logAccuracy",
        "outputs": [],
        "stateMutability": "nonpayable",
        "type": "function"
    },
    {
        "inputs": [
            {
                "internalType": "uint256",
                "name": "",
                "type": "uint256"
            }
        ],
        "name": "logs",
        "outputs": [
            {
                "internalType": "address",
                "name": "client",
                "type": "address"
            },
            {
                "internalType": "uint256",
                "name": "timestamp",
                "type": "uint256"
            },
            {
                "internalType": "uint256",
                "name": "accuracy",
                "type": "uint256"
            }
        ],
        "stateMutability": "view",
        "type": "function"
    }
]

# Load the contract
contract = web3.eth.contract(address=contract_address, abi=contract_abi)

# Load CIFAR-100 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
train_size = len(trainset)
num_clients = 4
client_size = train_size // num_clients

# Create client datasets
client_datasets = [Subset(trainset, range(i * client_size, (i + 1) * client_size)) for i in range(num_clients)]

# Define a simple CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 100)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Function to log aggregated accuracy on Ethereum
def log_aggregated_accuracy_to_ethereum(client_id, accuracy):
    # Get the nonce for the account
    nonce = web3.eth.get_transaction_count(account)

    # Convert accuracy to an integer format (e.g., scale it by 10000)
    accuracy_int = int(accuracy * 10000)

    # Build the transaction to call the smart contract's logAccuracy method
    transaction = contract.functions.logAccuracy(accuracy_int).build_transaction({
        'from': account,
        'gas': 4000000,
        'gasPrice': web3.to_wei('150', 'gwei'),
        'nonce': nonce
    })

    # Sign the transaction with the private key
    signed_txn = web3.eth.account.sign_transaction(transaction, private_key)

    # Send the transaction to the Ethereum network
    tx_hash = web3.eth.send_raw_transaction(signed_txn.raw_transaction)

    # Print the transaction hash
    print(f'Client {client_id + 1}: Transaction sent with hash: {tx_hash.hex()}')

# Train each client
for client_id in range(num_clients):
    client_data_loader = DataLoader(client_datasets[client_id], batch_size=32, shuffle=True)

    model = SimpleCNN()
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    total_accuracy = 0
    total_samples = 0

    for epoch in range(30):  # Change num_epochs if needed
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, labels in client_data_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()



       # Calculate accuracy for the epoch
        accuracy = correct / total
        print(f'Client {client_id + 1}, Epoch {epoch + 1}, Loss: {running_loss/len(client_data_loader)}, Accuracy: {accuracy:.4f}')

        # Update final_accuracy with the accuracy of the last epoch
        if epoch == 29:  # Assuming 5 epochs, this is the last one
            final_accuracy = accuracy

    # Log final epoch accuracy to Ethereum after all epochs
    if total > 0:  # Check to avoid division by zero
        log_aggregated_accuracy_to_ethereum(client_id, final_accuracy)
    print('Federated Learning Finished.')


Files already downloaded and verified
Client 1, Epoch 1, Loss: 4.121696708147483, Accuracy: 0.0678
Client 1, Epoch 2, Loss: 3.483838795396068, Accuracy: 0.1746
Client 1, Epoch 3, Loss: 3.0774863437008673, Accuracy: 0.2434
Client 1, Epoch 4, Loss: 2.7554508081787383, Accuracy: 0.3057
Client 1, Epoch 5, Loss: 2.4834455117545167, Accuracy: 0.3605
Client 1, Epoch 6, Loss: 2.221544655387664, Accuracy: 0.4198
Client 1, Epoch 7, Loss: 1.9712184320020554, Accuracy: 0.4737
Client 1, Epoch 8, Loss: 1.7289497934643874, Accuracy: 0.5313
Client 1, Epoch 9, Loss: 1.4862036418426983, Accuracy: 0.5886
Client 1, Epoch 10, Loss: 1.2532699434348689, Accuracy: 0.6471
Client 1, Epoch 11, Loss: 1.0299562067936754, Accuracy: 0.7080
Client 1, Epoch 12, Loss: 0.8326548417968214, Accuracy: 0.7644
Client 1, Epoch 13, Loss: 0.6844518948012911, Accuracy: 0.8030
Client 1, Epoch 14, Loss: 0.5425514350538059, Accuracy: 0.8366
Client 1, Epoch 15, Loss: 0.4392011517949421, Accuracy: 0.8718
Client 1, Epoch 16, Loss: 0.3