<a href="https://colab.research.google.com/github/nguyenanhtienabcd/AIO2024_EXERCISE/blob/feature%2FMODULE9-WEEK2/m09w02.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chuyển text thành ảnh

### Prepare data, process text

In [1]:
# tải bộ dữ liệu về
!gdown 1JJjMiNieTz7xYs6UeVqd02M3DW4fnEfU

Downloading...
From (original): https://drive.google.com/uc?id=1JJjMiNieTz7xYs6UeVqd02M3DW4fnEfU
From (redirected): https://drive.google.com/uc?id=1JJjMiNieTz7xYs6UeVqd02M3DW4fnEfU&confirm=t&uuid=eda13bba-596a-4551-bc78-360932379056
To: /content/cvpr2016_flowers.zip
100% 351M/351M [00:01<00:00, 218MB/s]


In [2]:
!unzip /content/cvpr2016_flowers.zip

[1;30;43mKết quả truyền trực tuyến bị cắt bớt đến 5000 dòng cuối.[0m
  inflating: content/cvpr2016_flowers/images/image_02041.jpg  
  inflating: content/cvpr2016_flowers/images/image_06550.jpg  
  inflating: content/cvpr2016_flowers/images/image_00556.jpg  
  inflating: content/cvpr2016_flowers/images/image_05695.jpg  
  inflating: content/cvpr2016_flowers/images/image_02732.jpg  
  inflating: content/cvpr2016_flowers/images/image_00112.jpg  
  inflating: content/cvpr2016_flowers/images/image_03424.jpg  
  inflating: content/cvpr2016_flowers/images/image_03702.jpg  
  inflating: content/cvpr2016_flowers/images/image_01316.jpg  
  inflating: content/cvpr2016_flowers/images/image_03391.jpg  
  inflating: content/cvpr2016_flowers/images/image_02503.jpg  
  inflating: content/cvpr2016_flowers/images/image_05044.jpg  
  inflating: content/cvpr2016_flowers/images/image_01633.jpg  
  inflating: content/cvpr2016_flowers/images/image_05305.jpg  
  inflating: content/cvpr2016_flowers/images/im

In [3]:
# load caption
import os

# mục đích: trả về dict(tên ảnh, caption)
def load_captions(caption_path, image_path):
  captions = {}
  image_files = os.listdir(image_path)
  for image_file in image_files:
    image_name = image_file.split(".")[0]
    caption_file = os.path.join(caption_path, image_name + ".txt")
    with open(caption_file, "r") as f:
      caption = f.readlines()[0].strip()
    if image_name not in captions:
      captions[image_name] = caption
  return captions

In [None]:
captions_path = "/content/content/cvpr2016_flowers/captions"

image_path = "/content/content/cvpr2016_flowers/images"

captions = load_captions(captions_path, image_path)
captions

sử dụng mô hình BERTs để biểu diễn mỗi câu mô tả thành một vector có
kích thước 768 chiều

In [None]:
!pip install sentence_transformers

In [None]:
import torch
import numpy as np
from sentence_transformers import SentenceTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
bert_model = SentenceTransformer("all-mpnet-base-v2").to(device)

# mục đích: tạo một hàm trả về caption, embedding
def encode_captions(captions):
  encoded_captions = {}
  for image_name in captions.keys():
    caption = captions[image_name]
    encoded_captions[image_name] = {
        'embed': torch.tensor(bert_model.encode(caption)).to(device),
        'text': caption
    }
  return encoded_captions


In [None]:
encoded_captions = encode_captions(captions)

In [None]:
encoded_captions

In [None]:
print(encoded_captions['image_05881'])

In [None]:
# thực hiện processing
from PIL import Image
from torch.utils.data import Dataset

class FlowerDataset(Dataset):
  def __init__(self, img_dir, captions, transform = None):
    self.img_dir = img_dir
    self.captions = captions
    self.transform = transform
    self.img_names = list(captions.keys())

  def __len__(self):
    return len(self.img_names)

  def __getitem__(self, idx):
    img_name = self.img_names[idx]
    img_path = os.path.join(self.img_dir, img_name + '.jpg')
    img = Image.open(img_path).convert("RGB")
    if self.transform:
      img = self.transform(img)
    encoded_caption = self.captions[img_name]['embed']
    caption = self.captions[img_name]['text']
    # The return statement was below the dictionary, causing it to return None.
    # Move it above the dictionary to return the dictionary instead.
    return {
        'image': img,
        'caption': caption,
        'encoded_caption': encoded_caption
    }

In [None]:
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

IMG_SIZE = 128

transform = transforms.Compose(
    [
     transforms.Resize((IMG_SIZE, IMG_SIZE)),
     transforms.ToTensor(),
     transforms.Normalize([0.5],[0.5])
    ]
)

BATCH_SIZE = 64
image_path = "/content/content/cvpr2016_flowers/images"
ds = FlowerDataset(image_path, encoded_captions, transform)
load_dataset = DataLoader(ds, batch_size = BATCH_SIZE , shuffle = True)

### Xây dựng model

In [None]:
# tạo class Generator
import torch
import torch.nn as nn
import torch.nn.functional as F

class Generator(torch.nn.Module):
  def __init__(self, noise_size, feature_size,
               num_channels, embedding_size,
               reduced_dim_size):
    super(Generator, self).__init__()
    self.reduced_dim_size = reduced_dim_size

    # tạo một biến model => mục đích giảm số chiều 768 -> 256
    self.text_encoder = nn.Sequential(
        nn.Linear(embedding_size, reduced_dim_size),
        nn.BatchNorm1d(reduced_dim_size),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),
        nn.Linear(reduced_dim_size, reduced_dim_size)
    )

    self.upsampling = nn.Sequential(
        # 100 + 256
        nn.ConvTranspose2d(noise_size + reduced_dim_size, feature_size*8, 4, 1, 0, bias = False),
        nn.BatchNorm2d(feature_size*8), # 1024
        nn.ReLU(inplace = True),

        # 1024 -> 512
        nn.ConvTranspose2d(feature_size*8, feature_size*4, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*4), # 512
        nn.ReLU(inplace = True),

        # 512 -> 256
        nn.ConvTranspose2d(feature_size*4, feature_size*2, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*2), # 256
        nn.ReLU(inplace = True),

        # 256 -> 128
        nn.ConvTranspose2d(feature_size*2, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size), # 128
        nn.ReLU(inplace = True),

        # 128 -> 128
        nn.ConvTranspose2d(feature_size, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size), # 128
        nn.ReLU(inplace = True),

        # 128 -> 3
        nn.ConvTranspose2d(128, num_channels, 4, 2, 1, bias = False),
        nn.Tanh()
    )

    # tại sao lại dùng hàm Tanh()?
    # vì ảnh được normalization với chuẩn (0.5, 0.5)

  def forward(self, noise, text_embedding):
    text_encoder = self.text_encoder(text_embedding)
    # tạo một input có shape (batch_size, 256, 1, 1)
    concat_input = torch.cat([noise, text_encoder], dim = 1).unsqueeze(2).unsqueeze(2)
    output = self.upsampling(concat_input)
    return output

In [None]:
# tạo một class disciminator
class Discriminator(torch.nn.Module):
  def __init__(self, num_channels, feature_size, embedding_size, reduced_dim_size):
    super(Discriminator, self).__init__()
    self.reduced_dim_size = reduced_dim_size

    self.img_encoder = nn.Sequential(
        # 3 -> 128
        nn.Conv2d(num_channels, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 128 -> 128
        nn.Conv2d(feature_size, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 128 -> 256
        nn.Conv2d(feature_size, feature_size*2, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*2),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 256 -> 512
        nn.Conv2d(feature_size*2, feature_size*4, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*4),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 512 -> 1024
        nn.Conv2d(feature_size*4, feature_size*8, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*8),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),
    )

    # tạo một block để giảm chiều của text
    self.text_encoder = nn.Sequential(
        nn.Linear(in_features = embedding_size, out_features = reduced_dim_size),
        nn.BatchNorm1d(reduced_dim_size),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),
        )



    # tạo một Block cuối cùng để dự đoán kết quả 0,1
    self.final_layer = nn.Sequential(
        # (1024 + 256) -> 1
        nn.Conv2d(feature_size*8 + reduced_dim_size, 1, 4, 1, 0, bias = False),
        nn.Sigmoid()
    )

  def forward(self, img, text_embedding):
      text_encoder = self.text_encoder(text_embedding)  # shape: (batch_size, reduced_dim_size)
      #text_encoder = text_encoder.unsqueeze(2).unsqueeze(2)  # shape: (batch_size, reduced_dim_size, 1, 1)

      # Encoder for images
      img_encoder = self.img_encoder(img)  # shape: (batch_size, feature_size*8, 4, 4)

      # Get the spatial dimensions of img_encoder
      img_h, img_w = img_encoder.shape[2:]

      # Reshape text_encoder to have 4 dimensions and match spatial dimensions of img_encoder
      # by replicating it across the height and width dimensions
      # New shape: (batch_size, reduced_dim_size, img_h, img_w)
      replicated_text = text_encoder.repeat(img_h, img_w, 1, 1).permute(2, 3, 0, 1)

      # Now you can concatenate along the channel dimension (dim=1)
      concat_input = torch.cat([img_encoder, replicated_text], dim=1)
      x = self.final_layer(concat_input)
      return x.view(-1, 1), img_encoder


In [None]:
generator = Generator(100, 128, 3, 768, 256).to(device)

In [None]:
discriminator = Discriminator(3, 128, 768, 256).to(device)

### plot hiển thị ảnh

In [None]:
# định nghĩa một hàm hiển thị ảnh sau mỗi lần train dữ liệu
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
import numpy as np
import torchvision # Import torchvision

def show_grid(img):
  npimg = img.numpy()
  plt.imshow(np.transpose(npimg, (1, 2, 0)))
  plt.show()

show_grid(torchvision.utils.make_grid(ds[0]['image'], normalize = True))

In [None]:
def plot_output(generator, epoch, fixed_noise, plt_o_text_embeddings):
  plt.clf()  # xóa biểu đồ cũ
  with torch.no_grad():  # không có thực hiện gradient

    generator.eval()
    test_img = generator(fixed_noise.to(device), plt_o_text_embeddings.to(device))

    generator.train()
    grid = torchvision.utils.make_grid(test_img.cpu(), normalize=True)
    show_grid(grid)
    # Lưu hình ảnh với tên file chứa epoch
    plt.savefig(f'image_epoch_{epoch + 1}.png')

In [None]:
# mục đích của đoạn code này để đọc dữ liệu của một batch
batch_sample = next(iter(load_dataset))
batch_sample['image'].shape

In [None]:
show_grid(torchvision.utils.make_grid(batch_sample['image'], normalize = True))

### Train model

In [None]:
plt_o_text_embeddings = ds[0]['encoded_caption']
print(plt_o_text_embeddings.shape)
plt_o_text_embeddings = plt_o_text_embeddings.unsqueeze(0)
print(plt_o_text_embeddings.shape)

In [None]:
latent_dim = 100
fixed_noise = torch.randn(size=(1, latent_dim))
print(fixed_noise.shape)

In [None]:
bce_loss = nn.BCELoss()
l2_loss = nn.MSELoss()
l1_loss = nn.L1Loss()

In [None]:
import torch.optim as optim

optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

In [None]:
import time

epochs = 500

for epoch in range(epochs):
    lambda_1 = 50
    lambda_2 = 100

    d_losses, g_losses = [], []

    epoch_time = time.time()

    for batch in load_dataset:
        running_D_loss = 0.0
        running_G_loss = 0.0

        real_imgs = batch['image'].to(device)
        embeded_text = batch['encoded_caption'].to(device)


        # labels
        real_labels = torch.ones(batch['image'].size(0),1, device = device)
        fake_labels = torch.zeros(batch['image'].size(0),1, device = device)

        # -------------------------- Train Disciminator ---
        optimizer_D.zero_grad()

        # Noise input for Disciminator
        noise = torch.randn((batch['image'].size(0), latent_dim)).to(device)
        fake_imgs = generator(noise, embeded_text)

        # tính toán real loss
        real_outputs, _ = discriminator(real_imgs, embeded_text)
        real_loss = bce_loss(real_outputs, real_labels)

        # tính toán fake loss
        fake_outputs, _ = discriminator(fake_imgs.detach(), embeded_text)
        fake_loss = bce_loss(fake_outputs, fake_labels)


        D_loss = real_loss + fake_loss
        d_losses.append(D_loss.item())

        # update weight
        D_loss.backward()
        optimizer_D.step()



        # -------------- Train Generator ---
        optimizer_G.zero_grad()

        # Noise input for Generator
        noise = torch.randn((batch['image'].size(0), latent_dim)).to(device)
        fake_imgs = generator(noise, embeded_text)

        # loss sau G
        l1_loss_value = l1_loss(real_imgs, fake_imgs)


        # fake_features: các features ảnh được tổng hợp lại cuối cùng
        fake_g_outputs, fake_features = discriminator(fake_imgs.detach(), embeded_text)
        _, real_features = discriminator(real_imgs, embeded_text)
        # tính trung bình các đặc trưng ảnh fake và ảnh real của 1 batch
        activation_fake = torch.mean(fake_features, dim = 0)
        activation_real = torch.mean(real_features, dim = 0)
        l2_loss_value = l2_loss(activation_fake, activation_real)

        # real loss sau D -> G (làm cho D tin rằng G tạo ra ảnh thật)
        real_loss = bce_loss(fake_g_outputs, real_labels)
        G_loss = real_loss + lambda_1*l1_loss_value + lambda_2*l2_loss_value



        GAN_loss = bce_loss(fake_g_outputs, real_labels)


        G_loss = GAN_loss + lambda_1*l1_loss_value + lambda_2*l2_loss_value
        g_losses.append(G_loss.item())

        # update weight
        G_loss.backward()
        optimizer_G.step()



    avg_d_loss = sum(d_losses)/len(d_losses)
    avg_g_loss = sum(g_losses)/len(g_losses)

    if (epoch+1) % 10 == 0:
        plot_output(generator, epoch, noise[0], embeded_text[0])


    print('Epoch [{}/{}] loss_D: {:.4f} loss_G: {:.4f} time: {:.2f}'.format(
        epoch+1, epochs,
        avg_d_loss,
        avg_g_loss,
        time.time() - epoch_time)
    )

# sau khi train xong ta save model lại
model_save_path = "./model_save"
torch.save(generator.state_dict(), os.path.join(model_save_path,'generator.pth'))
torch.save(discriminator.state_dict(), os.path.join(model_save_path,'discriminator.pth'))

### Inference

In [None]:
generator.eval()

noise = torch.randn(size=(1, 100))
text_embedding = ds[10]['encoded_caption'].unsqueeze(0)

with torch.no_grad():
    test_images = generator(noise.to(device), text_embedding.to(device))
grid = torchvision.utils.make_grid(test_images.cpu(), normalize=True)
show_grid(grid)

## Test

In [None]:
# Instead of trying to access elements directly like this:
# load_dataset[0]

# Use the following code to get the first batch from the DataLoader:
for batch in load_dataset:
    data = batch  # Access the data from the first batch
    break  # Exit the loop after getting the first batch

# Now you can access elements within the batch 'data' using indexing, for example:
image = data['image'][0]  # Get the first image from the batch
caption = data['caption'][0]  # Get the first caption from the batch

In [None]:
print(data['image'].shape)

In [None]:
embeded_text = data['encoded_caption'][0].unsqueeze(0).repeat(64,1)
print(embeded_text.shape)

In [None]:
noise = torch.randn((data['image'].shape[0], latent_dim), device = device)
print(noise.shape)


In [None]:
feature_size = 128
text_encoder = nn.Sequential(
        nn.Linear(embeded_text.shape[1], 256),
        nn.BatchNorm1d(256),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),
        nn.Linear(256, 256)
    ).to(device)

upsampling = nn.Sequential(
        # 100 + 256
        nn.ConvTranspose2d(noise.shape[1] + 256, feature_size*8, 4, 1, 0, bias = False),
        nn.BatchNorm2d(feature_size*8), # 1024
        nn.ReLU(inplace = True),

        # 1024 -> 512
        nn.ConvTranspose2d(feature_size*8, feature_size*4, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*4), # 512
        nn.ReLU(inplace = True),

        # 512 -> 256
        nn.ConvTranspose2d(feature_size*4, feature_size*2, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*2), # 256
        nn.ReLU(inplace = True),

        # 256 -> 128
        nn.ConvTranspose2d(feature_size*2, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size), # 128
        nn.ReLU(inplace = True),

        # 128 -> 3
        nn.ConvTranspose2d(128, 3, 4, 2, 1, bias = False),
        nn.Tanh()
).to(device)
text_encoder.eval()
text_encoder = text_encoder(embeded_text)
# tạo một input có shape (batch_size, 256, 1, 1)
concat_input = torch.cat([noise, text_encoder], dim = 1).unsqueeze(2).unsqueeze(2)
output = upsampling(concat_input)

In [None]:
print(output.shape)

In [None]:
num_channels = 3
feature_size = 128
#negative_slope = 0.2
model_test = nn.Sequential(
        # 3 -> 128
        nn.Conv2d(num_channels, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 128 -> 128
        nn.Conv2d(feature_size, feature_size, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 128 -> 256
        nn.Conv2d(feature_size, feature_size*2, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*2),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 256 -> 512
        nn.Conv2d(feature_size*2, feature_size*4, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*4),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),

        # 512 -> 1024
        nn.Conv2d(feature_size*4, feature_size*8, 4, 2, 1, bias = False),
        nn.BatchNorm2d(feature_size*8),
        nn.LeakyReLU(negative_slope = 0.2, inplace = True),
)



In [None]:
image_1 = data['encoded_caption']
fake_imgs = generator(noise.to(device), image_1.to(device))
print(fake_imgs.shape)

In [None]:
print(image_1.shape)

In [None]:
#image_1 = data['image'][0]
image_1 = data['encoded_caption']
fake_imgs = generator(noise.to(device), image_1.to(device))
print(fake_imgs.shape)
model_test = model_test.to(device)
image_1 = image_1.to(device)
image_1 = image_1.unsqueeze(0)
print(image_1.shape)
a = model_test(image_1)
print(a.shape)