<a href="https://colab.research.google.com/github/hassanSattariNia/FederatedLearning/blob/main/MnistCNNDistributed.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## original code cnn-mnitst


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# define cnn model
def mnist_cnn():
    return nn.Sequential(
        nn.Conv2d(1, 32, kernel_size=5),  #first partition
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Conv2d(32, 64, kernel_size=5),  #second partition
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=2),
        nn.Flatten(),
        nn.Linear(64 * 4 * 4, 512),  #third partition
        nn.ReLU(),
        nn.Linear(512, 10)  #last partition
    )

# initial values for parameters
batch_size = 32
epochs = 5
learning_rate = 0.001
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# normalize data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # میانگین و انحراف معیار MNIST
])


train_dataset = datasets.MNIST(root='./data', train=True, download=True,
                               transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True,
                              transform=transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size,
                         shuffle=False)


model = mnist_cnn().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    total_loss = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()  # صفر کردن گرادیان‌ها
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        if (batch_idx + 1) % 100 == 0:
            print(f'Epoch [{epoch}] Batch [{batch_idx+1}/{len(train_loader)}] '
                  f'Loss: {loss.item():.4f}')
    avg_loss = total_loss / len(train_loader)
    print(f'==> Epoch [{epoch}] Average training loss: {avg_loss:.4f}')


def test(model, device, test_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    avg_loss = test_loss / len(test_loader)
    accuracy = 100 * correct / total
    print(f'==> Test set: Average loss: {avg_loss:.4f}, '
          f'Accuracy: {correct}/{total} ({accuracy:.2f}%)')


for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)


Epoch [1] Batch [100/1875] Loss: 0.1617
Epoch [1] Batch [200/1875] Loss: 0.0994
Epoch [1] Batch [300/1875] Loss: 0.0035
Epoch [1] Batch [400/1875] Loss: 0.0764
Epoch [1] Batch [500/1875] Loss: 0.0504
Epoch [1] Batch [600/1875] Loss: 0.0373
Epoch [1] Batch [700/1875] Loss: 0.0992
Epoch [1] Batch [800/1875] Loss: 0.0749
Epoch [1] Batch [900/1875] Loss: 0.0906
Epoch [1] Batch [1000/1875] Loss: 0.0295
Epoch [1] Batch [1100/1875] Loss: 0.0405
Epoch [1] Batch [1200/1875] Loss: 0.0022
Epoch [1] Batch [1300/1875] Loss: 0.0146
Epoch [1] Batch [1400/1875] Loss: 0.0072
Epoch [1] Batch [1500/1875] Loss: 0.1307
Epoch [1] Batch [1600/1875] Loss: 0.0135
Epoch [1] Batch [1700/1875] Loss: 0.0260
Epoch [1] Batch [1800/1875] Loss: 0.0552
==> Epoch [1] Average training loss: 0.1041
==> Test set: Average loss: 0.0423, Accuracy: 9866/10000 (98.66%)
Epoch [2] Batch [100/1875] Loss: 0.0045
Epoch [2] Batch [200/1875] Loss: 0.0006
Epoch [2] Batch [300/1875] Loss: 0.0012
Epoch [2] Batch [400/1875] Loss: 0.2452
E

# **splited code Mnist**

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms


batch_size = 64
epochs = 5
learning_rate = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# loading data
train_dataset = datasets.MNIST(root='./data', train=True, download=True,
                               transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True,
                              transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                          shuffle=False)



# first partition
class Client1Model(nn.Module):
    def __init__(self):
        super(Client1Model, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.pool1(x)
        return x

# second partition
class Client2Model(nn.Module):
    def __init__(self):
        super(Client2Model, self).__init__()
        self.conv2 = nn.Conv2d(32, 64, kernel_size=5)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2)
    def forward(self, x):
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.pool2(x)
        return x


class Client3Model(nn.Module):
    def __init__(self):
        super(Client3Model, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64 * 4 * 4, 512)
        self.relu3 = nn.ReLU()
    def forward(self, x):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu3(x)
        return x

class Client4Model(nn.Module):
    def __init__(self):
        super(Client4Model, self).__init__()
        self.fc2 = nn.Linear(512, 10)
    def forward(self, x):
        x = self.fc2(x)
        return x


client1_model = Client1Model().to(device)
client2_model = Client2Model().to(device)
client3_model = Client3Model().to(device)
client4_model = Client4Model().to(device)

optimizer1 = optim.Adam(client1_model.parameters(), lr=learning_rate)
optimizer2 = optim.Adam(client2_model.parameters(), lr=learning_rate)
optimizer3 = optim.Adam(client3_model.parameters(), lr=learning_rate)
optimizer4 = optim.Adam(client4_model.parameters(), lr=learning_rate)

criterion = nn.CrossEntropyLoss()

# define splitLayer for handle backward and forward
class SplitLayer(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input_tensor, model):
        ctx.save_for_backward(input_tensor)
        ctx.model = model
        output = model(input_tensor)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_tensor, = ctx.saved_tensors
        model = ctx.model

        input_tensor.requires_grad = True
        with torch.enable_grad():
            output = model(input_tensor)
            output.backward(grad_output)
        grad_input = input_tensor.grad
        return grad_input, None

# train
for epoch in range(1, epochs + 1):
    client1_model.train()
    client2_model.train()
    client3_model.train()
    client4_model.train()

    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        # input data in first client
        data, target = data.to(device), target.to(device)

        optimizer1.zero_grad()
        out1 = client1_model(data)

        out1 = out1.detach()
        out1.requires_grad = True

        # splitLayer
        out2 = SplitLayer.apply(out1, client2_model)

        out2 = out2.detach()
        out2.requires_grad = True

        out3 = SplitLayer.apply(out2, client3_model)

        out3 = out3.detach()
        out3.requires_grad = True

        optimizer4.zero_grad()
        out4 = client4_model(out3)
        loss = criterion(out4, target)

        loss.backward()
        optimizer4.step()

        grad_out3 = out3.grad.detach()

        optimizer3.zero_grad()
        grad_out3 = grad_out3.to(device)
        out2 = out2.detach()
        out2.requires_grad = True
        out3 = client3_model(out2)
        out3.backward(grad_out3)
        optimizer3.step()

        grad_out2 = out2.grad.detach()

        optimizer2.zero_grad()
        grad_out2 = grad_out2.to(device)
        out1 = out1.detach()
        out1.requires_grad = True
        out2 = client2_model(out1)
        out2.backward(grad_out2)
        optimizer2.step()

        grad_out1 = out1.grad.detach()

        optimizer1.zero_grad()
        data.requires_grad = True
        out1 = client1_model(data)
        out1.backward(grad_out1)
        optimizer1.step()

        running_loss += loss.item()
        _, predicted = torch.max(out4.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch}/{epochs}], Batch [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')

    train_accuracy = 100 * correct / total
    print(f'Epoch [{epoch}/{epochs}], Training Loss: {running_loss/len(train_loader):.4f}, Training Accuracy: {train_accuracy:.2f}%')

client1_model.eval()
client2_model.eval()
client3_model.eval()
client4_model.eval()

test_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)

        out1 = client1_model(data)
        out2 = client2_model(out1)
        out3 = client3_model(out2)
        out4 = client4_model(out3)

        loss = criterion(out4, target)
        test_loss += loss.item()

        _, predicted = torch.max(out4.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

test_accuracy = 100 * correct / total
print(f'Test Loss: {test_loss/len(test_loader):.4f}, Test Accuracy: {test_accuracy:.2f}%')

print("Training completed.")


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 16.3MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 488kB/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 4.41MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 7.38MB/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Epoch [1/5], Batch [0/938], Loss: 2.3319
Epoch [1/5], Batch [100/938], Loss: 0.2466
Epoch [1/5], Batch [200/938], Loss: 0.0865
Epoch [1/5], Batch [300/938], Loss: 0.4218
Epoch [1/5], Batch [400/938], Loss: 0.3370
Epoch [1/5], Batch [500/938], Loss: 0.2547
Epoch [1/5], Batch [600/938], Loss: 0.2251
Epoch [1/5], Batch [700/938], Loss: 0.1948
Epoch [1/5], Batch [800/938], Loss: 0.1671
Epoch [1/5], Batch [900/938], Loss: 0.1712
Epoch [1/5], Training Loss: 0.2874, Training Accuracy: 92.25%
Epoch [2/5], Batch [0/938], Loss: 0.2981
Epoch [2/5], Batch [100/938], Loss: 0.1868
Epoch [2/5], Batch [200/938], Loss: 0.1757
Epoch [2/5], Batch [300/938], Loss: 0.0487
Epoch [2/5], Batch [400/938], Loss: 0.2748
Epoch [2/5], Batch [500/938], Loss: 0.0927
Epoch [2/5], Batch [600/938], Loss: 0.1143
Epoch [2/5], Batch [700/938], Loss: 0.1473
Epoch [2/5], Batch [800/938], Loss: 0.1998
Epoch [2/5], Batch [900/938], Loss: 0.0869
Epoch [