In [1]:
!pip install backpack-for-pytorch==1.3.0

Collecting backpack-for-pytorch==1.3.0
  Downloading backpack_for_pytorch-1.3.0-py3-none-any.whl (119 kB)
[K     |████████████████████████████████| 119 kB 12.6 MB/s 
[?25hCollecting einops<1.0.0,>=0.3.0
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops, backpack-for-pytorch
Successfully installed backpack-for-pytorch-1.3.0 einops-0.4.1


# Helper

In [2]:
import logging


class FedLogger:

    __instance = None
    __restart = None

    """
    Initialize
    """

    @staticmethod
    def getLogger(restart, filename):
        """ Static access method. """
        if FedLogger.__instance is None:
            FedLogger(filename)
        else:
            # Update logger
            FedLogger.__update_logger(filename)

        FedLogger.__restart = restart
        return FedLogger.__instance

    def __init__(self, filename):
        """ Virtually private constructor. """
        if FedLogger.__instance != None:
            raise Exception("Logger is a singleton!")
        else:
            FedLogger.__instance = self
            FedLogger.__update_logger(filename)

    """
    Private method
    """

    @staticmethod
    def __update_logger(filename):

        if FedLogger.__instance is None:
            raise Exception("Please init logger first!")

        logger = logging.getLogger()

        for handler in logger.handlers[:]:
            logger.removeHandler(handler)

        logger.setLevel(logging.INFO)

        formatter = logging.Formatter(
            "%(asctime)s [%(levelname)s] %(message)s")

        file_handler = logging.FileHandler(filename)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)

    """
    Public method
    """

    def log(self, msg):
        logger = logging.getLogger()
        logger.info("Restart - {}, {}".format(FedLogger.__restart, msg))

In [3]:
import copy
import torch
from torch import nn, autograd

from backpack import backpack, extend
from backpack.extensions import SumGradSquared, Variance

"""
Fishr
"""


def compute_irm_penalty(logits, y, loss_fn):
    scale = torch.tensor(1.).requires_grad_()
    if torch.cuda.is_available():
        scale = torch.tensor(1.).cuda().requires_grad_()
    loss = loss_fn(logits * scale, y)
    grad = autograd.grad(loss, [scale], create_graph=True)[0]
    return torch.sum(grad**2)


def compute_grad_variance(input, labels, network, algorithm):
    """
    Main Fishr method that computes the gradient variances using the BackPACK package.
    """
    logits = network(input)
    # bce_extended = extend(nn.BCEWithLogitsLoss(reduction='sum'))
    bce_extended = extend(nn.CrossEntropyLoss(reduction='sum'))
    loss = bce_extended(logits, labels)

    # print('Prediction: {}'.format(logits))
    # print('Real: {}'.format(labels))
    # calling first-order derivatives in the network while maintaining the per-sample gradients

    with backpack(Variance(), SumGradSquared()):
        loss.backward(
            inputs=list(network.parameters()), retain_graph=True, create_graph=True
        )

    dict_grads_variance = {
        name: (
            weights.variance.clone().view(-1)
            if "notcentered" not in algorithm.split("_") else
            weights.sum_grad_squared.clone().view(-1)/input.size(0)
        ) for name, weights in network.named_parameters() if (
            "onlyextractor" not in algorithm.split("_") or
            name not in ["4.weight", "4.bias"]
        )
    }

    return dict_grads_variance


def l2_between_dicts(dict_1, dict_2):
    assert len(dict_1) == len(dict_2)
    dict_1_values = [dict_1[key] for key in sorted(dict_1.keys())]
    dict_2_values = [dict_2[key] for key in sorted(dict_1.keys())]
    return (
        torch.cat(tuple([t.view(-1) for t in dict_1_values])) -
        torch.cat(tuple([t.view(-1) for t in dict_2_values]))
    ).pow(2).sum()


"""
ILC
"""


def get_model_grads(input, labels, network, loss_fn):

    _, logits = network(input)

    loss = loss_fn(logits, labels)
    loss.backward()

    model_params = list(network.parameters())
    param_gradients = []
    for model_param in model_params:
        # Get gradients
        # Note: The gradient of the loss each parameter p is stored in p.grad after the backward
        # See: https://discuss.pytorch.org/t/how-to-get-gradient-of-loss/16955
        grad = model_param.grad
        grad_copy = copy.deepcopy(grad)
        param_gradients.append(grad_copy)

    return param_gradients


"""
Arithmetic mean
"""


def compute_arith_mean(model_params, total_param_gradients):

    param_gradients = [[] for _ in model_params]

    # Loop for each environment
    for env_param_gradients in total_param_gradients:
        for idx, grads in enumerate(param_gradients):
            env_grad = env_param_gradients[idx]
            grads.append(env_grad)

    assert len(param_gradients) == len(model_params)

    for param, grads in zip(model_params, param_gradients):

        # Calculate sign matrix
        grads = torch.stack(grads, dim=0)
        avg_grad = torch.mean(grads, dim=0)
        param.grad = avg_grad


"""
Geometric mean
"""


def compute_geo_mean(model_params, total_param_gradients, algorithm, substitute):

    if "geo_substitute" == algorithm:
        compute_substitute_geo_mean(
            model_params, total_param_gradients, substitute)
    elif "geo_weighted" == algorithm:
        compute_weighted_geo_mean(model_params, total_param_gradients)


def compute_substitute_geo_mean(model_params, total_param_gradients, substitute):

    param_gradients = [[] for _ in model_params]

    # Loop for each environment
    for env_param_gradients in total_param_gradients:
        for idx, grads in enumerate(param_gradients):
            env_grad = env_param_gradients[idx]
            grads.append(env_grad)

    assert len(param_gradients) == len(model_params)

    for param, grads in zip(model_params, param_gradients):

        # Calculate sign matrix
        grads = torch.stack(grads, dim=0)
        sign_matrix = torch.sign(grads)

        avg_sign_matrix = torch.mean(sign_matrix, dim=0)

        # If torch.sign(avg_sign_matrix) == 0, then has equal number of positive and negative numbers
        # Regard the positive numbers are majority signs
        avg_sign = torch.sign(avg_sign_matrix) + (avg_sign_matrix == 0)

        majority_sign_matrix = sign_matrix == avg_sign
        minority_sign_matrix = ~majority_sign_matrix

        grads = majority_sign_matrix * grads + minority_sign_matrix * substitute

        n_agreement_envs = len(grads)
        avg_grad = torch.mean(grads, dim=0)
        substitute_prod_grad = torch.sign(avg_grad) * torch.exp(
            torch.sum(torch.log(torch.abs(grads) + 1e-10), dim=0) / n_agreement_envs)

        param.grad = substitute_prod_grad


def compute_weighted_geo_mean(model_params, total_param_gradients):

    param_gradients = [[] for _ in model_params]

    # Loop for each environment
    for env_param_gradients in total_param_gradients:
        for idx, grads in enumerate(param_gradients):
            env_grad = env_param_gradients[idx]
            grads.append(env_grad)

    assert len(param_gradients) == len(model_params)

    for param, grads in zip(model_params, param_gradients):

        # Calculate sign matrix
        grads = torch.stack(grads, dim=0)
        sign_matrix = torch.sign(grads)

        # Positive & Negative gradients
        positive_sign_matrix = sign_matrix > 0
        negative_sign_matrix = ~positive_sign_matrix

        # Temporarily replace 0 with 1 to calculate geometric mean
        positive_gradients = positive_sign_matrix * grads + negative_sign_matrix
        negative_gradients = negative_sign_matrix * grads + positive_sign_matrix

        # Temporarily replace 0 with 1 to prevent demoninator to be 0
        n_agreement_envs = len(grads)
        n_positive_envs = torch.sum(positive_sign_matrix, dim=0)
        n_negative_envs = torch.sum(negative_sign_matrix, dim=0)

        n_positive_envs_denominator = n_positive_envs + (n_positive_envs == 0)
        n_negative_envs_denominator = n_negative_envs + (n_negative_envs == 0)

        # Weighted geometric mean
        positive_prod_gradients = (n_positive_envs / n_agreement_envs) * torch.exp(torch.sum(
            torch.log(torch.abs(positive_gradients) + 1e-10), dim=0) / n_positive_envs_denominator)
        negative_prod_gradients = (n_negative_envs / n_agreement_envs) * torch.exp(torch.sum(
            torch.log(torch.abs(negative_gradients) + 1e-10), dim=0) / n_negative_envs_denominator)

        weighted_prod_grad = positive_prod_gradients - negative_prod_gradients
        param.grad = weighted_prod_grad

# Data Loader

In [4]:
from abc import ABC, abstractmethod
from enum import Enum

import torch
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms


"""
Data Loader Type
"""


class DataLoaderType(Enum):
    COLOR_MNIST = 0
    ROTATE_CIFAR = 1


"""
Abstract Data Loader
"""


class AbstractDataLoader(ABC):

    @abstractmethod
    def combine_envs(self, envs):
        raise Exception("Abstract method should be implemented")

    @abstractmethod
    def make_environment(self, images, labels, **kwargs):
        raise Exception("Abstract method should be implemented")

    def create_data_loader(self, x, y, batch_size):

        data_set = self.__convert_to_tensor(x, y)
        data_loader = DataLoader(data_set,
                                 shuffle=True,
                                 batch_size=batch_size)
        return data_loader

    def __convert_to_tensor(self, x, y):
        assert x.shape[0] == y.shape[0]

        tensor_list = []
        for idx in range(x.shape[0]):
            data_x, data_y = x[idx], y[idx]
            tensor_list.append((data_x, data_y))

        return tensor_list


"""
Color MNIST
"""


class ColorMNISTDataLoader(AbstractDataLoader):

    def combine_envs(self, envs):
        raise Exception("Method is not supported!")

    def make_environment(self, images, labels, **kwargs):

        label_flipping_prob = kwargs.get("label_flipping_prob")
        if label_flipping_prob is None:
            raise Exception("Need label flipping probability!")

        color_flipping_prob = kwargs.get("color_flipping_prob")
        if color_flipping_prob is None:
            raise Exception("Need color flipping probability!")

        def torch_bernoulli(p, size):
            return (torch.rand(size) < p).float()

        def torch_xor(a, b):
            return (a - b).abs()  # Assumes both inputs are either 0 or 1

        # 2x subsample for computational convenience
        images = images.reshape((-1, 28, 28))[:, ::2, ::2]
        # Assign a binary label based on the digit; flip label with probability
        labels = (labels < 5).float()
        labels = torch_xor(labels, torch_bernoulli(
            label_flipping_prob, len(labels)))
        # Assign a color based on the label; flip the color with probability e
        colors = torch_xor(labels, torch_bernoulli(
            color_flipping_prob, len(labels)))
        # Apply the color to the image by zeroing out the other color channel
        images = torch.stack([images, images], dim=1)
        images[torch.tensor(range(len(images))),
               (1 - colors).long(), :, :] *= 0

        images, labels = images.float() / 255., labels[:, None]

        if torch.cuda.is_available():
            return {'images': images.cuda(), 'labels': labels.cuda()}

        return {'images': images, 'labels': labels}

    def create_data_loader(self, x, y, batch_size):
        return super().create_data_loader(x, y, batch_size)


"""
Rotated CIFAR-10
"""


class RotatedCifarDataLoader(AbstractDataLoader):

    def combine_envs(self, envs):

        images, labels = [], []
        for env in envs:
            image, label = env["images"], env["labels"]
            images.append(image)
            labels.append(label)

        images = torch.cat(tuple(images))
        labels = torch.cat(tuple(labels))

        return {'images': images, 'labels': labels}

    def make_environment(self, images, labels, **kwargs):

        from_angle = kwargs.get('from_angle')
        to_angle = kwargs.get('to_angle')

        if from_angle is None:
            raise Exception("Need from angle!")

        if to_angle is None:
            raise Exception("Need to angle!")

        rotation = transforms.Compose([transforms.ToPILImage(),
                                       transforms.RandomRotation(
                                           degrees=(from_angle, to_angle)),
                                       transforms.ToTensor()])

        # images = images[:, ::2, ::2, :]
        x = torch.zeros(len(images), 3, 32, 32)
        for i in range(len(images)):
            x[i] = rotation(images[i])

        images = x

        images, labels = torch.Tensor(images), torch.Tensor(labels)
        labels = labels.type(torch.int64)

        if torch.cuda.is_available():
            return {'images': images.cuda(), 'labels': labels.cuda()}

        return {'images': images, 'labels': labels}

    def create_data_loader(self, x, y, batch_size):
        return super().create_data_loader(x, y, batch_size)


"""
Data Loader Factory
"""


class DataLoaderFactory:

    __color_mnist = None
    __rotate_cifar = None

    @staticmethod
    def get_data_loader(type):

        if type == DataLoaderType.COLOR_MNIST:
            if DataLoaderFactory.__color_mnist is None:
                DataLoaderFactory.__color_mnist = ColorMNISTDataLoader()
            return DataLoaderFactory.__color_mnist

        elif type == DataLoaderType.ROTATE_CIFAR:
            if DataLoaderFactory.__rotate_cifar is None:
                DataLoaderFactory.__rotate_cifar = RotatedCifarDataLoader()
            return DataLoaderFactory.__rotate_cifar

        else:
            raise Exception("Unsupported data loader type: {}".format(type))


# Model

In [5]:
import torch
from torch import nn

import torchvision
from backpack import extend

"""CIFAR ResNet"""


class CifarResNet(nn.Module):

    def __init__(self, in_features, out_features):

        super(CifarResNet, self).__init__()
        self.network = torchvision.models.resnet18(pretrained=True)
        self.classifier = extend(nn.Linear(in_features=in_features,
                                    out_features=out_features))

    def forward(self, input):

        features = self.network(input)
        logits = self.classifier(features)
        return features, logits


"""MNIST MLP"""


class MnistMLP(nn.Module):

    def __init__(self, hidden_dim):
        super(MnistMLP, self).__init__()

        lin1 = nn.Linear(2 * 14 * 14, hidden_dim)
        lin2 = nn.Linear(hidden_dim, hidden_dim)

        self.classifier = (nn.Linear(hidden_dim, 1))
        for lin in [lin1, lin2, self.classifier]:
            nn.init.xavier_uniform_(lin.weight)
            nn.init.zeros_(lin.bias)

        self._main = nn.Sequential(
            lin1, nn.ReLU(True), lin2, nn.ReLU(True))
        self.alllayers = extend(
            nn.Sequential(lin1, nn.ReLU(True), lin2,
                          nn.ReLU(True), self.classifier)
        )

    @staticmethod
    def prepare_input(input):
        return input.view(input.shape[0], 2 * 14 * 14)

    def forward(self, input):
        out = self.prepare_input(input)
        features = self._main(out)
        logits = self.classifier(features)
        return features, torch.sigmoid(logits)


# Trainer

In [6]:
# from helper import *


class Trainer:

    def __init__(self, evaluator_helper):
        self.__evaluator_helper = evaluator_helper

    def set_logger(self, logger):
        self.__logger = logger

    def train_model(self, model, optimizer, local_model, local_optimizer, train_loader, train_images, train_labels, round_idx, flags):

        # t = torch.cuda.get_device_properties(0).total_memory
        # a = torch.cuda.memory_allocated(0)

        # logger.info("Memory before calculating gradients:")
        # logger.info(convert_size(t))
        # logger.info(convert_size(a))

        algorithm = flags.algorithm

        if "fishr" in algorithm.split("_") and ("geo" in algorithm.split("_") or "arith" in algorithm.split("_")):

            """ Fishr + Geo Mean """
            final_loss = 0
            final_acc = 0

            total_param_gradients = []

            # Set mode to train model
            model.train()

            # Start training
            for (images, labels) in train_loader:

                optimizer.zero_grad()
                param_gradients = get_model_grads(
                    images, labels, model, self.__evaluator_helper.mean_nll)

                _, logits = model(images)
                loss = self.__evaluator_helper.mean_nll(logits, labels)
                acc = self.__evaluator_helper.mean_accuracy(logits, labels)

                final_loss += loss
                final_acc += acc

                total_param_gradients.append(param_gradients)

            # self.__logger.log(len(total_param_gradients))

            # ILC
            local_model_params = model.state_dict()
            local_model_params = copy.deepcopy(local_model_params)
            local_model.load_state_dict(local_model_params)

            # TODO
            local_model.train()
            local_optimizer.zero_grad()
            if "geo" in algorithm.split("_"):
                compute_geo_mean(list(local_model.parameters()),
                                 total_param_gradients, "geo_weighted", 0.001)
            elif "arith" in algorithm.split("_"):
                compute_arith_mean(
                    list(local_model.parameters()), total_param_gradients)
            local_optimizer.step()

            # Fishr
            features, _ = local_model(train_images)
            grad_statistics = compute_grad_variance(
                features, train_labels, local_model.classifier, algorithm)

            # Calculate loss and accuracy
            train_loss = final_loss / len(train_loader)
            train_acc = final_acc / len(train_loader)

        else:

            # Set mode to train model
            model.train()

            # Start training
            features, logits = model(train_images)

            self.__logger.log(logits.shape)
            self.__logger.log(train_labels.shape)

            train_loss = self.__evaluator_helper.mean_nll(logits, train_labels)
            train_acc = self.__evaluator_helper.mean_accuracy(
                logits, train_labels)

            optimizer.zero_grad()

            if "arith" in algorithm.split("_") or "geo" in algorithm.split("_") or "hybrid" in algorithm.split("_"):
                model_grads = get_model_grads(
                    train_images, train_labels, model, self.__evaluator_helper.mean_nll)

            if "fishr" in algorithm.split("_"):
                grad_variance = compute_grad_variance(
                    features, train_labels, model.classifier, algorithm)

            if "hybrid" in algorithm.split("_"):
                grad_statistics = (grad_variance, model_grads)
            elif "fishr" in algorithm.split("_"):
                grad_statistics = grad_variance
            else:
                # Arithmetic or geometric mean
                grad_statistics = model_grads

        # t = torch.cuda.get_device_properties(0).total_memory
        # a = torch.cuda.memory_allocated(0)

        # logger.info("Memory after calculating gradients:")
        # logger.info(convert_size(t))
        # logger.info(convert_size(a))

        return train_loss, train_acc, grad_statistics


# Evaluator

In [7]:
from abc import ABC, abstractmethod
from enum import Enum

from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc

import torch
from torch import nn

"""
Evaluator Helper Type
"""


class EvaluatorHelperType(Enum):
    BINARY = 0
    MULTIPLE = 1


"""
Abstract Evaluator Helper
"""


class AbstractEvaluatorHelper(ABC):

    @abstractmethod
    def mean_nll(self, logits, y):
        raise Exception("Abstract method should be implemented")

    @abstractmethod
    def mean_accuracy(self, logits, y):
        raise Exception("Abstract method should be implemented")

    @abstractmethod
    def mean_roc_auc(self, logits, y):
        raise Exception()

    @abstractmethod
    def mean_pr_auc(self, logits, y):
        raise Exception()


"""
Binary Classification
"""


class BinaryClassificationEvaluatorHelper(AbstractEvaluatorHelper):

    def mean_nll(self, logits, y):
        critetion = nn.BCELoss()
        return critetion(logits, y)

    def mean_accuracy(self, logits, y):
        preds = (logits > 0.5).float()
        return ((preds - y).abs() < 1e-2).float().mean()

    def mean_roc_auc(self, logits, y):
        preds = (logits > 0.5).float()
        y = y.detach().cpu().numpy()
        preds = preds.detach().cpu().numpy()
        return roc_auc_score(y, preds)

    def mean_pr_auc(self, logits, y):
        preds = (logits > 0.5).float()
        y = y.detach().cpu().numpy()
        preds = preds.detach().cpu().numpy()
        precision, recall, _ = precision_recall_curve(y, preds)
        return auc(recall, precision)


"""
Multi Classification
"""


class MultiClassificationEvaluatorHelper(AbstractEvaluatorHelper):

    def mean_nll(self, logits, y):
        critetion = nn.CrossEntropyLoss()
        return critetion(logits, y)

    def mean_accuracy(self, logits, y):
        _, preds = torch.max(logits, 1)
        correct = (preds == y).sum().item()
        total = y.size(0)
        return correct / total

    def mean_roc_auc(self, logits, y):
        return None

    def mean_pr_auc(self, logits, y):
        return None


"""
Evaluator Helper Factory
"""


class EvaluatorHelperFactory:

    __binary = None
    __multi = None

    @staticmethod
    def get_evaluator(type):

        if type == EvaluatorHelperType.BINARY:
            if EvaluatorHelperFactory.__binary is None:
                EvaluatorHelperFactory.__binary = BinaryClassificationEvaluatorHelper()
            return EvaluatorHelperFactory.__binary

        elif type == EvaluatorHelperType.MULTIPLE:
            if EvaluatorHelperFactory.__multi is None:
                EvaluatorHelperFactory.__multi = MultiClassificationEvaluatorHelper()
            return EvaluatorHelperFactory.__multi

        else:
            raise Exception(
                "Unsupported evaluator helper type: {}".format(type))

In [8]:
import torch


class Evaluator:

    def __init__(self, helper):
        self.__helper = helper

    def evaluate_model(self, model, test_loader, test_batch_size):

        with torch.no_grad():

            # Set mode to evaluate model
            model.eval()

            # Start evaluating model
            final_loss = 0
            final_acc = 0
            final_roc = []
            final_pr = []

            for (images, labels) in test_loader:

                _, logits = model(images)

                loss = self.__helper.mean_nll(logits, labels)
                acc = self.__helper.mean_accuracy(logits, labels)

                final_loss += loss
                final_acc += acc

                if len(labels) == test_batch_size:
                    roc = self.__helper.mean_roc_auc(logits, labels)
                    pr = self.__helper.mean_pr_auc(logits, labels)
                    if roc is not None:
                        final_roc.append(roc)
                    if pr is not None:
                        final_pr.append(pr)

            test_loss = final_loss / len(test_loader)
            test_acc = final_acc / len(test_loader)

            test_roc = None
            if len(final_roc) > 0:
                test_roc = sum(final_roc) / len(final_roc)

            test_pr = None
            if len(final_pr) > 0:
                test_pr = sum(final_pr) / len(final_pr)

            return test_loss, test_acc, test_roc, test_pr

# Client

In [9]:
import torch


class FederatedClient:

    # Init client
    def __init__(self, trainer, evaluator,  client_id, local_model,
                 train_loader, train_images, train_labels,
                 test_loader, learning_rate, logger):

        self.trainer = trainer
        self.evaluator = evaluator

        self.client_id = client_id

        self.train_loader = train_loader
        self.test_loader = test_loader

        self.train_images = train_images
        self.train_labels = train_labels

        self.local_model = local_model
        self.logger = logger

        if torch.cuda.is_available():
            self.local_model = self.local_model.to('cuda')

        self.local_optimizer = torch.optim.Adam(self.local_model.parameters(),
                                                lr=learning_rate)

    # Train model

    def train(self, global_model, global_optimizer, round_idx, flags):

        # Start training
        dict_grad_statistics = None
        train_loss, train_acc, dict_grad_statistics = self.trainer.train_model(global_model,
                                                                               global_optimizer,
                                                                               self.local_model,
                                                                               self.local_optimizer,
                                                                               self.train_loader,
                                                                               self.train_images,
                                                                               self.train_labels,
                                                                               round_idx,
                                                                               flags)
        train_history = (train_loss, train_acc)
        self.logger.log('Client[{}], Round [{}], Loss: [{}], Accuracy: [{}]'.format(
            self.client_id, round_idx + 1, train_loss, train_acc))

        # Evaluation
        test_batch_size = flags.test_batch_size
        test_history = self.evaluator.evaluate_model(global_model,
                                                     self.test_loader,
                                                     test_batch_size)

        return train_history, test_history, dict_grad_statistics


# Executor

In [10]:
from abc import ABC, abstractmethod


"""
Executor Interface
"""


class AbstractExecutor(ABC):

    @abstractmethod
    def __init__(self):
        raise Exception("Abstract method should be implemented")

    @abstractmethod
    def is_eligible_executor(self, dataset):
        raise Exception("Abstract method should be implemented")

    @abstractmethod
    def run(self, restart, flags):
        raise Exception("Abstract method should be implemented")

In [11]:
# from abstract_executor import AbstractExecutor


class ColorMNISTExecutor(AbstractExecutor):

    COLOR_MNIST_DATASET = "color_mnist"

    def __init__(self):
        pass

    def is_eligible_executor(self, dataset):
        return dataset == self.COLOR_MNIST_DATASET

    def run(self, restart, flags):

        algorithm = flags.algorithm

        total_feature = flags.total_feature
        learning_rate = flags.learning_rate
        weight_decay = flags.weight_decay

        # learning_rate_decay_step_size = 100
        # learning_rate_decay = 0.98

        train_batch_size = flags.train_batch_size
        test_batch_size = flags.test_batch_size

        num_steps = flags.num_steps
        num_rounds = flags.num_rounds
        num_epochs = flags.num_epochs

        hidden_dim = flags.hidden_dim

        penalty_anneal_iters = flags.penalty_anneal_iters
        penalty_weight_factor = flags.penalty_weight_factor
        penalty_weight = flags.penalty_weight

        return super().run(flags)

In [12]:
# from abstract_executor import AbstractExecutor
# from client import FederatedClient
# from fed_logger import FedLogger
# from trainer import Trainer
# from evaluator import Evaluator
# from evaluator_helper import *
# from data_loader import *
# from model import *
# from helper import *

from torchvision import datasets
import numpy as np
import os

import matplotlib.pyplot as plt


class RotateCifarExecutor(AbstractExecutor):

    ROTATE_CIFAR_DATASET = "rotate_cifar"

    """Initialize"""

    def __init__(self):
        self.data_loader = DataLoaderFactory.get_data_loader(
            DataLoaderType.ROTATE_CIFAR)
        self.evaluator_helper = EvaluatorHelperFactory.get_evaluator(
            EvaluatorHelperType.MULTIPLE)
        self.trainer = Trainer(self.evaluator_helper)
        self.evaluator = Evaluator(self.evaluator_helper)

    """Public Methods"""

    def is_eligible_executor(self, dataset):
        return dataset == self.ROTATE_CIFAR_DATASET

    def run(self, restart, flags):

        algorithm = flags.algorithm

        self.logger = FedLogger.getLogger(restart + 1,
                                          "cifar-{}-restart {}".format(algorithm, restart + 1))
        self.trainer.set_logger(self.logger)

        learning_rate = flags.learning_rate
        weight_decay = flags.weight_decay

        # learning_rate_decay_step_size = 100
        # learning_rate_decay = 0.98

        train_batch_size = flags.train_batch_size
        test_batch_size = flags.test_batch_size

        num_rounds = flags.num_rounds

        penalty_anneal_iters = flags.penalty_anneal_iters
        penalty_weight_factor = flags.penalty_weight_factor
        penalty_weight = flags.penalty_weight

        train_envs, test_envs, ood_validation = self.__load_dataset()
        clients = self.__create_clients(
            train_envs, test_envs, train_batch_size, test_batch_size, learning_rate)

        global_model = CifarResNet(1000, 10)
        if torch.cuda.is_available():
            global_model = global_model.to('cuda')

        global_optimizer = torch.optim.Adam(global_model.parameters(),
                                            lr=learning_rate,
                                            weight_decay=weight_decay)
        final_train_loss_history = []
        final_train_acc_history = []
        final_test_loss_history = []
        final_test_acc_history = []
        final_ood_loss_history = []
        final_ood_acc_history = []
        # final_ood_pr_history = []
        # final_ood_roc_history = []

        best_model = None
        best_round = 0
        best_loss = float("inf")
        best_acc = 0
        # best_pr_auc = 0
        # best_roc_auc = 0

        for round_idx in range(num_rounds):

            self.logger.log('\n')
            self.logger.log('########################################')
            self.logger.log('Start training round: {}'.format(round_idx + 1))
            self.logger.log('########################################')
            self.logger.log('\n')

            # 1. Load global params
            global_params = global_model.state_dict()

            # 2. Federated training
            train_loss_history, train_acc_history = [], []
            test_loss_history, test_acc_history = [], []
            model_grads_history, grads_variance_history = [], []

            for client in clients:

                train_history, test_history, dict_grad_statistics = client.train(
                    global_model, global_optimizer, round_idx, flags)

                train_loss, train_acc = train_history
                test_loss, test_acc, _, _ = test_history

                train_loss_history.append(train_loss)
                train_acc_history.append(train_acc)

                test_loss_history.append(test_loss)
                test_acc_history.append(test_acc)

                if "hybrid" in algorithm.split("_"):
                    grad_variance, model_grads = dict_grad_statistics
                    grads_variance_history.append(grad_variance)
                    model_grads_history.append(model_grads)
                elif "fishr" in algorithm.split("_"):
                    grads_variance_history.append(dict_grad_statistics)
                else:
                    model_grads_history.append(dict_grad_statistics)

            final_train_loss = torch.stack(train_loss_history).mean()
            final_train_acc = sum(train_acc_history) / len(train_acc_history)

            final_test_loss = torch.stack(test_loss_history).mean()
            final_test_acc = sum(test_acc_history) / len(test_acc_history)

            final_train_loss_np = final_train_loss.detach().cpu().numpy().copy()
            final_train_acc_np = final_train_acc
            final_test_loss_np = final_test_loss.detach().cpu().numpy().copy()
            final_test_acc_np = final_test_acc

            final_train_loss_history.append(final_train_loss_np)
            final_train_acc_history.append(final_train_acc_np)
            final_test_loss_history.append(final_test_loss_np)
            final_test_acc_history.append(final_test_acc_np)

            # 3. Arithmetic mean / geometric mean
            if "arith" in algorithm.split("_") and "fishr" not in algorithm.split("_"):
                global_optimizer.zero_grad()
                compute_arith_mean(
                    list(global_model.parameters()), model_grads_history)
                global_optimizer.step()

                self.logger.log(">>>>>>>>> Arith mean learning rate:")
                for param_group in global_optimizer.param_groups:
                    self.logger.log(param_group['lr'])

                # global_scheduler.step()

            if "geo" in algorithm.split("_") and "fishr" not in algorithm.split("_"):
                global_optimizer.zero_grad()
                compute_geo_mean(list(global_model.parameters()),
                                 model_grads_history, algorithm, 0.001)
                global_optimizer.step()

            # 4. Update global parameter based on gradients
            if "fishr" in algorithm.split("_"):

                dict_grad_statistics_averaged = {}

                first_dict_grad_statistics = grads_variance_history[0]
                for name in first_dict_grad_statistics:

                    grads_list = []
                    for dict_grad_statistics in grads_variance_history:
                        grads = dict_grad_statistics[name]
                        grads_list.append(grads)

                    dict_grad_statistics_averaged[name] = torch.stack(
                        grads_list, dim=0).mean(dim=0)

                fishr_penalty = 0
                for dict_grad_statistics in grads_variance_history:
                    fishr_penalty += l2_between_dicts(
                        dict_grad_statistics, dict_grad_statistics_averaged)

                if "hybrid" in algorithm.split("_"):

                    # Hybrid fishr
                    weight_norm = torch.tensor(0.)
                    if torch.cuda.is_available():
                        weight_norm = torch.tensor(0.).cuda()
                    for w in global_model.parameters():
                        grad = w.grad
                        weight_norm += w.norm().pow(2)

                    # if round_idx % 10 == 0 and round_idx != 0:
                    #     penalty_weight *= 1.01
                    # self.logger.log("***** Penalty weight: {}".format(penalty_weight))
                    # penalty_weight = (penalty_weight_factor if round_idx >= penalty_anneal_iters else 1.0)

                    # loss = weight_decay * weight_norm + penalty_weight * fishr_penalty
                    # loss = penalty_weight * fishr_penalty
                    penalty_weight = (
                        penalty_weight_factor if round_idx >= penalty_anneal_iters else penalty_weight)
                    
                    if penalty_weight > 1.0:
                        model_grads_history = model_grads_history / penalty_weight
                    else:
                        loss = penalty_weight * fishr_penalty

                    # Gradients computed by fishr loss
                    global_optimizer.zero_grad()
                    loss.backward()

                    model_params = list(global_model.parameters())
                    fishr_gradients = []
                    for model_param in model_params:
                        grad = model_param.grad
                        grad_copy = copy.deepcopy(grad.detach())
                        fishr_gradients.append(grad_copy)

                    # First, update model using geometric mean
                    global_optimizer.zero_grad()
                    compute_geo_mean(list(global_model.parameters()),
                                     model_grads_history, 'geo_weighted', 0.001)
                    global_optimizer.step()

                    self.logger.log(">>>>>>>>> Geo mean learning rate:")
                    for param_group in global_optimizer.param_groups:
                        self.logger.log(param_group['lr'])

                    # Then, update model using fishr loss
                    global_optimizer.zero_grad()
                    updated_model_params = list(global_model.parameters())

                    for param, grads in zip(updated_model_params, fishr_gradients):
                        param.grad = grads

                    global_optimizer.step()

                    self.logger.log(">>>>>>>>> Fishr learning rate:")
                    for param_group in global_optimizer.param_groups:
                        self.logger.log(param_group['lr'])

                    # global_scheduler.step()

                else:

                    loss = final_train_loss.clone()

                    weight_norm = torch.tensor(0.)
                    if torch.cuda.is_available():
                        weight_norm = torch.tensor(0.).cuda()
                    for w in global_model.parameters():
                        grad = w.grad
                        weight_norm += w.norm().pow(2)

                    # loss += weight_decay * weight_norm

                    self.logger.log('Before Loss: {}'.format(loss))
                    penalty_weight = (
                        penalty_weight_factor if round_idx >= penalty_anneal_iters else penalty_weight)

                    loss += penalty_weight * fishr_penalty
                    if penalty_weight > 1.0:
                        # Rescale the entire loss to keep backpropagated gradients in a reasonable range
                        loss /= penalty_weight
                    self.logger.log('Fishr Loss: {}'.format(fishr_penalty))
                    self.logger.log('After Loss: {}'.format(loss))

                    # Vanilla fishr
                    global_optimizer.zero_grad()
                    loss.backward()
                    global_optimizer.step()

                    self.logger.log(">>>>>>>>> Fishr learning rate:")
                    for param_group in global_optimizer.param_groups:
                        self.logger.log(param_group['lr'])

                    # global_scheduler.step()

            # 5. Evaluation
            ood_test_images, ood_test_labels = ood_validation["images"], ood_validation["labels"]
            ood_test_loader = self.data_loader.create_data_loader(
                ood_test_images, ood_test_labels, test_batch_size)
            ood_test_history = self.evaluator.evaluate_model(global_model,
                                                             ood_test_loader,
                                                             test_batch_size)
            ood_test_loss, ood_test_acc, _, _ = ood_test_history

            ood_test_loss_np = ood_test_loss.detach().cpu().numpy().copy()
            ood_test_acc_np = ood_test_acc

            final_ood_loss_history.append(ood_test_loss_np)
            final_ood_acc_history.append(ood_test_acc_np)
            # final_ood_roc_history.append(ood_test_roc)

            if ood_test_loss < best_loss and round_idx > 5:
                best_loss = ood_test_loss
                best_acc = ood_test_acc
                # best_roc_auc = ood_test_roc
                # best_pr_auc = odd_test_pr
                best_model = global_model
                best_round = round_idx

            self.logger.log('\n')
            self.logger.log('########################################')
            self.logger.log('End training round: {}'.format(round_idx + 1))
            self.logger.log('[Train] Loss: {}, Accuracy: {}'.format(
                final_train_loss, final_train_acc))
            self.logger.log('[Test] Loss: {}, Accuracy: {}'.format(
                final_test_loss, final_test_acc))
            self.logger.log('[OOD Test] Loss: {}, Accuracy: {}'.format(
                ood_test_loss, ood_test_acc))
            self.logger.log('########################################')
            self.logger.log('\n')

            if round_idx % 10 == 0 and round_idx > 5:
                self.logger.log(learning_rate)
                path = 'cifar-{}-restart-{}-output_checkpoint{}'.format(
                    algorithm, restart, str(round_idx))
                self.logger.log(global_model.state_dict())
                torch.save({'global_model': global_model.state_dict(),
                            'best_model': best_model.state_dict(),
                            'best_round': best_round,
                            'best_loss': best_loss,
                            # 'best_roc_auc': best_roc_auc,
                            # 'best_pr_auc': best_pr_auc,
                            'best_acc': best_acc,
                            'global_optimizer': global_optimizer.state_dict(),
                            'final_train_loss_history': final_train_loss_history,
                            'final_train_acc_history': final_train_acc_history,
                            'final_test_loss_history': final_test_loss_history,
                            'final_test_acc_history': final_test_acc_history,
                            'final_ood_loss_history': final_ood_loss_history,
                            'final_ood_acc_history': final_ood_acc_history}, path)

        best_loss = best_loss.cpu().numpy().copy()

        plt.title('Train & Test Loss')
        plt.plot(final_train_loss_history, label='train_loss')
        plt.plot(final_test_loss_history, label='test_loss')
        plt.plot(final_ood_loss_history, label='ood_test_loss')
        plt.ylim(0, 5)
        plt.hlines(best_loss, 0, best_round, linestyles='dashed')
        plt.xlabel('Round')
        plt.ylabel('Loss')
        plt.legend(['Train Loss', 'Test Loss', 'OOD Test Loss'])
        plt.savefig('loss-{}-restart {}.png'.format(algorithm, restart + 1))
        plt.close()

        plt.title('Train & Test Accuracy')
        plt.plot(final_train_acc_history, label='train_acc')
        plt.plot(final_test_acc_history, label='test_acc')
        plt.plot(final_ood_acc_history, label='ood_test_acc')
        plt.ylim(0, 1)
        plt.hlines(best_acc, 0, best_round, linestyles='dashed')
        plt.xlabel('Round')
        plt.ylabel('Accuracy')
        plt.legend(['Train Accuracy', 'Test Accuracy', 'OOD Test Accuracy'])
        plt.savefig('acc-{}-restart {}.png'.format(algorithm, restart + 1))
        plt.close()

        self.logger.log("Best Loss: {}".format(best_loss))
        self.logger.log("Best Accuracy: {}".format(best_acc))
        self.logger.log("Best Round: {}".format(best_round))

    """
    ### Load Dataset
    """

    def __load_dataset(self):

        cifar = datasets.CIFAR10('~/datasets/cifar', train=True, download=True)

        cifar_train = (cifar.data[:40000], cifar.targets[:40000])
        cifar_val = (cifar.data[40000:], cifar.targets[40000:])

        rng_state = np.random.get_state()
        np.random.shuffle(cifar_train[0])
        np.random.set_state(rng_state)
        np.random.shuffle(cifar_train[1])

        self.logger.log((cifar_val[0]).shape)

        train_client_1_env_1 = self.data_loader.make_environment(
            cifar_train[0][:30000:6], cifar_train[1][:30000:6], from_angle=10, to_angle=10)
        train_client_1_env_2 = self.data_loader.make_environment(
            cifar_train[0][1:30001:6], cifar_train[1][1:30001:6], from_angle=25, to_angle=25)
        train_client_1_env_3 = self.data_loader.make_environment(
            cifar_train[0][2:30002:6], cifar_train[1][2:30002:6], from_angle=40, to_angle=40)

        train_client_2_env_1 = self.data_loader.make_environment(
            cifar_train[0][3:30003:6], cifar_train[1][3:30003:6], from_angle=60, to_angle=60)
        train_client_2_env_2 = self.data_loader.make_environment(
            cifar_train[0][4:30004:6], cifar_train[1][4:30004:6], from_angle=75, to_angle=75)
        train_client_2_env_3 = self.data_loader.make_environment(
            cifar_train[0][5:30005:6], cifar_train[1][5:30005:6], from_angle=90, to_angle=90)

        train_envs = [
            # Client 1 Train
            self.data_loader.combine_envs([train_client_1_env_1,
                                           train_client_1_env_2, train_client_1_env_3]),
            # Client 2 Train
            self.data_loader.combine_envs([train_client_2_env_1,
                                           train_client_2_env_2, train_client_2_env_3])
        ]

        test_client_1_env_1 = self.data_loader.make_environment(
            cifar_train[0][30000::6], cifar_train[1][30000::6], from_angle=10, to_angle=10)
        test_client_1_env_2 = self.data_loader.make_environment(
            cifar_train[0][30001::6], cifar_train[1][30001::6], from_angle=25, to_angle=25)
        test_client_1_env_3 = self.data_loader.make_environment(
            cifar_train[0][30002::6], cifar_train[1][30002::6], from_angle=40, to_angle=40)

        test_client_2_env_1 = self.data_loader.make_environment(
            cifar_train[0][30003::6], cifar_train[1][30003::6], from_angle=60, to_angle=60)
        test_client_2_env_2 = self.data_loader.make_environment(
            cifar_train[0][30004::6], cifar_train[1][30004::6], from_angle=75, to_angle=75)
        test_client_2_env_3 = self.data_loader.make_environment(
            cifar_train[0][30005::6], cifar_train[1][30005::6], from_angle=90, to_angle=90)

        test_envs = [
            # Client 1 Validation
            self.data_loader.combine_envs([test_client_1_env_1, test_client_1_env_2,
                                           test_client_1_env_3]),
            # Client 2 Validation
            self.data_loader.combine_envs([test_client_2_env_1,
                                           test_client_2_env_2, test_client_2_env_3])
        ]

        ood_validation = self.data_loader.make_environment(
            cifar_val[0], cifar_val[1], from_angle=-90, to_angle=90)

        return train_envs, test_envs, ood_validation

    """
    ### Create federated clients
    """

    def __create_clients(self, train_envs, test_envs, train_batch_size, test_batch_size, learning_rate):

        # Create federated clients
        clients = []

        for client_id, (train_env, test_env) in enumerate(zip(train_envs, test_envs)):

            train_images, train_labels = train_env["images"], train_env["labels"]
            test_images, test_labels = test_env["images"], test_env["labels"]

            train_loader = self.data_loader.create_data_loader(
                train_images, train_labels, train_batch_size)
            test_loader = self.data_loader.create_data_loader(
                test_images, test_labels, test_batch_size)

            # Each client has one local model
            local_model = CifarResNet(1000, 10)
            client = FederatedClient(self.trainer, self.evaluator, client_id, local_model,
                                     train_loader, train_images, train_labels,
                                     test_loader, learning_rate, self.logger)
            clients.append(client)

        return clients


# Main

In [13]:
# from executor_color_mnist import ColorMNISTExecutor
# from executor_rotate_cifar import RotateCifarExecutor

import argparse
import torch

# if not torch.cuda.is_available():
#     raise Exception("Please use CUDA environment!")

# parser = argparse.ArgumentParser()

# """ Select dataset """
# parser.add_argument(
#     '--dataset',
#     type=str,
#     default="color_mnist",
#     choices=[
#         'color_mnist',
#         'rotate_cifar',
#         'e_icu'
#     ]
# )

# """ Select algorithm """
# parser.add_argument(
#     '--algorithm',
#     type=str,
#     default="fishr",
#     choices=[
#         'arith',  # Arithmetic mean
#         'geo_weighted',  # Geometric mean (weighted)
#         'geo_substitute',  # Geometric mean (substituted)
#         'fishr',  # Fishr
#         'fishr_geo'  # Inter-silo fishr + intra-silo geometric mean
#         'fishr_hybrid',  # Inter-silo fishr + inter-silo geometric mean
#     ]
# )

# parser.add_argument('--total_feature', type=int, default=2 * 14 * 14)

# parser.add_argument('--learning_rate', type=float, default=0.0001)
# parser.add_argument('--weight_decay', type=float, default=0.001)

# parser.add_argument('--train_batch_size', type=int, default=32)
# parser.add_argument('--test_batch_size', type=int, default=32)

# parser.add_argument('--hidden_dim', type=int, default=390)

# """ Federated Learning """
# parser.add_argument('--num_restarts', type=int, default=5)  # Total experiments
# parser.add_argument('--num_rounds', type=int, default=501)  # Federated rounds
# # Epochs per federated round
# parser.add_argument('--num_epochs', type=int, default=1)

# """ Fishr """
# parser.add_argument('--penalty_anneal_iters', type=int, default=0)
# parser.add_argument('--penalty_weight_factor', type=float, default=1.0)
# parser.add_argument('--penalty_weight', type=float, default=1.0)


# flags = parser.parse_args()

class Flags:
    dataset = "rotate_cifar"
    algorithm = "fishr"
    total_feature = 2 * 14 * 14
    learning_rate = 0.0003
    weight_decay = 0.001
    train_batch_size = 32
    test_batch_size = 32
    hidden_dim = 390
    num_restarts = 1
    num_rounds = 101
    num_epochs = 1
    penalty_anneal_iters = 50
    penalty_weight_factor = 1.0
    penalty_weight = 0.5

flags = Flags()

# for k, v in sorted(vars(flags).items()):
#     print("\t{}: {}".format(k, v))

dataset = flags.dataset

# Find eligible executor based on dataset
eligible_executor = None
executors = [ColorMNISTExecutor(), RotateCifarExecutor()]

for executor in executors:
    if executor.is_eligible_executor(dataset):
        eligible_executor = executor
        break

if eligible_executor is None:
    raise Exception(
        "Unable to find eligible executor for dataset: {}".format(dataset))

# Execute training
num_restarts = flags.num_restarts

for restart in range(num_restarts):
    print("Restart: {}".format(restart))
    eligible_executor.run(restart, flags)

Restart: 0
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/datasets/cifar/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting /root/datasets/cifar/cifar-10-python.tar.gz to /root/datasets/cifar


2022-04-10 10:16:04,699 [INFO] Restart - 1, (10000, 32, 32, 3)
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


  0%|          | 0.00/44.7M [00:00<?, ?B/s]

2022-04-10 10:16:33,961 [INFO] Restart - 1, 

2022-04-10 10:16:33,964 [INFO] Restart - 1, ########################################
2022-04-10 10:16:33,969 [INFO] Restart - 1, Start training round: 1
2022-04-10 10:16:33,972 [INFO] Restart - 1, ########################################
2022-04-10 10:16:33,974 [INFO] Restart - 1, 

2022-04-10 10:16:35,561 [INFO] Restart - 1, torch.Size([15000, 10])
2022-04-10 10:16:35,562 [INFO] Restart - 1, torch.Size([15000])
  allow_unreachable=True, accumulate_grad=True)  # allow_unreachable flag
2022-04-10 10:16:35,628 [INFO] Restart - 1, Client[0], Round [1], Loss: [3.2789552211761475], Accuracy: [0.1008]


RuntimeError: ignored