In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn.utils.parametrizations import spectral_norm

from torchvision import datasets, transforms, utils
import seaborn as sns
import matplotlib.pyplot as plt
from torchsummary import summary
from dataTransformation import labels4clients, distribute_data_labels4clients
from gan_model import Discriminator, Generator, initialize_weights
from network import Server, Worker
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import Dataset, DataLoader

In [None]:
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 64
IMAGE_SIZE = 32
CHANNELS_IMG = 3
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 32
FEATURES_GEN = 32

In [None]:
num_workers = 5
num_unique_users = num_workers
num_classes = 10
classes_per_user = 4

learning_rate = 0.0002
test_samples_num = 16
epochs_num = 200

dictionary = labels4clients(num_classes,classes_per_user,num_workers,num_unique_users,False)
print(dictionary)

In [None]:
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = datasets.CIFAR10(root='./datasets/cifar/', train=True, download=True, transform=trans_cifar)
dataset_test = datasets.CIFAR10(root='./datasets/cifar/', train=False, download=True, transform=trans_cifar)
dataloader = torch.utils.data.DataLoader(dataset, shuffle = True) 

In [None]:
print(dataset.data[0])
print(dataset.transforms(dataset.data[0],transforms.ToTensor()))
print(dataset.transforms(dataset.data[0],trans_cifar))

In [None]:
normalized = transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
print(normalized)
trans_cifar(dataset.data[0])

In [None]:
print(dataset.data.shape)
print(type(dataset))

In [None]:
dataloader.dataset.data.shape

In [None]:
dataloader.dataset.data.shape

In [None]:
x,_ = dataloader.dataset[0]
print(x.shape)
print(x)

In [None]:
x_train_normalized_np = np.empty((dataset.data.shape[0],dataset.data.shape[3],dataset.data.shape[1],dataset.data.shape[2]))
for i in range(len(dataset)):
    x_train_normalized_np[i] = trans_cifar(dataset.data[i])

In [None]:
x_train = np.asarray(dataset.data)
y_train = np.asarray(dataset.targets)
x_clinet_list, y_client_list = distribute_data_labels4clients(x_train_normalized_np,y_train,dictionary,False)

In [None]:
def getDist(y,num_classes,user_num):
    ax = sns.countplot(x=y)
    ax.set(title="Count of data classes for %s" %user_num)
    plt.show()

In [None]:
for i in range (len(x_clinet_list)):
    print(len(y_client_list[i]))
    getDist(y_client_list[i],num_classes,i)

In [None]:
# class Generator(nn.Module):

#     def __init__(self):
#         super(Generator, self).__init__()
#         self.main = nn.Sequential(
#             nn.ConvTranspose2d(in_channels=128, out_channels=512, kernel_size=4, stride=1, padding=0, bias = False),
#             nn.BatchNorm2d(num_features=512,momentum=0.1),
#             nn.ReLU(inplace=True),
#             nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=4, stride=2, padding=1, bias = False),
#             nn.BatchNorm2d(num_features=256,momentum=0.1),
#             nn.ReLU(inplace=True),
#             nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1, bias = False),
#             nn.BatchNorm2d(num_features=128,momentum=0.1),
#             nn.ReLU(inplace=True),
#             nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1, bias = False),
#             nn.BatchNorm2d(num_features=64,momentum=0.1),
#             nn.ReLU(inplace=True),
#             nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias = False),
#             nn.Tanh()
#         )

#     def forward(self, input):
#         output = self.main(input)
#         return output

In [None]:
# class Discriminator(nn.Module):

#     def __init__(self):
#         super(Discriminator, self).__init__()
#         self.main = nn.Sequential(
#             spectral_norm(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias = False)),
#             nn.LeakyReLU(negative_slope= 0.1, inplace = True),

#             spectral_norm(nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1, bias = False)),
#             nn.LeakyReLU(negative_slope= 0.1, inplace = True),

#             spectral_norm(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias = False)),
#             nn.LeakyReLU(negative_slope= 0.1, inplace = True),

#             spectral_norm(nn.Conv2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1, bias = False)),
#             nn.LeakyReLU(negative_slope= 0.1, inplace = True),

#             # need to calculate the number of neurons in this layer to connect each of their outputs to the next layer
#             spectral_norm(nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1, bias = False)),
#             nn.LeakyReLU(negative_slope= 0.1, inplace = True),
#             nn.Flatten(), #flatten the output
#             spectral_norm(nn.Linear(in_features =4096,out_features =1, bias = False))
#         )

#     def neuron_calculator(in_channels,padding,kernel_size,stride,out_channels):
#         return (in_channels+2*padding-kernel_size)**2 * out_channels
#     def forward(self, input):
#         output = self.main(input)
#         return output.view(-1)


In [None]:
if torch.cuda.is_available():
    dev = "cuda:0"
else:
    dev = "cpu"
dev = torch.device(dev)

In [None]:
# netG = Generator().to(dev)
# netD = Discriminator().to(dev)
# summary(netG,(128,1,1))
# summary(netD,(3,32,32))

In [None]:
main_server = Server(0,learning_rate)
main_server.generator.train()
workers = []
for i in range(num_workers):
    worker = Worker(i,learning_rate)
    # x_clinet_list[i] = np.transpose(x_clinet_list[i],(0, 3, 1, 2))
    worker.load_worker_data(x_clinet_list[i], y_client_list[i])
    worker.discriminator.train()
    workers.append(worker)
    
# summary(main_server.generator,(128,1,1))
# summary(workers[0].discriminator,(3,32,32))

In [None]:
criterion = nn.MSELoss()
NOISE_DIM = 128
fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(dev)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

worker_loaders = []

for worker in workers:
    # print(worker.x_data.shape)
    worker_loaders.append([])
    for batch_id, real in enumerate(DataLoader(dataset=worker.x_data,batch_size=BATCH_SIZE)):
        worker_loaders[-1].append(real)



In [None]:

for epoch in range(NUM_EPOCHS):
    for batch_id in range(len(worker_loaders[0])):

        highest_loss = 0
        chosen_discriminator = None
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1,1).to(dev)
        fake = main_server.generator(noise)

        for worker_id, worker in enumerate(workers):
            current_worker_real = worker_loaders[worker_id][batch_id].float().to(dev)
            # print(real.shape)

            ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
            current_disc_real = worker.discriminator(current_worker_real).reshape(-1)
            worker.loss_disc_real = criterion(current_disc_real, torch.ones_like(current_disc_real))
            current_disc_fake = worker.discriminator(fake.detach()).reshape(-1)
            worker.loss_disc_fake = criterion(current_disc_fake, torch.zeros_like(current_disc_fake))
            worker.loss_disc = (worker.loss_disc_real + worker.loss_disc_fake) / 2
            worker.discriminator.zero_grad()
            worker.loss_disc.backward()
            worker.d_optimizer.step()
            # print(worker.loss_disc_fake, i)
            if highest_loss < worker.loss_disc_fake:
                highest_loss = worker.loss_disc_fake
                chosen_discriminator = worker_id
        print(f"chosen worker is {chosen_discriminator} with loss of: {highest_loss.item():.4f}")
        chosen_worker = workers[chosen_discriminator]
        

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = chosen_worker.discriminator(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        main_server.generator.zero_grad()
        loss_gen.backward()
        main_server.g_optimizer.step()


        # Print losses occasionally and print to tensorboard
        if batch_id % 10 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_id}/{len(worker_loaders[0])} \
                  chosen D: {chosen_discriminator} Loss D: {chosen_worker.loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = main_server.generator(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

In [None]:
# dataloaders = []

# for worker in workers:
#     # print(worker.x_data.shape)
#     dataloaders.append(DataLoader(dataset=worker.x_data,batch_size=BATCH_SIZE))

# i = iter(dataloaders[0])
# print(next(i))

In [None]:
# criterion = nn.MSELoss()
# NOISE_DIM = 128
# fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(dev)
# writer_real = SummaryWriter(f"logs/real")
# writer_fake = SummaryWriter(f"logs/fake")
# step = 0

# for epoch in range(NUM_EPOCHS):
#     highest_loss = 0
#     chosen_discriminator = None
#     for i, worker in enumerate(workers):
#         print(worker.x_data.shape)
#         dataloader = DataLoader(dataset=worker.x_data,batch_size=BATCH_SIZE)
#         for batch_id, real in enumerate(dataloader):
#             real = real.float().to(dev)
#             # print(real.shape)
#             # print(real)
#             noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1,1).to(dev)
#             fake = main_server.generator(noise)

#             ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
#             disc_real = worker.discriminator(real).reshape(-1)
#             worker.loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
#             disc_fake = worker.discriminator(fake.detach()).reshape(-1)
#             worker.loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
#             loss_disc = (worker.loss_disc_real + worker.loss_disc_fake) / 2
#             worker.discriminator.zero_grad()
#             loss_disc.backward()
#             worker.d_optimizer.step()
#             if batch_id % 20 == 0:
#                 print(
#                     f"Worker: {i} Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_id}/{len(dataloader)} \
#                         Loss D: {loss_disc:.4f}"
#                 )
#         # print(worker.loss_disc_fake, i)
#         if highest_loss < worker.loss_disc_fake:
#             highest_loss = worker.loss_disc_fake
#             chosen_discriminator = i
#         print(f"chosen worker is {chosen_discriminator} with loss of: {highest_loss.item()}")
#     dataloader = DataLoader(dataset=workers[chosen_discriminator].x_data,batch_size=BATCH_SIZE)
        
        


In [None]:
    ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
    # output = disc(fake).reshape(-1)
    # loss_gen = criterion(output, torch.ones_like(output))
    # gen.zero_grad()
    # loss_gen.backward()
    # opt_gen.step()

    # for batch_idx, (real, _) in enumerate(dataloader):
    #     real = real.to(device)
    #     noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
    #     fake = gen(noise)

    #     ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
    #     disc_real = disc(real).reshape(-1)
    #     loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
    #     disc_fake = disc(fake.detach()).reshape(-1)
    #     loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    #     loss_disc = (loss_disc_real + loss_disc_fake) / 2
    #     disc.zero_grad()
    #     loss_disc.backward()
    #     opt_disc.step()

    #     ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
    #     output = disc(fake).reshape(-1)
    #     loss_gen = criterion(output, torch.ones_like(output))
    #     gen.zero_grad()
    #     loss_gen.backward()
    #     opt_gen.step()

    #     # Print losses occasionally and print to tensorboard
    #     if batch_idx % 100 == 0:
    #         print(
    #             f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
    #               Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
    #         )

    #         with torch.no_grad():
    #             fake = gen(fixed_noise)
    #             # take out (up to) 32 examples
    #             img_grid_real = torchvision.utils.make_grid(
    #                 real[:32], normalize=True
    #             )
    #             img_grid_fake = torchvision.utils.make_grid(
    #                 fake[:32], normalize=True
    #             )

    #             writer_real.add_image("Real", img_grid_real, global_step=step)
    #             writer_fake.add_image("Fake", img_grid_fake, global_step=step)

    #         step += 1