In [None]:
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.models as models
from dataset import CIFAR10Dataset
import wandb

#### Parameters.

In [None]:
# Classifier parameters.
CLASSIFIER_NUMBER_OF_CLASSES = 10
CLASSIFIER_NUMBER_OF_EPOCHS = 50
CLASSIFIER_LEARNING_RATE = 0.01
CLASSIFIER_BATCH_SIZE = 64

# Parameters for both agents.

REPLAY_BUFFER_SIZE = 5e4
PRIOROTIZED_REPLAY_EXPONENT = 3

BATCH_SIZE = 32
LEARNING_RATE = 1e-3
TARGET_COPY_FACTOR = 0.01
BIAS_INITIALIZATION = 0

# BatchAgent's parameters.

DIRNAME = './batch_agent/' # The resulting batch_agent of this experiment will be written in a file.

WARM_START_EPISODES_BATCH_AGENT = 50
NN_UPDATES_PER_EPOCHS_BATCH_AGENT = 50

TRAINING_EPOCHS_BATCH_AGENT = 50
TRAINING_EPISODES_PER_EPOCH_BATCH_AGENT = 5

TESTING_EPISODES = 50

In [None]:
import os
import shutil

cwd = os.getcwd()

# Delete following directories if they exist.
for directory in [cwd+'/__pycache__', cwd+'/wandb', cwd+'/batch_agent', cwd+'/libact', cwd+'/AL_results', cwd+'/checkpoints', cwd+'/summaries', cwd+'/data', cwd+'/data_client']:
    if os.path.exists(directory):
        shutil.rmtree(directory, ignore_errors=True)

### Federated Learning: Split the CIFAR10 dataset and retrieve the subset for the first client.

In [None]:
import torch
from torchvision.datasets import CIFAR10

def get_cifar10_splited_big_common(num_clients, trans,
                                   root='/home/kastellosa/PycharmProjects/federated_learning/CVPR_nov_23/data',
                                   special_client_size=0):
    """
    
    num_clients: The total number of clients to split the dataset into.
    trans: Transformations to be applied to the images.
    root: The root directory where the CIFAR-10 dataset is stored or should be downloaded.
    special_client_size: The size (number of images) of the common dataset for the special client.
    return: 
        1. Indices of  the images of the dataset for each client.
        2. The Cifar10 dataset.
    """

    # Special Client Size Calculation.
    special_indices_per_class_from_total = int(special_client_size / 10) # Calculates the number of images per class for the special client (assuming an even distribution across the 10 classes).
    
    # Client Number Validation.
    # Ensures there are at least two clients.
    if num_clients < 2:
        raise ValueError("Number of clients must be at least 2.")

    # Load CIFAR-10 Dataset.
    # Loads the CIFAR-10 training set with the specified transformations.
    trainset = CIFAR10(root=root, train=True, download=True, transform=trans)

    # Shuffle indices.
    # Shuffles the dataset indices randomly.
    indices = torch.randperm(len(trainset)).tolist()

    # Organize indices by class.
    # Initializes a list to hold indices for each class.
    # Iterates over the shuffled indices,
    # retrieves the label for each image,
    # and appends the index to the corresponding class list.
    class_indices = [[] for _ in range(10)]  # CIFAR10 has 10 classes.
    for idx in indices:
        _, label = trainset[idx]
        class_indices[label].append(idx)

    # First subset (special client).
    # Allocate Special Client Indices.
    # Allocates the first special_indices_per_class_from_total indices from each class to the special client.
    # Removes these indices from the class lists.
    special_client_indices = []
    for class_list in class_indices:
        special_client_indices.extend(class_list[:special_indices_per_class_from_total])
        del class_list[:special_indices_per_class_from_total]

    # Calculate the number of images per class for the remaining clients.
    # Calculates the remaining number of images per class
    # and the number of images per class per client
    # (excluding the special client).
    remaining_images_per_class = len(class_indices[0])
    images_per_class_per_client = remaining_images_per_class // (num_clients - 1)

    # Distribute remaining images among other clients.
    # Initializes the list of client indices with the special client's indices.
    # Iterates over the remaining clients
    # and distributes the remaining images per class among them.
    client_indices = [special_client_indices]  # Start with the special client.
    for _ in range(num_clients - 1):
        client_subset = []
        for class_list in class_indices:
            client_subset.extend(class_list[:images_per_class_per_client])
            del class_list[:images_per_class_per_client]
        client_indices.append(client_subset)

    # Return.
    # Returns the list of indices for each client and the CIFAR-10 dataset.
    return client_indices, trainset

"""
Summary:

The function get_cifar10_splited_big_common
effectively splits the CIFAR-10 dataset into subsets for federated learning.
It ensures one client receives a larger subset of data with an equal distribution across classes,
while the remaining clients receive evenly distributed subsets from the remaining data.
This can be particularly useful in scenarios where
a common dataset needs to be shared among a subset of clients.
"""

In [None]:
from torchvision import transforms

# Define the number of clients and the size of the special client's dataset.
num_clients = 5
special_client_size = 10000
root = './data'

# Define the transformations to be applied to the dataset.
trans = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Call the function.
client_indices, trainset = get_cifar10_splited_big_common(
    num_clients=num_clients,
    trans=trans,
    root=root,
    special_client_size=special_client_size
)

# Print some information about the output.
for i, indices in enumerate(client_indices):
    print(f"Client {i} has {len(indices)} images.")

# Example to access the dataset for a specific client.
client_0_data = torch.utils.data.Subset(trainset, client_indices[0])
print(f"First client's dataset size: {len(client_0_data)}.")

#### Client's subset initialization.

In [None]:
"""
Define the client.
    - 0: First client.
    - 1: Second client.
    - 2: Third client.
    - 3: Forth client.
    - 4: Fifth client.
"""

first_client_indices = client_indices[0]
subset_data_first_client = torch.tensor(trainset.data[first_client_indices])
subset_labels_first_client = torch.tensor([trainset.targets[i] for i in first_client_indices])
dataset = CIFAR10Dataset(root_dir= './data_client',length_of_client_data=len(client_0_data), data = subset_data_first_client, labels = subset_labels_first_client)

In [None]:
print("Warm-start data are {}.".format(len(dataset.warm_start_data)))
print("State data are {}.".format(len(dataset.state_data)))
print("Agent data are {}.".format(len(dataset.agent_data)))
print("Test data are {}.".format(len(dataset.test_data)))

#### Define the FL round.

In [None]:
round = 1

#### Initialize the classifier based on the 'round'.

In [None]:
# Initialize the model.
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.resnet18 = models.resnet18(pretrained=True)
        for param in self.resnet18.parameters():
            param.requires_grad = False
        
        # Modify the layers to handle smaller input sizes
        self.resnet18.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.resnet18.maxpool = nn.Identity()  # Remove the max pooling layer
        
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, 10)

    def forward(self, x):
        x = x.reshape(-1, 3, 32, 32)
        return self.resnet18(x)

# Initialize the model and device.
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
classifier = CNNClassifier()
classifier.to(device)

# If we have already run the first round of epochs, load the FL weights for the classifier.
if round!=1:
    classifier.load_state_dict(torch.load('classifier_weights.pth'))

# Define the loss function and optimizer.
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(classifier.parameters(), lr=0.01)

torch.save(classifier.state_dict(), 'classifier_weights.pth')

In [None]:
TARGET_PRECISION = 0.0

#### Initialize the Environment.

In [None]:
from batch_envs import LalEnvFirstAccuracy
batch_env = LalEnvFirstAccuracy(dataset, classifier, epochs=CLASSIFIER_NUMBER_OF_EPOCHS, classifier_batch_size=CLASSIFIER_BATCH_SIZE, target_precision=TARGET_PRECISION)

#### Initialize the Replay Buffer.

In [None]:
from batch_helpers import ReplayBuffer
replay_buffer = ReplayBuffer(buffer_size=REPLAY_BUFFER_SIZE, prior_exp=PRIOROTIZED_REPLAY_EXPONENT)

In [None]:
torch.cuda.empty_cache()  # Clear unused memory after each episode.

#### Warm-Start episodes.

In [None]:
# WARM-START EPISODES.

import torch
import numpy as np

# Initialize the variables.
episode_durations = []
episode_scores = []
episode_number = 1
episode_losses = []
episode_precisions = []
batches = []

# Warm start procedure.
for _ in range(WARM_START_EPISODES_BATCH_AGENT):
    print("Episode {}.".format(episode_number))
    # Reset the environment to start a new episode.
    # print("- Reset.")
    state, next_action, indicies_unknown, reward = batch_env.reset(code_state="Warm-Start", target_precision=TARGET_PRECISION, target_budget=1.0)
    done = False
    episode_duration = CLASSIFIER_NUMBER_OF_CLASSES

    # Before we reach a terminal state, make steps.
    while not done:
        # Choose a random action.
        # print("-- Number of actions left: {}.".format(batch_env.n_actions))
        if batch_env.n_actions==1:
            batch = batch_env.n_actions
        else:
            batch = torch.randint(1, batch_env.n_actions + 1, (1,)).item()
        # print("-- Batch: {}.".format(batch))
        batches.append(batch)

        # Get the numbers from 0 to n_actions.
        input_numbers = range(0, batch_env.n_actions)

        # Non-repeating using sample() function.
        batch_actions_indices = torch.tensor(np.random.choice(input_numbers, batch, replace=False))
        # print("batch_actions_indices", batch_actions_indices)
        action = batch
        # print("- Step.")
        next_state, next_action, indicies_unknown, reward, done = batch_env.step(batch_actions_indices)

        if next_action == []:
            next_action.append(np.array([0]))

        # Store the transition in the replay buffer.
        replay_buffer.store_transition(state, action, reward, next_state, next_action, done)

        # Get ready for the next step.
        state = next_state
        episode_duration += batch

    # Calculate the final accuracy and precision of the episode.
    episode_final_acc = batch_env.return_episode_qualities()     
    episode_scores.append(episode_final_acc[-1])
    episode_final_precision = batch_env.return_episode_precisions()     
    episode_precisions.append(episode_final_precision[-1])    
    episode_durations.append(episode_duration)  
    episode_number += 1
    
    torch.cuda.empty_cache()  # Clear unused memory after each episode.

# Compute the average episode duration of episodes generated during the warm start procedure.
av_episode_duration = np.mean(episode_durations)
BIAS_INITIALIZATION = - av_episode_duration / 2

#### Define target precision and target budget based on the Warm-Start episodes.

In [None]:
import torch

# Convert the list to a PyTorch tensor.
episode_precisions = torch.tensor(episode_precisions)
max_precision = torch.max(episode_precisions)

warm_start_batches = []
i=0
for precision in episode_precisions:
    if precision >= max(episode_precisions):
        warm_start_batches.append(episode_durations[i])
    i+=1
TARGET_BUDGET = min(warm_start_batches)/(len(dataset.warm_start_data))
print("Target budget is {}.".format(TARGET_BUDGET))
TARGET_PRECISION = max(episode_precisions)
print("Target precision is {}.".format(TARGET_PRECISION))

#### Initialize the DQN based on the 'round'.

In [None]:
from batch_dqn import DQN
batch_agent = DQN(
            observation_length=len(dataset.state_data),
            learning_rate=LEARNING_RATE,
            batch_size=BATCH_SIZE,
            target_copy_factor=TARGET_COPY_FACTOR,
            bias_average=BIAS_INITIALIZATION,
            round = round
           )

#### Agent's first training using the Replay Buffer.

In [None]:
for update in range(NN_UPDATES_PER_EPOCHS_BATCH_AGENT):
    print("Update:", update+1)
    minibatch = replay_buffer.sample_minibatch(BATCH_SIZE)
    td_error = batch_agent.train(minibatch)
    replay_buffer.update_td_errors(td_error, minibatch.indices)
    torch.cuda.empty_cache()  # Clear unused memory after each update.

#### Agent's training.

In [None]:
# BATCH-AGENT TRAINING.

# Initialize the agent.
agent_epoch_durations = []
agent_epoch_scores = []
agent_epoch_precisions = []

for epoch in range(TRAINING_EPOCHS_BATCH_AGENT):
    print("Training epoch {}.".format(epoch+1))

    # Simulate training episodes.
    agent_episode_durations = []
    agent_episode_scores = []
    agent_episode_precisions = []

    for training_episode in range(TRAINING_EPISODES_PER_EPOCH_BATCH_AGENT):

        print("- Training episode {}.".format(training_episode+1))

        # Reset the environment to start a new episode.
        print("- Reset.")
        state, action_batch, action_unlabeled_data, reward = batch_env.reset(code_state="Agent", target_precision=TARGET_PRECISION, target_budget=TARGET_BUDGET)
        done = False
        episode_duration = CLASSIFIER_NUMBER_OF_CLASSES
        first_batch = True

        # Run an episode.
        while not done:
            if first_batch:
                next_batch = action_batch
                next_unlabeled_data = action_unlabeled_data
                first_batch = False
            else:
                next_batch = next_action_batch_size
                next_unlabeled_data = next_action_unlabeled_data

            selected_batch, selected_indices = batch_agent.get_action(dataset=dataset, model=classifier, state=state, next_action_batch=next_batch, next_action_unlabeled_data=next_unlabeled_data)
            print("- Step.")
            next_state, next_action_batch_size, next_action_unlabeled_data, reward, done = batch_env.step(selected_indices)
            if next_action_batch_size==[]:
                next_action_batch_size.append(np.array([0]))

            print("- Buffer.")
            replay_buffer.store_transition(state, selected_batch, reward, next_state, next_action_batch_size, done)
        
            # Change the state of the environment.
            state = torch.tensor(next_state, dtype=torch.float32).to(device)
            episode_duration += selected_batch

        print("\n")
        
        agent_episode_final_acc = batch_env.return_episode_qualities()
        agent_episode_scores.append(agent_episode_final_acc[-1])
        agent_episode_final_precision = batch_env.return_episode_precisions()
        agent_episode_precisions.append(agent_episode_final_precision[-1])
        agent_episode_durations.append(episode_duration)
        
    maximum_epoch_precision = max(agent_episode_precisions)
    minimum_batches_for_the_maximum_epoch_precision = []
    accuracy_for_the_maximum_epoch_precision = []
    for i in range(len(agent_episode_precisions)):
        if agent_episode_precisions[i] == maximum_epoch_precision:
            minimum_batches_for_the_maximum_epoch_precision.append(agent_episode_durations[i])
            accuracy_for_the_maximum_epoch_precision.append(agent_episode_scores[i])
    agent_epoch_precisions.append(maximum_epoch_precision)
    agent_epoch_scores.append(accuracy_for_the_maximum_epoch_precision)
    agent_epoch_durations.append(min(minimum_batches_for_the_maximum_epoch_precision))

    torch.cuda.empty_cache()  # Clear unused memory after each episode.

    # NEURAL NETWORK UPDATES.
    for update in range(NN_UPDATES_PER_EPOCHS_BATCH_AGENT):
        minibatch = replay_buffer.sample_minibatch(BATCH_SIZE)
        td_error = batch_agent.train(minibatch)
        replay_buffer.update_td_errors(td_error, minibatch.indices)
        torch.cuda.empty_cache()  # Clear unused memory after each update.

In [None]:
torch.cuda.memory_summary()

#### Save models' weights.

In [None]:
# Save classifier's weights.
torch.save(classifier.state_dict(), 'classifier_weights.pth')

# Save DQN's weights.
batch_agent.save_weights('dqn_weights.pth')

#### This is the end of a round.