## Dataset Details

All users input data is 0,1 MNIST images.

User 1 and User 2: 0 is negative and 1 is positive class.

User 3: labels are random


In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import json

In [3]:
import numpy as np

In [4]:
import torch
import torchvision.datasets as datasets

In [5]:
data_path = "/data/ddmg/redditlanguagemodeling/data/MNIST"

In [6]:
mnist_train = datasets.MNIST(root=data_path, train=True, download=True)

In [7]:
mnist_test = datasets.MNIST(root=data_path, train=False, download=True)

In [8]:
keep_indices = (mnist_train.targets == 0) | (mnist_train.targets == 1)

In [9]:
mnist_train.data, mnist_train.targets = mnist_train.data[keep_indices], mnist_train.targets[keep_indices]

In [10]:
keep_indices = (mnist_test.targets == 0) | (mnist_test.targets == 1)

In [11]:
mnist_test.data, mnist_test.targets = mnist_test.data[keep_indices], mnist_test.targets[keep_indices]

In [12]:
# break up into 3 sets for each user and save as numpy arrays
# each user has 200 train, 100 val, 100 test examples

# train data
user_1_train_x, user_1_train_y = mnist_train.data[:200].flatten(1,2), mnist_train.targets[:200]
user_2_train_x, user_2_train_y = mnist_train.data[200:400].flatten(1,2), mnist_train.targets[200:400]
p_train = torch.full((100, 1), 0.5)
user_3_train_x, user_3_train_y = mnist_train.data[400:600].flatten(1,2), torch.bernoulli(p_train).flatten()

# val data
user_1_val_x, user_1_val_y = mnist_train.data[600:700].flatten(1,2), mnist_train.targets[600:700]
user_2_val_x, user_2_val_y = mnist_train.data[700:800].flatten(1,2), mnist_train.targets[700:800]
p_val = torch.full((100, 1), 0.5)
user_3_val_x, user_3_val_y = mnist_train.data[800:900].flatten(1,2), torch.bernoulli(p_val).flatten()

# test data
user_1_test_x, user_1_test_y = mnist_test.data[:100].flatten(1,2), mnist_test.targets[:100]
user_2_test_x, user_2_test_y = mnist_test.data[100:200].flatten(1,2), mnist_test.targets[100:200]
p_test = torch.full((100, 1), 0.5)
user_3_test_x, user_3_test_y = mnist_test.data[200:300].flatten(1,2), torch.bernoulli(p_test).flatten()

In [13]:
save_data_path = os.path.join(data_path, "dummmy_test1")

In [14]:
# convert to dict and then save as json file
data_list = []
for (x, y) in zip(user_1_train_x, user_1_train_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "train", "user": 1}
    data_list.append(row_dict)

In [15]:
for (x, y) in zip(user_1_val_x, user_1_val_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "val", "user": 1}
    data_list.append(row_dict)

In [16]:
for (x, y) in zip(user_1_test_x, user_1_test_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "test", "user": 1}
    data_list.append(row_dict)

In [17]:
for (x, y) in zip(user_2_train_x, user_2_train_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "train", "user": 2}
    data_list.append(row_dict)
    
for (x, y) in zip(user_2_val_x, user_2_val_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "val", "user": 2}
    data_list.append(row_dict)
    
for (x, y) in zip(user_2_test_x, user_2_test_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "test", "user": 2}
    data_list.append(row_dict)

In [18]:
for (x, y) in zip(user_3_train_x, user_3_train_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "train", "user": 3}
    data_list.append(row_dict)
    
for (x, y) in zip(user_3_val_x, user_3_val_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "val", "user": 3}
    data_list.append(row_dict)
    
for (x, y) in zip(user_3_test_x, user_3_test_y):
    row_dict = {"x": x.tolist(), "y": y.item(), "split": "test", "user": 3}
    data_list.append(row_dict)

In [19]:
data_dict = {"version": "21.28.10", "data": data_list}

In [20]:
with open(os.path.join(data_path, "dummmy_test1", "full_data.json"), "w") as f:
    json.dump(data_dict, f)