In [1]:
# Clone the repositories
!git clone https://github.com/NVlabs/stylegan3
!git clone https://github.com/openai/CLIP
!git clone https://github.com/salesforce/BLIP
# Install the requirements
!pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 ninja==1.11.1.1
!pip install ftfy==6.1.3 regex==2023.12.25 tqdm==4.66.2
!pip install transformers==4.19.4 timm==0.9.16 fairscale==0.4.13
# Download the pre-trained models
!mkdir pretrained_models
!gdown -O test.png https://drive.google.com/uc?id=1hfVAbs5nkXcUpG6FCAafid7F7ZsqRRkk
!curl -L --output pretrained_models/stylegan2-ffhq-512x512.pkl 'https://api.ngc.nvidia.com/v2/models/org/nvidia/team/research/stylegan2/1/files?redirect=true&path=stylegan2-ffhq-512x512.pkl'
!curl -L --output pretrained_models/ViT-B-32.pt 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt'
!curl -L --output pretrained_models/model_base_capfilt_large.pth https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth

import locale
locale.getpreferredencoding = lambda: "UTF-8"

# Download dataset:
!gdown -O dataset.zip https://drive.google.com/uc?id=1JtRIbZDuZlBEA6vp870eW3ei0r32uYqF
# https://drive.google.com/file/d/1JtRIbZDuZlBEA6vp870eW3ei0r32uYqF/view?usp=sharing
!unzip -q dataset.zip

Cloning into 'stylegan3'...
remote: Enumerating objects: 212, done.[K
remote: Counting objects: 100% (5/5), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 212 (delta 0), reused 1 (delta 0), pack-reused 207[K
Receiving objects: 100% (212/212), 4.17 MiB | 8.79 MiB/s, done.
Resolving deltas: 100% (98/98), done.
Cloning into 'CLIP'...
remote: Enumerating objects: 251, done.[K
remote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 251 (delta 3), reused 2 (delta 0), pack-reused 243[K
Receiving objects: 100% (251/251), 8.93 MiB | 14.30 MiB/s, done.
Resolving deltas: 100% (127/127), done.
Cloning into 'BLIP'...
remote: Enumerating objects: 277, done.[K
remote: Counting objects: 100% (165/165), done.[K
remote: Compressing objects: 100% (30/30), done.[K
remote: Total 277 (delta 137), reused 136 (delta 135), pack-reused 112[K
Receiving objects: 100% (277/277), 7.03 MiB | 13.80 MiB/s, done.
Resolving deltas: 

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import pickle
import sys
import CLIP.clip.clip as clip
from torchvision.transforms.functional import InterpolationMode
import numpy as np
import os
import PIL
sys.path.append('BLIP') # a folder BLIP is in the same directory as this notebook
from BLIP.models.blip import blip_decoder # to use class in: pwd/BLIP/models/blip.py
#### ENCODER
class TextToLatentEncoder2(nn.Module):
    def __init__(self, input_dim=512, output_dim=512):
        super(TextToLatentEncoder, self).__init__()
        # Initialize network layers
        self.relu = nn.LeakyReLU()
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 256)
        self.fc4 = nn.Linear(256, 512)
        self.fc5 = nn.Linear(512, 1024)
        self.fc6 = nn.Linear(1024, 2048)
        self.fc7 = nn.Linear(2048, 8192)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        x = self.relu(x)
        x = self.fc4(x)
        x = self.relu(x)
        x = self.fc5(x)
        x = self.relu(x)
        x = self.fc6(x)
        x = self.relu(x)
        x = self.fc7(x)
        x = self.relu(x)
        return x

import torch.nn as nn

class TextToLatentEncoder123(nn.Module):
    def __init__(self, input_dim=512, output_dim=8192):
        super(TextToLatentEncoder, self).__init__()
        # Initialize network layers
        self.relu = nn.LeakyReLU(0.2)
        self.fc1 = nn.Linear(input_dim, 1024)
        self.dropout1 = nn.Dropout(0.25)

        self.fc2 = nn.Linear(1024, 2048)
        self.dropout2 = nn.Dropout(0.25)

        self.fc3 = nn.Linear(2048, 4096)
        self.dropout3 = nn.Dropout(0.25)

        self.fc4 = nn.Linear(4096, 8192)  # Expanded to directly go to 8192 if needed
        self.dropout4 = nn.Dropout(0.25)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        x = self.relu(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        x = self.relu(x)
        x = self.dropout3(x)

        x = self.fc4(x)
        x = self.relu(x)
        x = self.dropout4(x)

        return x

import torch.nn as nn

class TextToLatentEncoder(nn.Module):
    def __init__(self, input_dim=512, output_dim=8192):
        super(TextToLatentEncoder, self).__init__()
        # Initialize network layers
        self.relu = nn.LeakyReLU(0.2)

        self.fc1 = nn.Linear(input_dim, 1024)
        self.in1 = nn.InstanceNorm1d(1024)  # Batch normalization for the first layer
        self.dropout1 = nn.Dropout(0.25)

        self.fc2 = nn.Linear(1024, 2048)
        self.in2 = nn.InstanceNorm1d(2048)  # Batch normalization for the second layer
        self.dropout2 = nn.Dropout(0.25)

        self.fc3 = nn.Linear(2048, 4096)
        self.in3 = nn.InstanceNorm1d(4096)  # Batch normalization for the third layer
        self.dropout3 = nn.Dropout(0.25)

        self.fc4 = nn.Linear(4096, 8192)
        self.in4 = nn.InstanceNorm1d(8192)  # Batch normalization for the fourth layer
        self.dropout4 = nn.Dropout(0.25)

    def forward(self, x):
        x = self.fc1(x)
        #x = self.in1(x)
        x = self.relu(x)
        x = self.dropout1(x)

        x = self.fc2(x)
        #x = self.in2(x)
        x = self.relu(x)
        x = self.dropout2(x)

        x = self.fc3(x)
        #x = self.in3(x)
        x = self.relu(x)
        x = self.dropout3(x)

        x = self.fc4(x)
        #x = self.in4(x)
        x = self.relu(x)
        x = self.dropout4(x)

        return x

# L2
def loss_calc_l2(text_emb, z):
  return torch.norm(text_emb - z)

# L1
def loss_calc_l1(text_emb, z):
  return torch.sum(torch.abs(text_emb - z))

In [3]:
from tqdm.auto import tqdm
sys.path.append('stylegan3')  # assuming stylegan3 is in the same directory as this notebook

# Load the StyleGAN model
device = "cuda" if torch.cuda.is_available() else "cpu"
with open('pretrained_models/stylegan2-ffhq-512x512.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema'].to(device)  # Load pretrained StyleGAN model - SYNTHESIS
    G.eval()  # Set StyleGAN to eval mode

# Load the CLIP model
clip_model, preprocess = clip.load("ViT-B/32", device=device)


### BLIP
blip_model_url = 'pretrained_models/model_base_capfilt_large.pth'
med_config_path = os.getcwd() + '/BLIP/configs/med_config.json'
blip_model = blip_decoder(pretrained=blip_model_url, image_size=512, vit='base', med_config = med_config_path)
#blip_model.eval()
blip_model = blip_model.to(device)


encoder = TextToLatentEncoder(input_dim=512, output_dim=G.z_dim).to(device)
encoder.float()  # Ensure the encoder uses Float dtype
optimizer = optim.Adam(encoder.parameters(), lr=0.0005)

# Ensure all components are converted to float before training
encoder = encoder.float()
G = G.float()
clip_model = clip_model.float()
encoder.train()
G.eval()


### CLIP

  #z = torch.randn([1, G.z_dim]).to(device

  #Encoder(text_embedding) ->

  #ABC(text) -> CLIP(text) -> Encoder(text_embeddibg): returns z*/w* -> |z*-z| (styleGAN(z/W) return: image/latent_code)

100%|███████████████████████████████████████| 338M/338M [00:05<00:00, 62.8MiB/s]


Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

reshape position embedding from 196 to 1024
load checkpoint from pretrained_models/model_base_capfilt_large.pth


Generator(
  (synthesis): SynthesisNetwork(
    w_dim=512, num_ws=16, img_resolution=512, img_channels=3, num_fp16_res=4
    (b4): SynthesisBlock(
      resolution=4, architecture=skip
      (conv1): SynthesisLayer(
        in_channels=512, out_channels=512, w_dim=512, resolution=4, up=1, activation=lrelu
        (affine): FullyConnectedLayer(in_features=512, out_features=512, activation=linear)
      )
      (torgb): ToRGBLayer(
        in_channels=512, out_channels=3, w_dim=512
        (affine): FullyConnectedLayer(in_features=512, out_features=512, activation=linear)
      )
    )
    (b8): SynthesisBlock(
      resolution=8, architecture=skip
      (conv0): SynthesisLayer(
        in_channels=512, out_channels=512, w_dim=512, resolution=8, up=2, activation=lrelu
        (affine): FullyConnectedLayer(in_features=512, out_features=512, activation=linear)
      )
      (conv1): SynthesisLayer(
        in_channels=512, out_channels=512, w_dim=512, resolution=8, up=1, activation=lrelu
 

In [4]:
# read dataset
import os
imagess = os.listdir('dataset')
# import txt file as csv file with pandas with comma separator
import pandas as pd
dtype_dict = {'name': str, 'caption': str}
df = pd.read_csv('dataset/image_captions_cleaned.txt', sep=",", header=0, dtype=dtype_dict)
print('Shape is: ' + str(df.shape))
df.head()

Shape is: (2048, 2)


Unnamed: 0,name,caption
0,1,a man with a black jacket
1,2,a young child with a very look on his face
2,3,a woman with a big smile on her face
3,4,a man in a white shirt
4,5,a man with a bandana on his head


In [5]:
keywords = ["blonde", "bald", "brown", "black", "blue", "yellow", "green", "long", "short", "eye", "angry", "sad", "happy", "excited", "red", "glasses", "pink", "gray", "grey", "bright", "dark", "curly", "straight", "beard"]
nb_iter = 2048
losses = []
for idx in tqdm(range(nb_iter)):

  '''
  z = torch.randn([1, G.z_dim]).to(device)
  w = G.mapping(z, None) # styleGAN output
  # w: 1, 16, 512
  generated_image = G.synthesis.forward(w)
  generated_image = (generated_image.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
  width, h = generated_image.size()[1], generated_image.size()[2]

  transform = transforms.Compose([
      transforms.Resize((width,h),interpolation=InterpolationMode.BICUBIC),
      transforms.ToTensor(),
      transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
      ])

  generated_image = PIL.Image.fromarray(generated_image[0].cpu().numpy(), 'RGB').resize((256, 256))
  image = transform(generated_image).unsqueeze(0).to(device)

  with torch.no_grad():

      caption = blip_model.generate(image, sample=False, num_beams=3, max_length=20, min_length=10)
      num_of_beams = 3
      attempts = 0
      print('caption: '+caption[0])
      while not any(keyword in caption[0] for keyword in keywords) and attempts < 5:
        attempts += 1
        num_of_beams += 1
        caption = blip_model.generate(image, sample=False, num_beams=num_of_beams, max_length=40, min_length=10)
        # nucleus sampling
        # caption = blip_model.generate(image, sample=True, top_p=0.9, max_length=20, min_length=5)
        print('caption: '+caption[0])
      print("---------------")
  blipout = caption[0] # BLIP's output (description of a StyleGAN-generated image)
  '''

  blipout = df.iloc[idx, 1] # GET CAPTION FROM DATASET


### CLIP
  text = clip.tokenize([blipout]).to(device)

  with torch.no_grad():
      #image_features = clip_model.encode_image(image)
      text_features = clip_model.encode_text(text)
      #logits_per_image, logits_per_text = clip_model(image, text)
      #probs = logits_per_image.softmax(dim=-1).cpu().numpy()

  #print("Image-text similarity:", logits_per_image.cpu().numpy())  # prints: [[24.815384]]
  #print("Label probabilities:", probs)  # prints: [[0.9927937  0.00421068 0.00299572]]
  #print("Image features (tensor size): ",len(image_features[0]))
  print("Text features (tensor size): ",len(text_features[0]))

  #z = torch.randn([1, G.z_dim]).cpu()   # random latent codes
  clipout = text_features                # THIS IS OUR AIM #TODO
  #it is same size tensor but output of stylegan2 is different imag
  print("clipout", clipout.shape)
  #----
  clipout = clipout.view(1, 1, 512)

  #for i in range(3):
  optimizer.zero_grad()
  a = encoder(clipout)

  # GET W LATENT SPACE FROM DATASET
  tensor_name = "dataset/tensors/" + str(df.iloc[idx, 0]) + ".pt"
  w = torch.load(tensor_name, map_location='cuda:0')
  w.to(device)

  loss = loss_calc_l2(w, torch.reshape(a, (1,16,512)))
  print(loss)
  loss.backward()
  optimizer.step()

  losses.append(loss.item())
  print(f"Iter {idx+1}, Loss: {loss.item()}")

  0%|          | 0/2048 [00:00<?, ?it/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Text features (tensor size):  512
clipout torch.Size([1, 512])
tensor(66.7163, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Iter 799, Loss: 66.71630096435547
Text features (tensor size):  512
clipout torch.Size([1, 512])
tensor(54.9320, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Iter 800, Loss: 54.932029724121094
Text features (tensor size):  512
clipout torch.Size([1, 512])
tensor(55.0960, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Iter 801, Loss: 55.09601593017578
Text features (tensor size):  512
clipout torch.Size([1, 512])
tensor(42.6605, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Iter 802, Loss: 42.66053771972656
Text features (tensor size):  512
clipout torch.Size([1, 512])
tensor(39.9528, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
Iter 803, Loss: 39.95280838012695
Text features (tensor size):  512
clipout torch.Size([1, 512])
tensor(56.4437, device='cuda:0', gra

In [6]:
import matplotlib.pyplot as plt

# Prompts:
captions = ['a man with sunglasses',
            'a man with beard',
            'a smiling baby',
            'a woman with curly hair',
            'a man in a suit and a tie',
            'a little girl',
            'an old man with blue eyes',
            'an old woman',
            'an angry man',
            'a woman with long black hair',
            'a child with brown hair']
# Read dataset/image_captions_cleaned captions with pandas:
dtype_dict = {'name': str, 'caption': str}
df2 = pd.read_csv('dataset/image_captions_cleaned.txt', sep=",", header=0, dtype=dtype_dict)
saved_captions = df2['caption'].tolist()

# "a man with glasses", "a child", "a child with black hairs and blue eyes"

def Inference(caption, current_time, save_fig = True, show_pic = True, index = None):
    encoder.eval()  # Ensure the encoder is in eval mode
    G.eval()  # Ensure the StyleGAN generator is in eval modee
    # Tokenize and encode the text description using CLIP
    #caption = "a child with black hairs and blue eyes"
    tokens = clip.tokenize([caption]).to(device)
    text_embedding = clip_model.encode_text(tokens).float()
    # Use the trained encoder to predict the latent vector
    with torch.no_grad():
        #ard(text_embedding, None)  # Unsqueeze to add batch dimensio
        latent_vector = encoder(text_embedding)  # Unsqueeze
        latent_vector = torch.reshape(latent_vector, (1,16,512)) + 0.3 * G.mapping(torch.randn([1, G.z_dim]).to(device),None)

    # Generate an image using the predicted latent vector with StyleGAN
    generated_image = G.synthesis(torch.reshape(latent_vector, (1,16,512)))
    generated_image = (generated_image.clamp(-1, 1) + 1) / 2  # Normalize to [0, 1]
    generated_image = generated_image.cpu().permute(0, 2, 3, 1).numpy()  # Adjust dimensions for image display
    generated_image = (generated_image * 255).astype('uint8')  # Convert to uint

    # Convert to PIL Image for display
    generated_image = Image.fromarray(generated_image[0])

    # Display image with caption
    #plt.imshow(generated_image)
    #plt.title(caption)
    #plt.axis('off')  # Turn off axis

    # save figure
    if save_fig == True and index == None:
        try:
            generated_image.save(f'inference/{current_time}/{caption}.png') # save image
            #PIL.Image.fromarray(generated_image.to(device).numpy(), 'RGB').save(f'inference/{current_time}/{caption}.png') # save image
            #plt.savefig(f'inference/{current_time}/{caption}.png')
        except:
            try:
                os.mkdir(f'inference/{current_time}')
                generated_image.save(f'inference/{current_time}/{caption}.png') # save image
                #plt.savefig(f'inference/{current_time}/{caption}.png')
                #PIL.Image.fromarray(generated_image.to(device).numpy(), 'RGB').save(f'inference/{current_time}/{caption}.png') # save image
            except:
                os.mkdir('./inference')
                os.mkdir(f'inference/{current_time}')
                generated_image.save(f'inference/{current_time}/{caption}.png') # save image
                #plt.savefig(f'inference/{current_time}/{caption}.png')
                #PIL.Image.fromarray(generated_image.to(device).numpy(), 'RGB').save(f'inference/{current_time}/{caption}.png') # save image
    elif save_fig == True and index != None:
            try:
                generated_image.save(f'inference/{current_time}/{index:05d}.png') # save image
                #PIL.Image.fromarray(generated_image.to(device).numpy(), 'RGB').save(f'inference/{current_time}/{caption}.png') # save image
                #plt.savefig(f'inference/{current_time}/{caption}.png')
            except:
                try:
                    os.mkdir(f'inference/{current_time}')
                    generated_image.save(f'inference/{current_time}/{index:05d}.png') # save image
                    #plt.savefig(f'inference/{current_time}/{caption}.png')
                    #PIL.Image.fromarray(generated_image.to(device).numpy(), 'RGB').save(f'inference/{current_time}/{caption}.png') # save image
                except:
                    os.mkdir('./inference')
                    os.mkdir(f'inference/{current_time}')
                    generated_image.save(f'inference/{current_time}/{index:05d}.png') # save image
                    #plt.savefig(f'inference/{current_time}/{caption}.png')
                    #PIL.Image.fromarray(generated_image.to(device).numpy(), 'RGB').save(f'inference/{current_time}/{caption}.png') # save image

    if show_pic == True:
      plt.show()

In [7]:
# Multiple captions in loop
from datetime import datetime
current_time = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
current_time = 'stylegan2-db'
index = 0
for idx in tqdm(saved_captions):
    index += 1
    Inference(idx, current_time, save_fig = True, show_pic = False, index = index)

import os
print(str(len(os.listdir('/content/inference/'+current_time)))+' images are saved.')

  0%|          | 0/2048 [00:00<?, ?it/s]

Setting up PyTorch plugin "bias_act_plugin"... Done.
Setting up PyTorch plugin "upfirdn2d_plugin"... Done.
2048 images are saved.


In [18]:
!cd inference && zip -r stylegan2-db.zip stylegan2-db/
# To save inference output zip to google drive
from google.colab import drive
drive.mount('/content/drive')
!cp -r /content/inference/stylegan2-db.zip /content/drive/MyDrive/stylegan2-db.zip

  adding: stylegan2-db/ (stored 0%)
  adding: stylegan2-db/00028.png (deflated 0%)
  adding: stylegan2-db/01195.png (deflated 0%)
  adding: stylegan2-db/01462.png (deflated 0%)
  adding: stylegan2-db/01921.png (deflated 0%)
  adding: stylegan2-db/00456.png (deflated 0%)
  adding: stylegan2-db/01895.png (deflated 0%)
  adding: stylegan2-db/00757.png (deflated 0%)
  adding: stylegan2-db/00500.png (deflated 0%)
  adding: stylegan2-db/00813.png (deflated 0%)
  adding: stylegan2-db/00304.png (deflated 0%)
  adding: stylegan2-db/02032.png (deflated 0%)
  adding: stylegan2-db/00778.png (deflated 0%)
  adding: stylegan2-db/00146.png (deflated 0%)
  adding: stylegan2-db/00683.png (deflated 0%)
  adding: stylegan2-db/00242.png (deflated 0%)
  adding: stylegan2-db/01394.png (deflated 0%)
  adding: stylegan2-db/00420.png (deflated 0%)
  adding: stylegan2-db/01214.png (deflated 0%)
  adding: stylegan2-db/00012.png (deflated 0%)
  adding: stylegan2-db/01175.png (deflated 0%)
  adding: stylegan2-db/0