### Experiments with different domains and data augmentation (Full set)
This notebook demonstrate the performance of different training methods under different environments.
In order to understand the effect of data augmentation, we create different scenarios where data augmentation is used.

The data augmentation used in this notebook is related to the different domains at hand. Here we use the RotatedMNIST dataset, that consists of rotated digits from the MNIST dataset, where the test set correspond to the digits rotated by 75°.
Furthermore, the data augmentation techniques here employed are the rotation of digits by as much as 15°, 30°, and 45°, as well as applying a guassian blur.

Copyright (C) 2019. Huawei Technologies Co., Ltd. All rights reserved.
This program is free software; you can redistribute it and/or modify
it under the terms of the Apache 2.0 License.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
Apache 2.0 License for more details.


In [1]:
import os

import numpy as np
import torch
from torchvision import transforms

import flsuite
import flsuite.data as data
import flsuite.utils as utils
from flsuite.algs import afl, fed_avg, individual_train
from flsuite.algs.trainers import GroupDRO, ERM

rounds = 40
local_steps = 200
steps = local_steps*rounds

seed = 0
num_clients = 5
batch_size = 64

save = '../data/experiments/data_augmentation/full_set/'
os.environ['CUDA_VISIBLE_DEVICES'] = str(1)
device = torch.device('cuda:0')

dataset = data.datasets.RotatedMNIST('../data/datasets')
train_sets = dataset.datasets[:num_clients]
test_set = dataset.datasets[-1]
all_train_set = data.merge_datasets(train_sets)

all_train_loader = data.build_dataloaders([all_train_set], batch_size, shuffle=False)[0]
test_loader = data.build_dataloaders([test_set], batch_size, shuffle=False)[0]

In [2]:
transform = transforms.RandomRotation(15)
augmented_train_sets_15 = [data.utils.CustomDataset.parse_values(x, transform) \
                           for x in train_sets]

transform = transforms.RandomRotation(30)
augmented_train_sets_30 = [data.utils.CustomDataset.parse_values(x, transform) \
                           for x in train_sets]

transform = transforms.RandomRotation(45)
augmented_train_sets_45 = [data.utils.CustomDataset.parse_values(x, transform) \
                           for x in train_sets]

transform = transforms.GaussianBlur(5)
augmented_train_sets_blur = [data.utils.CustomDataset.parse_values(x, transform) \
                           for x in train_sets]

### Empirical Risk Minimization (Vapnik 1992)

#### Without data augmentation

In [8]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(train_sets, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'erm/original/')
global_model = flsuite.models.model_loader('RMNIST', 1, seed)[0]
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.997
Test accuracy: 0.918


#### With data augmentation (15°)

In [9]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_15, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'erm/da_15/')
global_model = flsuite.models.model_loader('RMNIST', 1, seed)[0]
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.996
Test accuracy: 0.963


#### With data augmentation (30°)

In [10]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_30, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'erm/da_30/')
global_model = flsuite.models.model_loader('RMNIST', 1, seed)[0]
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.993
Test accuracy: 0.975


#### With data augmentation (45°)

In [11]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_45, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'erm/da_45/')
global_model = flsuite.models.model_loader('RMNIST', 1, seed)[0]
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.990
Test accuracy: 0.978


#### With data augmentation (Gaussian Blur)

In [12]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_blur, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'erm/da_blur/')
global_model = flsuite.models.model_loader('RMNIST', 1, seed)[0]
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.993
Test accuracy: 0.915


### Group DRO (Sagawa et al. 2019)

#### Without data augmentation

In [13]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(train_sets, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'group_dro/original/')
global_model = GroupDRO.bind_to(flsuite.models.model_loader('RMNIST', 1, seed)[0])
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.998
Test accuracy: 0.943


#### With data augmentation (15°)

In [14]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_15, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'group_dro/da_15/')
global_model = GroupDRO.bind_to(flsuite.models.model_loader('RMNIST', 1, seed)[0])
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.994
Test accuracy: 0.969


#### With data augmentation (30°)

In [15]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_30, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'group_dro/da_30/')
global_model = GroupDRO.bind_to(flsuite.models.model_loader('RMNIST', 1, seed)[0])
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.992
Test accuracy: 0.979


#### With data augmentation (45°)

In [16]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_45, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'group_dro/da_45/')
global_model = GroupDRO.bind_to(flsuite.models.model_loader('RMNIST', 1, seed)[0])
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.989
Test accuracy: 0.980


#### With data augmentation (Gaussian Blur)

In [17]:
torch.manual_seed(seed)
train_loaders = data.build_dataloaders(augmented_train_sets_blur, batch_size)
global_loader = data.utils.DataLoaderWrapper(train_loaders)

path = os.path.join(save, 'group_dro/da_blur/')

global_model = GroupDRO.bind_to(flsuite.models.model_loader('RMNIST', 1, seed)[0])
global_model = individual_train(global_model, global_loader, steps, validation_loader=test_loader, \
                                device=device, save=path, eval_steps=10)

print('Train accuracy: %.3f' % utils.eval.accuracy(global_model, all_train_loader))
print('Test accuracy: %.3f' % utils.eval.accuracy(global_model, test_loader))

Train accuracy: 0.992
Test accuracy: 0.933
