In [None]:
from google.colab import drive
drive.mount('/content/drive')
#%cd /content/drive/My\ Drive/Github/Product-image-generation-from-text-description

Mounted at /content/drive


In [None]:
%%writefile /content/drive/MyDrive/Github/Product-image-generation-from-text-description/train_eval.py
from tqdm.auto import tqdm
import os
import torch
import math
from PIL import Image
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore


def train_step(vae, unet, text_encoder, noise_scheduler, dataloader, criterion,
               optimizer, device, accelerator, scaler):
    unet.train()

    epoch_loss = 0.0
    NUM_ACCUMULATION_STEPS = 2

    for idx, batch_data in tqdm(enumerate(dataloader)):
        text, images = batch_data
        optimizer.zero_grad()
        
        text_embeddings = text_encoder(text["input_ids"].to(device).squeeze(1))[0]
        batch_size = images.shape[0]

        #with torch.no_grad():
        latents = vae.encode(images.to(device, dtype=torch.float16)).latent_dist.sample()     
        latents = latents * vae.config.scaling_factor

        # create noise for latents
        noise = torch.randn_like(latents).to(latents.device)
        # Sample a random timestep for each image
        t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=latents.device).long()

        noisy_images = noise_scheduler.add_noise(latents, noise, t)
        with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
            noise_pred = unet(noisy_images, t, encoder_hidden_states=text_embeddings).sample
            loss = criterion(noise_pred.float(), noise.float(), reduction="mean") / NUM_ACCUMULATION_STEPS
        
        scaler.scale(loss).backward() #loss.backward()
        
        if ((idx + 1) % NUM_ACCUMULATION_STEPS == 0) or (idx + 1 == len(dataloader)):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()

        epoch_loss += loss.item()
        
    return loss / len(dataloader)

def eval_step(vae, unet, text_encoder, noise_scheduler, dataloader, 
              device, height, width, num_inference_steps, logger):
    unet.eval()
    vae.eval()
    text_encoder.eval()
    noise_scheduler.set_timesteps(num_inference_steps, device=device)
    
    num_images_to_log = 10
    num_iters = (num_images_to_log / dataloader.batch_size) + 1
    
    # specifying metric
    fid = FrechetInceptionDistance(feature=64)
    inception_score = InceptionScore(feature=64)
    metric_fid = 0.0
    metric_inception = 0.0
    for idx, batch_data in enumerate(dataloader):
        text, images = batch_data
        images = images.to(device)
        with torch.no_grad():
            text_embeddings = text_encoder(text["input_ids"].squeeze(1).to(device))[0]
        batch_size = images.shape[0]

        latents = torch.randn((batch_size, unet.in_channels, height // 8, width // 8))
        latents = latents * vae.config.scaling_factor
        latents = latents.to(device)
        
        for t in noise_scheduler.timesteps:
            latent_model_input = noise_scheduler.scale_model_input(latents, t)

            # predict the noise residual
            with torch.no_grad():
                with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
                    noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # compute the previous noisy sample x_t -> x_t-1
            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
        
        latents = 1 / vae.config.scaling_factor * latents

        with torch.no_grad():
            with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
                pred_images = vae.decode(latents).sample
        
        # compute metrics
        # FID
        fid.update(images, real=True)
        fid.update(pred_images, real=False)
        metric_fid += fid.compute().item()

        # Inception Score
        inception_score.update(pred_images)
        metric_inception += inception_score.compute()[0].item()

        if idx < num_iters:
            pred_images = (pred_images / 2 + 0.5).clamp(0, 1)
            pred_images = pred_images.cpu().permute(0, 2, 3, 1).float().numpy()
            true_images = (images / 2 + 0.5).clamp(0, 1).cpu().permute(0, 2, 3, 1).float().numpy()
            image_array = [(true_images[i] * 255).astype(np.uint8) for i in range(true_images.shape[0])]
            pred_images = [(pred_images[i] * 255).astype(np.uint8) for i in range(pred_images.shape[0])]
            
            lbl_idx = idx * dataloader.batch_size
            indices = dataloader.dataset.indices
            labels = [dataloader.dataset.dataset.descriptions.iloc[indices[lbl_idx + i]]['description'] for i in range(true_images.shape[0])]
            
            true_images = [Image.fromarray(image) for image in image_array]
            pred_images = [Image.fromarray(image) for image in pred_images]
            logger.log({"true_images": [wandb.Image(image, caption=labels[i]) for i, image in enumerate(images)],
                      "pred_images": [wandb.Image(image, caption=labels[i]) for i, image in enumerate(pred_images)]})
        else:
            break
        
    return metric_fid / len(dataloader), met

def train(vae, unet, text_encoder, noise_scheduler, num_epochs, train_loader, 
          val_loader, criterion, optimizer, save_path, logger, device, args, inf_freq=None):
    #vae.eval()
    #text_encoder.eval()
    im_height, im_width = val_loader.dataset[0][1].shape[1:3]
    
    accelerator = args.accelerator

    #unet, optimizer, train_loader = accelerator.prepare(unet, optimizer, train_loader)
    if args.gradient_checkpointing:
        unet.enable_gradient_checkpointing()
    
    #text_encoder.to(accelerator.device,  dtype=torch.float16)
    #vae.to(accelerator.device,  dtype=torch.float16)
    text_encoder.to(device,  dtype=torch.float16)
    vae.to(device,  dtype=torch.float16)
    unet.to(device)
    
    scaler = torch.cuda.amp.GradScaler(enabled=True)
    
    best_metric = 0.0
    for epoch in tqdm(range(num_epochs)):
        train_loss = train_step(vae, unet, text_encoder, noise_scheduler, 
                                train_loader, criterion, optimizer, device, accelerator, scaler)
        
        # log train loss to wandb
        logger.log({"train_loss":train_loss}, step=epoch)

        if (epoch + 1) % inf_freq == 0:
            val_metric = eval_step(vae, unet, text_encoder, noise_scheduler,
                                   val_loader, device, im_height, im_width, 50, logger)
            logger.log({"val_metric": val_metric}, step=epoch)

            if val_metric > best_metric:
                # save best model
                torch.save({
                    'epoch': epoch,
                    'unet_state_dict': unet.state_dict(),
                    'vae_state_dict': vae.state_dict(),
                    'text_enc_state_dict': text_encoder.state_dict()
                    }, os.path.join(save_path, f"diffusion_model_{round(val_metric, 2)}.pt"))
                
                #prev_file = os.path.join(save_path, f"diffusion_model_{round(best_metric)}.pt")
                #if os.path.exists(prev_file):
                #    os.remove(prev_file)
                best_metric = val_metric

    # load best model weights
    #best_checkpoint = torch.load(os.path.join(save_path, f"diffusion_model_{round(best_metric, 2)}.pt"))
    #vae.load_state_dict(best_checkpoint["vae_state_dict"])
    #unet.load_state_dict(best_checkpoint["unet_state_dict"])
    #text_encoder.load_state_dict(best_checkpoint["text_enc_state_dict"])

    #model = {'vae': vae, 'unet': unet, "text_encoder": text_encoder}
    #return model

def generate_images(text_prompts, vae, unet, noise_scheduler, text_encoder, tokenizer, im_height=512, im_width=512):
    noise_scheduler.set_timesteps(70)
    text = [tokenizer(text_prmt, padding="max_length", 
                                max_length=tokenizer.model_max_length, truncation=True,
                                return_tensors="pt")["input_ids"] for text_prmt in text_prompts]
    text = torch.cat(text)#.half()
    batch_size = text.shape[0]

    with torch.no_grad():
        text_embeddings = text_encoder(text.to(device))[0].half()

    latents = torch.randn((batch_size, unet.in_channels, im_height // 8, im_width // 8))
    latents = latents.half().to(device) #* noise_scheduler.init_noise_sigma
    latents = latents * vae.config.scaling_factor 

    for t in noise_scheduler.timesteps:
        latent_model_input = noise_scheduler.scale_model_input(latents, t)

        # predict the noise residual
        with torch.no_grad():
            noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

        # compute the previous noisy sample x_t -> x_t-1
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample

    latents = 1 / vae.config.scaling_factor * latents

    with torch.no_grad():
        images = vae.decode(latents).sample

    images = (images / 2 + 0.5).clamp(0, 1)
    images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
    return images

Overwriting /content/drive/MyDrive/Github/Product-image-generation-from-text-description/train_eval.py


In [None]:
#from kaggle_secrets import UserSecretsClient
#user_secrets = UserSecretsClient()
#key = user_secrets.get_secret("wandb_api")
!pip install --upgrade -q wandb
import wandb
wandb.login(key=key)
run = wandb.init(project='text-to-image',
                    group='finetune', #resume='must',
                    job_type='train')

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m48.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.2/199.2 KB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 KB[0m [31m21.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
import os
from torch.utils.data import Dataset
from PIL import Image


class CustomTensorDataset(Dataset):
    def __init__(self, descriptions, tokenizer, path, transform_images=None):
        self.descriptions = descriptions

        self.links = {}
        for file in os.listdir(path):
            self.links[int(file.split('.')[0])] = path + '/' + file

        self.tokenizer = tokenizer
        self.transform_images = transform_images

    def __getitem__(self, index):
        text = self.descriptions.iloc[index]['description']
        idx = self.descriptions.iloc[index]['id']
        tokenized_text = self.tokenizer(text, padding="max_length", max_length=self.tokenizer.model_max_length, truncation=True,
                                   return_tensors="pt")

        image = Image.open(self.links[idx])
        if self.transform_images:
            image = self.transform_images(image)

        return tokenized_text, image

    def __len__(self):
        return len(self.links)

'\nclass CustomDataset(Dataset):\n    def __init__(self, images, texts, tokenizer):\n        self.images = images\n        self.texts = texts\n        self.tokenizer = tokenizer\n\n    def __getitem__(self, index):\n        tokenized_text = self.tokenizer(self.texts[index], padding="max_length", \n                                        max_length=self.tokenizer.model_max_length, \n                                        truncation=True,\n                                   return_tensors="pt")\n\n        image = self.images[index]\n        return tokenized_text, image\n\n    def __len__(self):\n        return len(self.images)'

In [None]:
import torch
import pandas as pd
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader, Subset
from transformers import CLIPTokenizer
from sklearn.model_selection import train_test_split

In [None]:
from google.colab import files
files.upload()
#!pip install -q kaggle
#!mkdir -p ~/.kaggle
#!chmod 600 ~/.kaggle/kaggle.json
#!cp kaggle.json ~/.kaggle/
os.environ['KAGGLE_CONFIG_DIR'] = "/content"

Saving kaggle.json to kaggle.json


In [None]:
!kaggle datasets download -d paramaggarwal/fashion-product-images-dataset -p '/content'

Downloading fashion-product-images-dataset.zip to /content
100% 23.1G/23.1G [02:48<00:00, 181MB/s]
100% 23.1G/23.1G [02:48<00:00, 147MB/s]


In [None]:
!unzip '/content/fashion-product-images-dataset.zip'

In [None]:
!rm /content/fashion-product-images-dataset.zip

In [None]:
!pip install -qq -U diffusers transformers accelerate

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m737.4/737.4 KB[0m [31m34.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m75.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m215.3/215.3 KB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m200.1/200.1 KB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m108.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from diffusers import AutoencoderKL 
from diffusers import UNet2DConditionModel, LMSDiscreteScheduler, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")#,  torch_dtype=torch.float16)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")#,  torch_dtype=torch.float16)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
noise_scheduler = DDPMScheduler(
        beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
    
_ = vae.requires_grad_(False)
_ = text_encoder.requires_grad_(False)

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]

Downloading (…)main/vae/config.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

Downloading (…)on_pytorch_model.bin:   0%|          | 0.00/3.44G [00:00<?, ?B/s]

Downloading (…)ain/unet/config.json:   0%|          | 0.00/743 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.52k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.71G [00:00<?, ?B/s]

Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPTextModel: ['vision_model.encoder.layers.13.layer_norm2.bias', 'vision_model.encoder.layers.18.layer_norm2.bias', 'vision_model.encoder.layers.15.mlp.fc2.weight', 'vision_model.encoder.layers.21.layer_norm1.bias', 'vision_model.encoder.layers.9.layer_norm2.bias', 'vision_model.encoder.layers.1.self_attn.k_proj.bias', 'vision_model.encoder.layers.10.self_attn.q_proj.bias', 'vision_model.encoder.layers.12.self_attn.out_proj.bias', 'vision_model.encoder.layers.19.self_attn.q_proj.bias', 'vision_model.encoder.layers.21.self_attn.k_proj.bias', 'vision_model.encoder.layers.5.self_attn.q_proj.weight', 'vision_model.encoder.layers.2.self_attn.out_proj.weight', 'vision_model.encoder.layers.14.self_attn.v_proj.bias', 'vision_model.encoder.layers.12.mlp.fc2.bias', 'vision_model.post_layernorm.weight', 'vision_model.encoder.layers.12.layer_norm2.bias', 'vision_model.encoder.layers.7.mlp.fc2.we

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/961k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/525k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/389 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/905 [00:00<?, ?B/s]

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e

In [None]:
import os

path = '/content/drive/MyDrive/Github/Product-image-generation-from-text-description'

path_to_descriptions = os.path.join(path, 'descriptions.json')
descriptions = pd.read_json(path_to_descriptions, orient='records')

In [None]:
import sys

sys.path.insert(0, path)

In [None]:
RESOLUTION = 64

data_transformation_images = transforms.Compose([
            transforms.Resize((RESOLUTION, RESOLUTION)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, ), (0.5, ))
        ])

path = 'content/fashion-dataset/images'
batch_size = 32

dataset = CustomTensorDataset(descriptions, tokenizer, path, transform_images=data_transformation_images)

In [None]:
indices = np.arange(len(descriptions))
indices_train, indices_test = train_test_split(indices, test_size=0.2)

In [None]:
train_dataset = Subset(dataset, indices_train)
test_dataset = Subset(dataset, indices_test)

batch_size = 16
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
with open(os.path.join(os.path.split(path)[0], "wandb_token.txt")) as f:
    key = f.read()

In [None]:
from easydict import EasyDict as edict

args = edict()

args.gradient_accumulation_steps = 2
args.mixed_precision = "fp16" 
args.gradient_checkpointing=True

args.accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
    )
args.use_8bit_adam = True
args.train_batch_size = train_dataloader.batch_size
args.max_train_steps = 3
args.train_text_encoder = False
args.set_grads_to_none = False

In [None]:
!pip install -q bitsandbytes

In [None]:
import bitsandbytes as bnb

optimizer = bnb.optim.AdamW8bit(
    unet.parameters(),
    lr=2e-6
)

torch.backends.cudnn.benchmark = True
train(vae=vae, unet=unet, text_encoder=text_encoder, 
      noise_scheduler=noise_scheduler, num_epochs=args.max_train_steps, 
      train_loader=train_dataloader, val_loader=test_dataloader, 
      criterion=torch.nn.functional.mse_loss, 
      optimizer=optimizer, 
      save_path='/', 
      logger=wandb, device=device, args=args, 
      inf_freq=1)