In [1]:
from google.colab import drive
drive.mount('/content/drive')
%cd /content/drive/My\ Drive/Github/Product-image-generation-from-text-description

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/Github/Product-image-generation-from-text-description


In [10]:
%%writefile train_eval.py
from tqdm.auto import tqdm
import os
import torch


def train_step(vae, unet, text_encoder, noise_scheduler, dataloader, criterion, optimizer, device):
    unet.train()

    epoch_loss = 0.0

    for batch_data in dataloader:
        text, images = batch_data
        images = images.to(device)
        with torch.no_grad():
            text_embeddings = text_encoder(text["input_ids"].squeeze(1).to(device))[0]
        batch_size = images.shape[0]

        with torch.no_grad():
            latents = vae.encode(images).latent_dist.sample()     
            latents = latents * vae.config.scaling_factor
        latents = latents.to(device)

        # create noise for latents
        noise = torch.randn_like(latents).to(device)
        # Sample a random timestep for each image
        t = torch.randint(0, noise_scheduler.config.num_train_timesteps, (batch_size,), device=device).long()
        
        noisy_images = noise_scheduler.add_noise(latents, noise, t)
        noise_pred = unet(noisy_images, t, encoder_hidden_states=text_embeddings).sample

        loss = criterion(noise_pred.float(), noise.float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
        epoch_loss += loss.item()
    
    return loss / len(dataloader)

def eval_step(vae, unet, text_encoder, noise_scheduler, dataloader, device, height, width):
    unet.eval()
    vae.eval()
    text_encoder.eval()
    
    metric = 0.0
    for batch_data in dataloader:
        text, images = batch_data
        images = images.to(device)
        with torch.no_grad():
            text_embeddings = text_encoder(text["input_ids"].squeeze(1).to(device))[0]
        batch_size = images.shape[0]

        latents = torch.randn((batch_size, unet.in_channels, height // 8, width // 8))
        latents = latents * vae.config.scaling_factor
        latents = latents.to(device)
        
        for t in noise_scheduler.timesteps:
            latent_model_input = noise_scheduler.scale_model_input(latents, t)

            # predict the noise residual
            with torch.no_grad():
                noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            # compute the previous noisy sample x_t -> x_t-1
            latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
        
        latents = 1 / vae.config.scaling_factor * latents

        with torch.no_grad():
            images = vae.decode(latents).sample

        # compute metrics
        # todo
        metric += images.mean()

    return metric / len(dataloader)

def train(vae, unet, text_encoder, noise_scheduler, num_epochs, train_loader, 
          val_loader, criterion, optimizer, save_path, logger, device, inf_freq=None):
    vae.eval()
    text_encoder.eval()
    im_height, im_width = val_loader.dataset[0][1].shape[1:3]

    best_metric = 0.0
    for epoch in range(num_epochs):
        train_loss = train_step(vae, unet, text_encoder, noise_scheduler, 
                                train_loader, criterion, optimizer, device)
        
        # log train loss to wandb
        logger.log({"train_loss":train_loss}, step=epoch)

        if epoch % inf_freq == 0:
            val_metric = eval_step(vae, unet, text_encoder, noise_scheduler,
                                   val_loader, device, im_height, im_width)
            logger.log({"val_metric":val_metric}, step=epoch)

            if val_metric > best_metric:
                # save best model
                torch.save({
                    'epoch': epoch,
                    'unet_state_dict': unet.state_dict(),
                    'vae_state_dict': vae.state_dict(),
                    'text_enc_state_dict': text_encoder.state_dict()
                    }, os.path.join(save_path, f"diffusion_model_{round(val_metric, 2)}.pt"))
                
                prev_file = os.path.join(save_path, f"diffusion_model_{round(best_metric)}.pt")
                if os.path.exists(prev_file):
                    os.remove(prev_file)
                best_metric = val_metric

    # load best model weights
    best_checkpoint = torch.load(os.path.join(save_path, f"diffusion_model_{round(best_metric, 2)}.pt"))
    vae.load_state_dict(best_checkpoint["vae_state_dict"])
    unet.load_state_dict(best_checkpoint["unet_state_dict"])
    text_encoder.load_state_dict(best_checkpoint["text_enc_state_dict"])

    model = {'vae': vae, 'unet': unet, "text_encoder": text_encoder}
    return model

def generate_images(text_prompts):
    # todo
    pass

Overwriting train_eval.py


In [3]:
import os
from torch.utils.data import Dataset
from PIL import Image


class CustomDataset(Dataset):
    def __init__(self, images, texts, tokenizer):
        self.images = images
        self.texts = texts
        self.tokenizer = tokenizer

    def __getitem__(self, index):
        tokenized_text = self.tokenizer(self.texts[index], padding="max_length", 
                                        max_length=self.tokenizer.model_max_length, 
                                        truncation=True,
                                   return_tensors="pt")

        image = self.images[index]
        return tokenized_text, image

    def __len__(self):
        return len(self.images)

In [3]:
!pip install -qq -U diffusers transformers

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m737.4/737.4 KB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m91.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m107.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
from diffusers import AutoencoderKL 
from diffusers import UNet2DConditionModel, LMSDiscreteScheduler, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
import torch

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device)
unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet").to(device)
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
noise_scheduler = DDPMScheduler(
        beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
    
def disable_grad(model):
    for p in model.parameters(): 
        p.requires_grad = False 
        
    return model

vae.requires_grad_(False)#disable_grad(vae)
text_encoder.requires_grad_(False)# = disable_grad(text_encoder)
#for p in text_encoder.parameters(): p.requires_grad = False

In [5]:
n = 5
train_images = torch.randn((n, 3, 16, 16))
val_images = torch.randn((n, 3, 16, 16))
train_texts = [f"text_{i}" for i in range(n)]
val_texts = [f"text_text{i}" for i in range(n)]

train_dataset = CustomDataset(train_images, train_texts, tokenizer)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2)

val_dataset = CustomDataset(val_images, val_texts, tokenizer)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=2)

In [6]:
import sys
import os

path = '/content/drive/MyDrive/Github/Product-image-generation-from-text-description'
sys.path.insert(0, path)

In [7]:
with open(os.path.join(os.path.split(path)[0], "wandb_token.txt")) as f:
    key = f.read()

In [8]:
!pip install --upgrade -q wandb
import wandb
wandb.login(key=key)
run = wandb.init(project='text-to-image',
                    group='finetune', #resume='must',
                    job_type='train')

[34m[1mwandb[0m: Currently logged in as: [33mearina[0m ([33mdatasatanists[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [9]:
#from train_eval import train

train(vae, unet, text_encoder, noise_scheduler, 3, train_dataloader, val_dataloader, 
      torch.nn.functional.mse_loss, 
      torch.optim.Adam(unet.parameters()), 
      os.path.join(path, 'models'), wandb, device, 1)
#wandb.finish()

OutOfMemoryError: ignored