In [None]:
%%writefile /content/drive/MyDrive/Github/Product-image-generation-from-text-description/code/config.py
from easydict import EasyDict as edict
from accelerate import Accelerator
from tqdm import tqdm 

args = edict()

args.gradient_accumulation_steps = 2
args.mixed_precision = "fp16" 
args.gradient_checkpointing = True

accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
    )
args.revision = "fp16"
args.pretrained_model_name_or_path = 'CompVis/stable-diffusion-v1-4'
args.use_8bit_adam = True
args.train_batch_size = 8
args.max_train_steps = None
args.num_train_epochs = 10
args.train_text_encoder = False
args.set_grads_to_none = False
args.seed = None
args.scale_lr = False #???????????
args.learning_rate = 1e-6
args.adam_beta1 = 0.9
args.adam_beta2 = 0.999
args.adam_weight_decay = 1e-2
args.adam_epsilon = 1e-08
args.output_dir = '/kaggle/working/'
args.height, args.width = test_dataloader.dataset[0][1].shape[1:3]
args.num_inference_steps = 50
args.enable_xformers_memory_efficient_attention = False
args.max_grad_norm = 1.0
args.validation_steps = 1
args.checkpointing_steps = 7 #args.num_train_epochs // 2 + 1
args.lr_scheduler = 'constant'
args.lr_warmup_steps = 500
args.lr_num_cycles = 1
args.lr_power = 1
args.revision = "fp16"
args.resume_from_checkpoint = True
args.checkpoint_path = '/kaggle/input/fashion-data/checkpoint-10545/checkpoint-10545'

Overwriting /content/drive/MyDrive/Github/Product-image-generation-from-text-description/code/config.py


In [None]:
%%writefile /content/drive/MyDrive/Github/Product-image-generation-from-text-description/code/train_eval.py

from diffusers import (
    UNet2DConditionModel, 
    LMSDiscreteScheduler, 
    DDPMScheduler,
    DPMSolverMultistepScheduler,
    DiffusionPipeline,
    AutoencoderKL
)
from transformers import CLIPTextModel, CLIPTokenizer
import torch
import bitsandbytes as bnb

def eval_step(unet, text_encoder, tokenizer, vae, accelerator, dataloader, logger, epoch, args, weight_dtype):
    indices = dataloader.dataset.indices
    n = 10
    labels = [dataloader.dataset.dataset.descriptions.iloc[indices[i]]['description'] for i in range(n)]
    true_images = [dataloader.dataset[i][1].float().permute(1, 2, 0) for i in range(n)]
    true_images = [(image / 2 + 0.5).clamp(0, 1).numpy() for image in true_images]
    image_array = [(true_images[i] * 255).astype(np.uint8) for i in range(len(true_images))]
    true_images = [Image.fromarray(image) for image in image_array]

    pipeline = DiffusionPipeline.from_pretrained(
        args.pretrained_model_name_or_path,
        text_encoder=accelerator.unwrap_model(text_encoder),
        tokenizer=tokenizer,
        unet=accelerator.unwrap_model(unet),
        vae=vae,
        revision=args.revision,
        torch_dtype=weight_dtype,
    )
    pipeline.safety_checker = lambda images, clip_input: (images, False)
    pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
    pipeline = pipeline.to(accelerator.device)
    #pipeline.set_progress_bar_config()

    # run inference
    generator = None if args.seed is None else torch.Generator(device=accelerator.device).manual_seed(args.seed)
    images = []
    for i in tqdm(range(n)):
        with torch.autocast("cuda"):
            image = pipeline(labels[i], num_inference_steps=args.num_inference_steps, 
                             generator=generator, width=args.width, 
                             height=args.height).images[0]
            images.append(image)

    logger.log({"true_images": [wandb.Image(image, caption=labels[i]) for i, image in enumerate(true_images)],
                "pred_images": [wandb.Image(image, caption=labels[i]) for i, image in enumerate(images)]},
                      step=epoch)
    
    del pipeline
    torch.cuda.empty_cache()

def train(args, train_dataloader, val_dataloader):
    vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", 
                                    revision=args.revision)#,  torch_dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet",
                                            revision=args.revision)
    text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, 
                                                subfolder="text_encoder",
                                                revision=args.revision)#,  torch_dtype=torch.float16)

    noise_scheduler = DDPMScheduler(
            beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

    params_to_optimize = unet.parameters()
    optimizer = bnb.optim.AdamW8bit(
        params_to_optimize,
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    if args.seed is not None:
    set_seed(args.seed)

    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)

    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
        if args.train_text_encoder:
            text_encoder.gradient_checkpointing_enable()

    # Check that all trainable models are in full precision
    low_precision_error_string = (
        "Please make sure to always have all model weights in full float32 precision when starting training - even if"
        " doing mixed precision training. copy of the weights should still be float32."
    )

    if accelerator.unwrap_model(unet).dtype != torch.float32:
        raise ValueError(
            f"Unet loaded as datatype {accelerator.unwrap_model(unet).dtype}. {low_precision_error_string}"
        )

    if args.train_text_encoder and accelerator.unwrap_model(text_encoder).dtype != torch.float32:
        raise ValueError(
            f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
            f" {low_precision_error_string}"
        )


    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
        )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
    if args.max_train_steps is None:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
        num_cycles=args.lr_num_cycles,
        power=args.lr_power,
    )

    # Prepare everything with our `accelerator`.
    if args.train_text_encoder:
        unet, text_encoder, optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(
            unet, text_encoder, optimizer, train_dataloader, test_dataloader, lr_scheduler
        )
    else:
        unet, optimizer, train_dataloader, test_dataloader, lr_scheduler = accelerator.prepare(
            unet, optimizer, train_dataloader, test_dataloader, lr_scheduler
        )

    # For mixed precision training we cast the text_encoder and vae weights to half-precision
    # as these models are only used for inference, keeping weights in full precision is not required.
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # Move vae and text_encoder to device and cast to weight_dtype
    vae.to(accelerator.device, dtype=weight_dtype)
    if not args.train_text_encoder:
        text_encoder.to(accelerator.device, dtype=weight_dtype)

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)

    if overrode_max_train_steps:
        args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Train!
    total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        path = os.path.basename(args.checkpoint_path)

        if path is None:
            accelerator.print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            accelerator.print(f"Resuming from checkpoint {path}")
            accelerator.load_state(os.path.join(args.checkpoint_path))
            global_step = int(path.split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
    progress_bar.set_description("Steps")

    for epoch in range(first_epoch, args.num_train_epochs):
        unet.train()
        if args.train_text_encoder:
            text_encoder.train()
        
        epoch_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue
                    
            with accelerator.accumulate(unet):
                text, images = batch
                # Convert images to latent space
                latents = vae.encode(images.to(dtype=weight_dtype)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor

                noise = torch.randn_like(latents)
                    
                bsz = latents.shape[0]
                # Sample a random timestep for each image
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
                timesteps = timesteps.long()

                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                # Get the text embedding for conditioning
                encoder_hidden_states = text_encoder(text["input_ids"].squeeze(1))[0]

                # Predict the noise residual
                model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample

                # Get the target for loss depending on the prediction type
                if noise_scheduler.config.prediction_type == "epsilon":
                    target = noise
                elif noise_scheduler.config.prediction_type == "v_prediction":
                    target = noise_scheduler.get_velocity(latents, noise, timesteps)
                else:
                    raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")

                loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
                
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    params_to_clip = (
                        itertools.chain(unet.parameters(), text_encoder.parameters())
                        if args.train_text_encoder
                        else unet.parameters()
                    )
                    accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad(set_to_none=args.set_grads_to_none)

            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                
                if accelerator.is_main_process:
                    if global_step % (args.checkpointing_steps * num_update_steps_per_epoch) == 0:
                        save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        print(f"Saved state to {save_path}")

                    if global_step % (args.validation_steps * num_update_steps_per_epoch) == 0:
                        eval_step(unet, text_encoder, tokenizer, vae, accelerator, test_dataloader, 
                                args.logger, epoch, args,weight_dtype) 
                        
            logs = {"train_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
            progress_bar.set_postfix(**logs)
            
            epoch_loss += loss.detach().item()
            if global_step >= args.max_train_steps:
                break

        args.logger.log({"train_loss": epoch_loss / num_update_steps_per_epoch}, step=epoch)
        args.logger.log({"lr":lr_scheduler.get_last_lr()[0]}, step=epoch)
        print(f"Epoch: {epoch}, loss: {epoch_loss / num_update_steps_per_epoch}")

    # Create the pipeline using using the trained modules and save it.
    accelerator.wait_for_everyone()

Writing /content/drive/MyDrive/Github/Product-image-generation-from-text-description/code/train_eval.py


In [None]:
use_colab = True

if use_colab:
    from google.colab import drive
    drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
if not use_colab:
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    key = user_secrets.get_secret("wandb_api")
else:
    with open('wandb_token.txt') as f:
        key = f.read()
        
!pip install --upgrade wandb
import wandb
wandb.login(key=key)
run = wandb.init(project='text-to-image',
                    group='finetune', #resume='must',
                    job_type='train')

In [None]:
import torch
import pandas as pd
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from transformers import CLIPTokenizer
from sklearn.model_selection import train_test_split

from code.CustomDataset import CustomTensorDataset
from code.config import args

In [None]:
!pip install -qq -U diffusers transformers accelerate

In [None]:
!pip install -q bitsandbytes

[0m

In [None]:
import os

use_colab = False

if use_colab:
    path = '/content/drive/MyDrive/Github/Product-image-generation-from-text-description'
else:
    path = '/kaggle/input/fashion-data'
    
path_to_descriptions = os.path.join(path, 'descriptions_2.json')
descriptions = pd.read_json(path_to_descriptions, orient='records')
descriptions['description'] = descriptions['description'].apply(lambda x: x + ' isolated on white background')

In [None]:
from transformers import CLIPTokenizer

tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, 
                                          subfolder="tokenizer",
                                         revision=args.revision)

Downloading (…)tokenizer/vocab.json:   0%|          | 0.00/1.06M [00:00<?, ?B/s]

Downloading (…)tokenizer/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/788 [00:00<?, ?B/s]

In [None]:
RESOLUTION = 256

data_transformation_images = transforms.Compose([
            transforms.Resize((RESOLUTION, RESOLUTION)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])

if use_colab:
    im_path = 'content/fashion-dataset/images'
else:
    im_path = '/kaggle/input/fashion-product-images-dataset/fashion-dataset/images'
    
dataset = CustomTensorDataset(descriptions, tokenizer, im_path, transform_images=data_transformation_images)

In [None]:
indices = np.arange(len(descriptions))
indices_train, indices_test = train_test_split(indices, test_size=0.2)

In [None]:
train_dataset = Subset(dataset, indices_train)
test_dataset = Subset(dataset, indices_test)

batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

args.height, args.width = test_dataloader.dataset[0][1].shape[1:3]
args.logger = wandb

In [None]:
train(args)

In [None]:
from diffusers import DiffusionPipeline

save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
print(f"Saved state to {save_path}")

pipeline = DiffusionPipeline.from_pretrained(
    args.pretrained_model_name_or_path,
    unet=accelerator.unwrap_model(unet),
    text_encoder=accelerator.unwrap_model(text_encoder),
    revision=args.revision
)
pipeline.save_pretrained('/kaggle/working/data/')