In [1]:
!pip install tensorflow-federated==0.19.0

Collecting tensorflow-federated==0.19.0
  Downloading tensorflow_federated-0.19.0-py2.py3-none-any.whl (602 kB)
[?25l[K     |▌                               | 10 kB 20.0 MB/s eta 0:00:01[K     |█                               | 20 kB 25.7 MB/s eta 0:00:01[K     |█▋                              | 30 kB 28.3 MB/s eta 0:00:01[K     |██▏                             | 40 kB 19.3 MB/s eta 0:00:01[K     |██▊                             | 51 kB 15.6 MB/s eta 0:00:01[K     |███▎                            | 61 kB 17.6 MB/s eta 0:00:01[K     |███▉                            | 71 kB 14.5 MB/s eta 0:00:01[K     |████▍                           | 81 kB 13.2 MB/s eta 0:00:01[K     |█████                           | 92 kB 14.4 MB/s eta 0:00:01[K     |█████▍                          | 102 kB 15.5 MB/s eta 0:00:01[K     |██████                          | 112 kB 15.5 MB/s eta 0:00:01[K     |██████▌                         | 122 kB 15.5 MB/s eta 0:00:01[K     |███████          

In [2]:
import pandas as pd
import tensorflow as tf
import tensorflow_federated as tff

csv_url = "https://docs.google.com/spreadsheets/d/1eJo2yOTVLPjcIbwe8qSQlFNpyMhYj-xVnNVUTAhwfNU/gviz/tq?tqx=out:csv"

df = pd.read_csv(csv_url, na_values=("?",))

client_id_colname = 'native.country' # the column that represents client ID
SHUFFLE_BUFFER = 1000
NUM_EPOCHS = 1

In [3]:
# split client id into train and test clients
client_ids = df[client_id_colname].unique()
train_client_ids = pd.DataFrame(client_ids).sample(frac=0.5).values.tolist()

[['Peru'], ['Iran'], ['Nicaragua'], ['Outlying-US(Guam-USVI-etc)'], ['India'], ['Jamaica'], ['Japan'], ['Trinadad&Tobago'], ['Yugoslavia'], ['France'], ['Holand-Netherlands'], ['Portugal'], ['Columbia'], ['South'], ['United-States'], [nan], ['Thailand'], ['Vietnam'], ['Puerto-Rico'], ['Italy'], ['England']]


In [15]:
test_client_ids = [x for x in client_ids if x not in train_client_ids]

['United-States', nan, 'Mexico', 'Greece', 'Vietnam', 'China', 'Taiwan', 'India', 'Philippines', 'Trinadad&Tobago', 'Canada', 'South', 'Holand-Netherlands', 'Puerto-Rico', 'Poland', 'Iran', 'England', 'Germany', 'Italy', 'Japan', 'Hong', 'Honduras', 'Cuba', 'Ireland', 'Cambodia', 'Peru', 'Nicaragua', 'Dominican-Republic', 'Haiti', 'El-Salvador', 'Hungary', 'Columbia', 'Guatemala', 'Jamaica', 'Ecuador', 'France', 'Yugoslavia', 'Scotland', 'Portugal', 'Laos', 'Thailand', 'Outlying-US(Guam-USVI-etc)']


In [69]:
def create_tf_dataset_for_client_fn(client_id):
  # a function which takes a client_id and returns a
  # tf.data.Dataset for that client
  # print(client_id)
  client_data = df[df[client_id_colname] == client_id[0]]
  # print(client_data)
  dataset = tf.data.Dataset.from_tensor_slices(client_data.fillna('').to_dict("list"))
  dataset = dataset.shuffle(SHUFFLE_BUFFER).batch(1).repeat(NUM_EPOCHS)
  return dataset

In [71]:
train_data = tff.simulation.datasets.ClientData.from_clients_and_fn(
        client_ids=train_client_ids,
        create_tf_dataset_for_client_fn=create_tf_dataset_for_client_fn
    )
test_data = tff.simulation.datasets.ClientData.from_clients_and_fn(
        client_ids=test_client_ids,
        create_tf_dataset_for_client_fn=create_tf_dataset_for_client_fn
    )

In [72]:
example_dataset = train_data.create_tf_dataset_for_client(
        train_data.client_ids[0]
    )

print(type(example_dataset))
example_element = iter(example_dataset).next()
print(example_element)
# <class 'tensorflow.python.data.ops.dataset_ops.RepeatDataset'>
# {'age': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([37], dtype=int32)>, 'workclass': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Local-gov'], dtype=object)>, ...

<class 'tensorflow.python.data.ops.dataset_ops.RepeatDataset'>
{'age': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([48], dtype=int32)>, 'workclass': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Private'], dtype=object)>, 'fnlwgt': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([280422], dtype=int32)>, 'education': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Some-college'], dtype=object)>, 'education.num': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([10], dtype=int32)>, 'marital.status': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Separated'], dtype=object)>, 'occupation': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Other-service'], dtype=object)>, 'relationship': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Not-in-family'], dtype=object)>, 'race': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'White'], dtype=object)>, 'sex': <tf.Tensor: shape=(1,), dtype=string, numpy=array([b'Female'], dtype=object)>, 'capital.gain': <tf.Ten