# Fine-tuning Flux Model
This notebook demonstrates how to fine-tune the Flux diffusion model using Hugging Face Diffusers.

In [None]:
# Install dependencies
!pip install diffusers[training] accelerate transformers datasets

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from transformers import CLIPTextModel, CLIPTokenizer
import torchvision.transforms as T
from PIL import Image
import os

In [None]:
# Dataset definition
class ImageTextDataset(Dataset):
    def __init__(self, folder, tokenizer, transforms=None):
        self.folder = folder
        self.image_paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith(('png','jpg','jpeg'))]
        self.tokenizer = tokenizer
        self.transforms = transforms or T.Compose([
            T.Resize((512,512)),
            T.ToTensor(),
            T.Normalize([0.5], [0.5])
        ])
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        image = self.transforms(image)
        text = "pixel art sprite"  # Change to appropriate prompt or caption
        inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.tokenizer.model_max_length, return_tensors='pt')
        return {"pixel_values": image, "input_ids": inputs.input_ids.squeeze()}

In [None]:
# Initialize models
model_id = "your-flux-model-id"  # e.g., "sprited/flux-1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet").to(device)
scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)

tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)

In [None]:
# Prepare data
dataset = ImageTextDataset("path/to/your/images", tokenizer)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [None]:
# Training loop
optimizer = torch.optim.AdamW(unet.parameters(), lr=1e-5)

num_epochs = 3
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for step, batch in enumerate(dataloader):
        optimizer.zero_grad()
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        text_embeddings = text_encoder(input_ids)[0]
        latents = vae.encode(pixel_values).latent_dist.sample() * 0.18215
        noise = torch.randn_like(latents)
        timesteps = torch.randint(0, scheduler.num_train_timesteps, (latents.shape[0],), device=device).long()
        noisy_latents = scheduler.add_noise(latents, noise, timesteps)
        noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states=text_embeddings).sample
        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        loss.backward()
        optimizer.step()
        if step % 10 == 0:
            print(f"  Step {step}, Loss: {loss.item():.4f}")

In [None]:
# Save the fine-tuned model
output_dir = "./flux-finetuned"
unet.save_pretrained(f"{output_dir}/unet")
scheduler.save_pretrained(f"{output_dir}/scheduler")
vae.save_pretrained(f"{output_dir}/vae")
tokenizer.save_pretrained(f"{output_dir}/tokenizer")
text_encoder.save_pretrained(f"{output_dir}/text_encoder")
print("Model saved to", output_dir)