<a href="https://colab.research.google.com/github/cleyber-bezerra/SplitLearning-Async-NS3/blob/main/AsyncSL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Importando Biblioteca Python

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import random
import matplotlib.pyplot as plt

Importando Base de Dados (Data set)

In [2]:
# Definindo as transformações para os dados de validação
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Carregando o dataset de validação (MNIST neste caso)
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Criando o DataLoader para validação
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 16389328.59it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 485831.20it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 3855964.10it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 9270330.30it/s]


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



In [3]:
class ClientNet(nn.Module):
    def __init__(self):
        super(ClientNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

    def forward(self, x):
        x = self.layer1(x)
        return x

class ServerNet(nn.Module):
    def __init__(self):
        super(ServerNet, self).__init__()
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc1 = nn.Linear(7 * 7 * 64, 1000)
        self.fc2 = nn.Linear(1000, 10)

    def forward(self, x):
        x = self.layer2(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

In [4]:
#INTRODUCAO DE FALHAS DE TRANSMISSAO
def introduce_transmission_errors(tensor, error_rate):
    if error_rate == 0:
        return tensor
    device = tensor.device  # Obter o dispositivo do tensor original
    mask = torch.rand(tensor.size(), device=device) > error_rate  # Criar a máscara no mesmo dispositivo
    return tensor * mask

#INTRODUCAO DE LATENCIA NAS TRANSMISSOES
def introduce_latency(latencies, delta_t):
    latency = random.choice(latencies)
    if latency > delta_t:
        return None, latency  # Consider as timeout
    else:
        return latency, latency  # Successful transmission

In [5]:
def train(client_models, server_model, train_loader, criterion, optimizer, latencies, delta_t, error_rate, device):
    for client_model in client_models:
        client_model.train()
    server_model.train()
    train_loss = 0.0
    correct = 0
    total = 0
    latencies_record = []

    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # Simulating latency
        latency, latency_val = introduce_latency(latencies, delta_t)
        latencies_record.append(latency_val)
        if latency is None:
            continue  # Simulate timeout (loss of message)

        # Forward pass with each client model
        outputs = sum(server_model(client_model(data)) for client_model in client_models) / len(client_models)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * data.size(0)
        _, predicted = outputs.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

    train_loss /= total
    accuracy = 100. * correct / total
    return train_loss, accuracy, latencies_record

In [6]:
def validate(client_models, server_model, val_loader, criterion, latencies, delta_t, error_rate, device):
    for client_model in client_models:
        client_model.eval()
    server_model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    latencies_record = []

    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)

            # Simulating latency
            latency, latency_val = introduce_latency(latencies, delta_t)
            latencies_record.append(latency_val)
            if latency is None:
                continue  # Simulate timeout (loss of message)

            outputs = sum(server_model(client_model(data)) for client_model in client_models) / len(client_models)
            loss = criterion(outputs, target)

            val_loss += loss.item() * data.size(0)
            _, predicted = outputs.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()

    val_loss /= total
    accuracy = 100. * correct / total
    return val_loss, accuracy, latencies_record

In [7]:
def search_based_quantization(X):
    median = np.median(np.abs(X))
    for ebit in range(3, 6):
        max_value = (2 ** (2 ** ebit - 1)) - 1
        min_value = 1
        max_bias = np.log2(min_value/median)
        min_bias = np.log2(median/max_value)
        for bias in np.arange(min_bias, max_bias, 0.01):
            overflow = np.sum(X > max_value * 2 ** bias) / len(X)
            underflow = np.sum(X < min_value * 2 ** bias) / len(X)
            clip = overflow + underflow
            if clip < 0.01:
                return bias
    return None

In [8]:
class AsynchronousSplitLearning:
    def __init__(self, client_models, server_model, num_epoch, num_batch, K, lthred):
        self.state = 'A'
        self.client_models = client_models
        self.server_model = server_model
        self.num_epoch = num_epoch
        self.num_batch = num_batch
        self.K = K
        self.lthred = lthred
        self.total_loss = 0

    def split_forward(self, state, data, target, criterion):
        if state == 'C':
            act, y_star = None, None
        else:
            act = sum(client_model(data) for client_model in self.client_models) / len(self.client_models)
            y_star = target
        outputs = self.server_model(act)
        loss = criterion(outputs, target)
        return loss

    def split_backward(self, state, loss, optimizer):
        loss.backward()
        optimizer.step()

    def update_state(self, total_loss):
        last_update_loss = total_loss / (self.num_batch * self.K)
        delta_loss = last_update_loss - (total_loss / (self.num_batch * self.K))
        if delta_loss <= self.lthred:
            self.state = 'A'
        else:
            self.state = 'B' if self.state == 'A' else 'C'
        return self.state

    def train(self, train_loader, criterion, optimizer, latencies, delta_t, error_rate, device):
        for epoch in range(1, self.num_epoch + 1):
            total_loss = 0
            for client in range(1, self.K + 1):
                for batch_idx, (data, target) in enumerate(train_loader):
                    data, target = data.to(device), target.to(device)
                    optimizer.zero_grad()

                    latency, latency_val = introduce_latency(latencies, delta_t)
                    if latency is None:
                        continue

                    loss = self.split_forward(self.state, data, target, criterion)
                    total_loss += loss.item()
                    self.split_backward(self.state, loss, optimizer)
            self.state = self.update_state(total_loss)

In [None]:
device = torch.device("cpu")

client_models = [ClientNet().to(device) for _ in range(3)]
server_model = ServerNet().to(device)

criterion = nn.CrossEntropyLoss() #FUNCAO DE PERDA
optimizer = optim.SGD(
    [param for client_model in client_models for param in client_model.parameters()] + list(server_model.parameters()),
    lr=0.01,                      #Otimizador
    momentum=0.9
)

latencies = [1, 2, 3, 4, 5]
delta_t = 3  # Timeout threshold
error_rates = [0.0, 0.25, 0.50, 0.75]

num_epochs = 4  # the article 200
train_size = 5  # the article 10000
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []
all_train_latencies = []
all_val_latencies = []
biases = []

for error_rate in error_rates:
    print(f"Training with error rate: {error_rate * 100}%")
    for epoch in range(1, num_epochs + 1):
        if biases.append(search_based_quantization(np.random.randn(train_size))):
            asl = AsynchronousSplitLearning(client_models, server_model, num_epoch=num_epochs, num_batch=len(train_loader), K=3, lthred=0.01)
            asl.train(train_loader, criterion, optimizer, latencies, delta_t, error_rate, device)

            train_loss, train_accuracy, train_latencies = train(client_models, server_model, train_loader, criterion, optimizer, latencies, delta_t, error_rate, device)
            val_loss, val_accuracy, val_latencies = validate(client_models, server_model, train_loader, criterion, latencies, delta_t, error_rate, device)

            train_losses.append(train_loss)
            train_accuracies.append(train_accuracy)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)
            all_train_latencies.extend(train_latencies)
            all_val_latencies.extend(val_latencies)

            print(f'Epoch {epoch}/{num_epochs} - '
                  f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% - '
                  f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
            print(f'TRAIN ASSYNC')
        else:
            train_loss, train_accuracy, train_latencies = train(client_models, server_model, train_loader, criterion, optimizer, latencies, delta_t, error_rate, device)
            val_loss, val_accuracy, val_latencies = validate(client_models, server_model, train_loader, criterion, latencies, delta_t, error_rate, device)

            train_losses.append(train_loss)
            train_accuracies.append(train_accuracy)
            val_losses.append(val_loss)
            val_accuracies.append(val_accuracy)
            all_train_latencies.extend(train_latencies)
            all_val_latencies.extend(val_latencies)

            print(f'Epoch {epoch}/{num_epochs} - '
                  f'Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.2f}% - '
                  f'Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.2f}%')
            print(f'TRAIN SYNC')



In [None]:
   # Plotting latencies
    plt.figure(figsize=(num_epochs, 5))
    plt.hist(all_train_latencies, bins=[1, 2, 3, 4, 5, 6], edgecolor='black')
    plt.axvline(x=delta_t, color='r', linestyle='--', label=f'Timeout Threshold: {delta_t} ms')
    plt.title('Latency Distribution During Training')
    plt.xlabel('Latency (ms)')
    plt.ylabel('Frequency')
    plt.legend()
    plt.show()

    print(num_epochs)
    print(len(train_losses))
    print(len(val_losses))


In [None]:
    plt.figure(figsize=(num_epochs, 5))

    # Plotting the losse
    plt.subplot(1, 2, 1)
    plt.plot(range(1, num_epochs + 1), train_losses, label='Train Loss')
    plt.plot(range(1, num_epochs + 1), val_losses, label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Train and Validation Loss')
    plt.legend()
    # Plotting the accuracies
    plt.subplot(1, 2, 2)
    plt.plot(range(1, num_epochs + 1), train_accuracies, label='Train Accuracy')
    plt.plot(range(1, num_epochs + 1), val_accuracies, label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.title('Train and Validation Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.show()