In [1]:
from collections import OrderedDict

from PIL import Image
import pandas as pd
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import flwr as fl

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [35]:
def load_data():
    # Set the path to your dataset
    data_train = "/home/bhabesh/Tencon 2023 Paper/Indian Dataset/Disease Grading/1. Original Images/Train_Set/"
    data_test = "/home/bhabesh/Tencon 2023 Paper/Indian Dataset/Disease Grading/1. Original Images/Test_Set/"
    csv_file_train = "/home/bhabesh/Tencon 2023 Paper/Indian Dataset/Disease Grading/2. Groundtruths/TrainingLabels.csv"
    csv_file_test = "/home/bhabesh/Tencon 2023 Paper/Indian Dataset/Disease Grading/2. Groundtruths/TestingLabels.csv"

    # Define transformations for data augmentation
    train_transforms = transforms.Compose([
        transforms.RandomResizedCrop(224),  # Random crop and resize to 224x224
        transforms.RandomHorizontalFlip(),  # Randomly flip the image horizontally
        transforms.ToTensor(),  # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406],  # Normalize image channels
                             std=[0.229, 0.224, 0.225])
    ])

    # Define transformations for validation (no augmentation)
    val_transforms = transforms.Compose([
        transforms.Resize(256),  # Resize image to 256x256
        transforms.CenterCrop(224),  # Center crop to 224x224
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # Custom dataset class for loading images and labels from the CSV file
    class CustomDataset(torch.utils.data.Dataset):
        def __init__(self, csv_file, root_dir, transform=None):
            self.df = pd.read_csv(csv_file)
            self.root_dir = root_dir
            self.transform = transform

        def __len__(self):
            return len(self.df)

        def __getitem__(self, idx):
            img_name = os.path.join(self.root_dir, str(self.df.iloc[idx, 0]))  # Assuming the image names are in the first column
            image = Image.open(img_name).convert("RGB")
            label = self.df.iloc[idx, 1]  # Assuming the labels are in the second column

            if self.transform:
                image = self.transform(image)

            return image, label

    # Load the dataset with data augmentation and transformations
    train_dataset = CustomDataset(csv_file_train, root_dir=os.path.join(data_train, 'Train'), transform=train_transforms)
    test_dataset = CustomDataset(csv_file_test, root_dir=os.path.join(data_test, 'Test'), transform=val_transforms)

    # Create data loaders for training and validation
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32)
    return train_loader, test_loader

In [36]:
class Net(nn.Module):
    """Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""

    def __init__(self) -> None:
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

In [37]:
def train(net, trainloader, epochs):
    """Train the model on the training set."""
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    for _ in range(epochs):
        for images, labels in tqdm(trainloader):
            optimizer.zero_grad()
            criterion(net(images.to(DEVICE)), labels.to(DEVICE)).backward()
            optimizer.step()


def test(net, testloader):
    """Validate the model on the test set."""
    criterion = torch.nn.CrossEntropyLoss()
    correct, loss = 0, 0.0
    with torch.no_grad():
        for images, labels in tqdm(testloader):
            outputs = net(images.to(DEVICE))
            labels = labels.to(DEVICE)
            loss += criterion(outputs, labels).item()
            correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
    accuracy = correct / len(testloader.dataset)
    return loss, accuracy

In [38]:
net = Net().to(DEVICE)
train_loader, test_loader = load_data()

In [43]:
# Define Flower client
class FlowerClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        return [val.cpu().numpy() for _, val in net.state_dict().items()]

    def set_parameters(self, parameters):
        params_dict = zip(net.state_dict().keys(), parameters)
        state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
        net.load_state_dict(state_dict, strict=True)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        train(net, trainloader, epochs=1)
        return self.get_parameters(config={}), len(trainloader.dataset), {}

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        loss, accuracy = test(net, testloader)
        return loss, len(testloader.dataset), {"accuracy": accuracy}

In [44]:
# Start Flower client
fl.client.start_numpy_client(
    server_address="127.0.0.1:8080",
    client=FlowerClient(),
)

INFO flwr 2023-06-10 17:58:23,657 | grpc.py:50 | Opened insecure gRPC connection (no certificates were passed)
DEBUG flwr 2023-06-10 17:58:23,669 | connection.py:39 | ChannelConnectivity.IDLE
DEBUG flwr 2023-06-10 17:58:23,672 | connection.py:39 | ChannelConnectivity.CONNECTING
DEBUG flwr 2023-06-10 17:58:23,674 | connection.py:39 | ChannelConnectivity.TRANSIENT_FAILURE
DEBUG flwr 2023-06-10 17:58:23,876 | connection.py:113 | gRPC channel closed


_MultiThreadedRendezvous: <_MultiThreadedRendezvous of RPC that terminated with:
	status = StatusCode.UNAVAILABLE
	details = "failed to connect to all addresses; last error: UNKNOWN: ipv4:172.31.2.4:8080: HTTP proxy returned response code 403"
	debug_error_string = "UNKNOWN:failed to connect to all addresses; last error: UNKNOWN: ipv4:172.31.2.4:8080: HTTP proxy returned response code 403 {grpc_status:14, created_time:"2023-06-10T17:58:23.674558678+05:30"}"
>