In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from copy import deepcopy

from collections import OrderedDict
import argparse
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [2]:
def make_dir(args):    
    import os

    MODEL_DIR = f"AE_model/client = {args.num_clients}"
    try:
        os.makedirs(MODEL_DIR)
    except FileExistsError:
        print('Directories not created because they already exist')

In [11]:
def load_data(args):
    # ===== w_data ===== #
    df = pd.read_csv(f"Final/parameter_data/param_data_{args.num_clients}.csv")
    data = torch.Tensor(np.delete(df.values, 0, 1))

    # ===== MNIST data ===== #
    train_dataset = datasets.MNIST(root='MNIST_data/', train = True, transform = transforms.ToTensor(), download = True)
    val_dataset = datasets.MNIST(root='MNIST_data/', train = False, transform = transforms.ToTensor(), download = True)

    train_dataset, _ = torch.utils.data.random_split(train_dataset, [10000, 50000])
    val_dataset, _ = torch.utils.data.random_split(val_dataset, [2000, 8000])
    return data, train_dataset, val_dataset

In [12]:
class Encoder(nn.Module):
    def __init__(self, hidden_dim):
        super(Encoder, self).__init__()

        self.hidden_dim = hidden_dim
        self.encoder = nn.Sequential(nn.Linear(7851, 785),
                                    nn.ReLU(),
                                    nn.Linear(785, 50),
                                    nn.ReLU(),
                                    nn.Linear(50, self.hidden_dim)
                                    )
            
    def forward(self, x):
        output = self.encoder(x)
        return output


class Decoder(nn.Module):
    def __init__(self, hidden_dim):
        super(Decoder, self).__init__()

        self.hidden_dim = hidden_dim
        self.decoder = nn.Sequential(nn.Linear(self.hidden_dim, 50),
                                    nn.ReLU(),
                                    nn.Linear(50, 785),
                                    nn.ReLU(),
                                    nn.Linear(785, 7851)
                                    )
            
    def forward(self, x):
        output = self.decoder(x)
        return output


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.net = nn.Sequential(nn.Linear(28*28, 10)
                                )

    def forward(self, x):
        x = x.view(-1,28*28)
        out = self.net(x)
        return out

In [52]:
def train(encoder, decoder, net, partition, optimizer, criterion, args):
    encoder.train()
    decoder.train()
    net.train()
    train_loader = torch.utils.data.DataLoader(partition['train'], batch_size=args.batch_size*3)
    mnist_train_loader = torch.utils.data.DataLoader(partition['mnist_train'], batch_size=args.mnist_batch_size)
    
    train_loss = 0
    for train_data in train_loader:
        train_data = train_data.to(device)
        ae_out = decoder(encoder(train_data))
        w_avg_hat = torch.mean(ae_out.view(args.num_clients, -1, 7851), dim=0)[:,:-1].view(-1,785)

        global_weights = OrderedDict()
        global_weights['net.0.weight'] = w_avg_hat[:,:784]
        global_weights['net.0.bias'] = w_avg_hat[:,784:].squeeze()
        net.load_state_dict(global_weights)

        #iter_loss = 0
        for mnist_data, mnist_label in mnist_train_loader:
            mnist_data, mnist_label = mnist_data.to(device), mnist_label.to(device)

            optimizer.zero_grad()
            
            net_out = net(mnist_data)
            loss = criterion(net_out, mnist_label)
            train_loss += loss.item()
            iter_loss += loss.item()
        
            loss.backward()
            optimizer.step()
        #print(iter_loss/len(mnist_train_loader))

    train_loss /= (len(train_loader)*len(mnist_train_loader))
    return encoder, decoder, train_loss

In [53]:
def validate(encoder, decoder, net, partition, criterion, args):
    encoder.eval()
    decoder.eval()
    val_loader = torch.utils.data.DataLoader(partition['val'], batch_size=args.batch_size*3, shuffle=True)
    mnist_val_loader = torch.utils.data.DataLoader(partition['mnist_val'], batch_size=args.mnist_batch_size, shuffle=True)
    

    with torch.no_grad():
        val_loss = 0

        for val_data in val_loader:

            val_data = val_data.to(device)

            ae_out = decoder(encoder(val_data))
            val_w_avg_hat = torch.mean(ae_out.view(args.num_clients, -1, 7851), dim=0)[:,:-1].view(-1,785)

            global_weights = OrderedDict()
            global_weights['net.0.weight'] = val_w_avg_hat[:,:784]
            global_weights['net.0.bias'] = val_w_avg_hat[:,784:].squeeze()
            net.load_state_dict(global_weights)

            for mnist_val_data, mnist_val_label in mnist_val_loader:
                mnist_val_data, mnist_val_label = mnist_val_data.to(device), mnist_val_label.to(device)

                net_out = net(mnist_val_data)
                loss = criterion(net_out , mnist_val_label)
                val_loss += loss.item()
                
        val_loss /= (len(val_loader)*len(mnist_val_loader))
    return val_loss

In [54]:
def experiment(partition, args):
    encoder = Encoder(args.hidden_dim)
    decoder = Decoder(args.hidden_dim)
    net = Net()
    encoder, decoder, net = encoder.to(device), decoder.to(device), net.to(device)

    optimizer = torch.optim.Adam(list(encoder.parameters())+list(decoder.parameters()), lr = args.lr)
    criterion = torch.nn.CrossEntropyLoss()

    train_losses = []
    val_losses = []
    
    for epoch in range(args.num_epochs):
        train_loss = 0
        val_loss = 0

        encoder, decoder, train_loss = train(encoder, decoder, net, partition, optimizer, criterion, args)
        val_loss = validate(encoder, decoder, net, partition, criterion, args)

        #if (epoch+1)%10 == 0:
        #    print(f"[{epoch+1}/{args.num_epochs}]  Train Loss: {train_loss}  Val Loss: {val_loss}")
        print(f"[{epoch+1}/{args.num_epochs}]  Train Loss: {train_loss}  Val Loss: {val_loss}")
        train_losses.append(train_loss)
        val_losses.append(val_loss)
    
    return encoder, decoder, train_losses, val_losses

In [55]:
def save_model(encoder, decoder, var, args):

    if args.reverse == True:
        enc_path = f"min(FL_Loss)_AE_model/client = {args.num_clients}/reverse/enc_num_data_{var}.pth"
        dec_path = f"min(FL_Loss)_AE_model/client = {args.num_clients}/reverse/dec_num_data_{var}.pth"
    else:  
        enc_path = f"min(FL_Loss)_AE_model/client = {args.num_clients}/enc_num_data_{var}.pth"
        dec_path = f"min(FL_Loss)_AE_model/client = {args.num_clients}/dec_num_data_{var}.pth"

    torch.save(encoder.state_dict(), enc_path)
    torch.save(decoder.state_dict(), dec_path)
    print(f"save model when [{var}] ")

In [56]:
def plot_loss(var, name_var, train_losses, val_losses):
    plt.title(f"{name_var} = {var}")
    plt.plot(train_losses, label='train')
    plt.plot(val_losses, label='val')

    plt.grid()
    plt.legend()
    plt.show()

In [57]:
seed = 1228
np.random.seed(seed)
torch.manual_seed(seed)

parser = argparse.ArgumentParser()
args = parser.parse_args("")

# ====== Optimizer & Training ====== #
args.lr = 0.001
args.num_epochs = 200

# ====== Model Capacity ===== #
args.hidden_dim = 15

# ====== Data Loading ====== #
args.batch_size = 1
args.mnist_batch_size = 512
args.num_clients = 3
args.num_data = 500
args.reverse = False


make_dir(args)

name_var1 = 'reverse'
name_var2 = 'num_data'
list_var1 = [False, True]
list_var2 = [100, 500, 700, 1000]

set, train_dataset, val_dataset = load_data(args)
set = set.view(3, -1 ,7851)

for var1 in list_var1:
    print(f"========== {name_var1}: {var1} ==========")
    setattr(args, name_var1, var1)
    
    for var2 in list_var2:
        print(f"[{name_var2}: {var2}]")
        setattr(args, name_var2, var2)

        num_set = set[:, torch.randperm(args.num_data), :].reshape(-1, 7851)
        print(f"shape of data: {num_set.shape}")

        train_set = num_set[:int(len(num_set)*0.8)]
        val_set = num_set[int(len(num_set)*0.8):]
        partition = {'train':train_set, 'val':val_set, 'mnist_train':train_dataset, 'mnist_val':val_dataset}

        encoder, decoder, train_losses, val_losses = experiment(partition, deepcopy(args))
        plot_loss(var2, name_var2, train_losses, val_losses)
        save_model(encoder, decoder, var2, args)

Directories not created because they already exist
[num_data: 100]
shape of data: torch.Size([300, 7851])
2.3542761325836183
2.354082405567169
2.3542733550071717
2.3538515329360963
2.351799488067627
2.3532838940620424
2.355102574825287
2.354480504989624
2.3545336604118345
2.3543482065200805
2.3521487593650816
2.355717384815216
2.354356038570404
2.355105483531952
2.352464699745178
2.354747176170349
2.353770208358765
2.3519059300422667
2.3573376059532167
2.3551472425460815
2.355656659603119
2.354658281803131
2.3538002490997316
2.3537532091140747
2.3544562578201296
2.3541361689567566
2.350209367275238
2.3523062229156495
2.350916123390198
2.353810524940491
2.352051866054535
2.355924165248871
2.356825923919678
2.413106679916382
2.4484464645385744
2.451253867149353
2.447420620918274
2.4473577618598936
2.4473183512687684
2.44848552942276
2.448956286907196
2.4493505716323853
2.4493940830230714
2.4472745656967163
2.4478769063949586
2.4496517062187193
2.45160973072052
2.446529722213745
2.4506221

KeyboardInterrupt: 