# Lesson 3: Federated LLM Fine-tuning

Welcome to Lesson 3!

To access the `requirements.txt` and `utils` files for this course, go to `File` and click `Open`.

#### 1. Import packages and utilities

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
import flwr as fl
from flwr_datasets import FederatedDataset
from flwr_datasets.partitioner import IidPartitioner
from datasets import load_dataset
from flwr.client.mod import fixedclipping_mod
from flwr.server.strategy import (
    DifferentialPrivacyClientSideFixedClipping
)
from utils.utils import * 
from utils.LLM import LLM_fl
from utils.LLM import get_fireworks_api_key,load_env

* Load configuration.

In [None]:
cfg = get_config("federated")

print_config(cfg)

#### 2. Dataset partition

In [None]:
partitioner = IidPartitioner(num_partitions=cfg.flower.num_clients)
fds = FederatedDataset(
    dataset=cfg.dataset.name,
    partitioners={"train": partitioner}
)

partition_zero = fds.load_partition(0) 

format_dataset(partition_zero)

* Visualize the data partition.

In [None]:
visualize_partitions(fds)

#### 3. Client and Server

* Load the tokenizer and other components.

In [None]:
(
tokenizer,
data_collator,
formatting_prompts_func,
) = get_tokenizer_and_data_collator_and_propt_formatting(
    cfg.model.name, 
    cfg.model.use_fast_tokenizer,
    cfg.train.padding_side,
)

* Define the client.

In [None]:
save_path = "./my_fl_model"
client = fl.client.ClientApp(
    client_fn=gen_client_fn(
        fds,
        tokenizer,
        formatting_prompts_func,
        data_collator,
        cfg.model, 
        cfg.train, 
        save_path,
    ),
    mods=[fixedclipping_mod] 
)

> Note: The ```gen_client_fn``` function is provided for your use. You can find it in the utils folder > ```utils.py```

* Define the server function and add Differential Privacy.

In [None]:
def server_fn(context: Context):

    # Define the Strategy
    strategy = fl.server.strategy.FedAvg(
        min_available_clients=cfg.flower.num_clients, # total clients
        fraction_fit=cfg.flower.fraction_fit, # ratio of clients to sample
        fraction_evaluate=0.0, # No federated evaluation
        # A (optional) function used to configure a "fit()" round
        on_fit_config_fn=get_on_fit_config(),
        # A (optional) function to aggregate metrics sent by clients
        fit_metrics_aggregation_fn=fit_weighted_average,
        # A (optional) function to execute on the server after each round. 
        # In this example the function only saves the global model.
        evaluate_fn=get_evaluate_fn( 
            cfg.model,
            cfg.train.save_every_round,
            cfg.flower.num_rounds,
            save_path
        ),
    )

    # Add Differential Privacy
    sampled_clients = cfg.flower.num_clients*strategy.fraction_fit
    strategy = DifferentialPrivacyClientSideFixedClipping(
        strategy, 
        noise_multiplier=cfg.flower.dp.noise_mult,
        clipping_norm=cfg.flower.dp.clip_norm, 
        num_sampled_clients=sampled_clients
    )

    # Number of rounds to run the simulation
    num_rounds = cfg.flower.num_rounds
    config = fl.server.ServerConfig(num_rounds=num_rounds)
    
    return fl.server.ServerAppComponents(strategy=strategy, config=config) 

* Instantiate the ServerApp.

In [None]:
server = fl.server.ServerApp(server_fn=server_fn)

#### 4. Run

* Run the simulation.

> Note: This simulation might take a few minutes to complete.

In [None]:
client_resources = dict(cfg.flower.client_resources)
fl.simulation.run_simulation(
    server_app=server,
    client_app=client,
    num_supernodes=cfg.flower.num_clients,
    backend_config={"client_resources": client_resources,
                    "init_args": backend_setup}
)

* Run the fine-tuned model. 

In [None]:
# Load the checkpoint
llm_eval = LLM_fl()

In [None]:
# Load dataset
train_dataset = load_dataset(cfg.dataset.name, split='train')
train_dataset = format_dataset(train_dataset)

# Select training example
example_index = 6

In [None]:
data_point = train_dataset[example_index]

In [None]:
# Print the prompt
llm_eval.eval(data_point['instruction'], verbose=True)

In [None]:
# Print the fine-tuned LLM response
llm_eval.print_response()

In [None]:
# Print the expected output from the medAlpaca dataset
ex_response = format_string(data_point['response'])
print(f"Expected output:\n\t{ex_response}")

#### 5. Visualize results

In [None]:
visualize_results(
    results=['7b/pretrained', '7b/cen_10', '7b/fl'])

* See the results if you would provide the same amount of data to the centralized finetuned-model.

In [None]:
visualize_results(
    results=['7b/pretrained', '7b/cen_10',
             '7b/cen_full', '7b/fl'],
    compact=True)

#### 6. One more analysis

* Compute communication costs.

In [None]:
cfg = get_config("federated")

compute_communication_costs(cfg, comm_bw_mbps=20)