# Fine-tune FLUX.1-schnell with SageMaker Distributed Data Parallel (SMDDP)

---

In this demo notebook, we demonstrate how to fine-tune the FLUX.1-schnell model using Hugging Face PEFT - LoRA, bitsandbytes, with SageMaker Distrubuted Data Parallel library

Fine-Tuning:
* Instance Type: ml.p4dn.24xlarge

Install the required libriaries, including the Hugging Face libraries, and restart the kernel.

In [None]:
%pip install -r requirements.txt

If you are going to use Sagemaker in a local environment. You need access to an IAM Role with the required permissions for Sagemaker. You can find [here](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) more about it.



In [None]:
import boto3
import sagemaker

In [None]:
sess = sagemaker.Session()
sagemaker_session_bucket=None

if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker session region: {sess.boto_region_name}")

## Visualize and upload the dataset

Read train dataset

In [None]:
dataset_name = "diffusers/tuxemon"

In [None]:
from datasets import load_dataset

dataset = load_dataset(
    dataset_name
)

dataset

# Fine-tune FLUX.1-schnell on Amazon SageMaker

We are now ready to fine-tune our model. The training script is located in [./scripts/train.py](./scripts/train.py).

We are going to use SageMaker Distributed Data Parallel with `AllGather` as collective operation, shard the model across all the available GPUs, and `torchrun` as script launcher for distributing the training across the GPUs available in the `ml.p4d.24xlarge`

For more information about SageMaker Distributed Data Parallel, please visit the official [AWS Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html)

In [None]:
model_id = "black-forest-labs/FLUX.1-schnell"
dataset_name = "diffusers/tuxemon"

Below the hyperparameters used in the training script

In [None]:
hyperparameters = {
    "pretrained_model_name_or_path": model_id,
    "dataset_name": dataset_name,
    "output_dir": "/opt/ml/checkpoint",
    "mixed_precision": "bf16",
    "instance_prompt": "describe",
    "resolution": 1024,
    "train_batch_size": 2,
    "gradient_accumulation_steps": 1,
    "gradient_checkpointing": True,
    "use_8bit_adam": True,
    "learning_rate": 1e-5,
    "lr_scheduler": "constant",
    "lr_warmup_steps": 0,
    "seed": "42",
    "rank": 16,
    "train_text_encoder": True,
    "max_sequence_length": 512,
    "max_train_steps": 500,
    "caption_column": "gpt4_turbo_caption"
}

Below estimtor will train the model with LoRA and will save the adapter in S3 

In [None]:
from sagemaker.pytorch import PyTorch

job_name = f"train-{model_id.split('/')[-1].replace('.', '-')}"

# Create SageMaker PyTorch Estimator

pytorch_estimator = PyTorch(
    entry_point= 'train.py',
    source_dir="./scripts",
    base_job_name=job_name,
    role=role,
    framework_version="2.2.0",
    py_version="py310",
    instance_count=1,
    instance_type="ml.p4d.24xlarge",
    sagemaker_session=sess,
    #disable_output_compression=True, # Avoid compression in .tar.gz
    keep_alive_period_in_seconds=1800,
    distribution={"torch_distributed": {"enabled": True}},
    hyperparameters=hyperparameters
)

In [None]:
# starting the train job with our uploaded datasets as input
pytorch_estimator.fit(wait=True)