In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from nltk.tokenize import word_tokenize
import nltk
import torch.optim as optim
import torch.nn as nn

nltk.download('punkt')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\hardi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
def load_glove_embeddings(glove_file):
    embeddings_index = {}
    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

# Load GloVe embeddings
glove_file = 'goolove\glove.6B.300d.txt'
glove_model = load_glove_embeddings(glove_file)

  glove_file = 'goolove\glove.6B.300d.txt'


In [3]:
class StackGANDataset(Dataset):
    def __init__(self, root_dir, transform=None, img_size=(64, 64), glove_model=None):
        self.root_dir = root_dir
        self.transform = transform
        self.data = []
        self.img_size = img_size
        self.glove_model = glove_model

        for subdir in os.listdir(root_dir):
            subdir_path = os.path.join(root_dir, subdir)
            if os.path.isdir(subdir_path):
                try:
                    img_file = [f for f in os.listdir(subdir_path) if f.endswith('.jpg')][0]
                    txt_file = [f for f in os.listdir(subdir_path) if f.endswith('.txt')][0]
                    self.data.append((os.path.join(subdir_path, img_file), os.path.join(subdir_path, txt_file)))
                except IndexError as e:
                    print(f"Error: {e}, in directory {subdir_path}")
                except Exception as e:
                    print(f"Unexpected error: {e}, in directory {subdir_path}")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        try:
            img_path, txt_path = self.data[idx]
            image = Image.open(img_path).convert('RGB')

            image = image.resize(self.img_size)  # Resize the image
            if self.transform:
                image = self.transform(image)

            with open(txt_path, 'r') as f:
                text = f.read().strip()

            text_embedding = self.text_to_embedding(text)
            return image, text_embedding
        except Exception as e:
            print(f"Error loading item at index {idx}: {e}")
            raise

    def text_to_embedding(self, text):
        words = word_tokenize(text.lower())
        embeddings = [self.glove_model[word] for word in words if word in self.glove_model]
        if embeddings:
            text_embedding = np.mean(embeddings, axis=0)
        else:
            text_embedding = np.zeros(len(next(iter(self.glove_model.values()))))
        return text_embedding


In [4]:
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # for Stage-I
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# Create dataset and dataloader
dataset = StackGANDataset('stackgan_output', transform=transform, img_size=(128, 128), glove_model=glove_model)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=0)  # Changed num_workers to 0


In [5]:

class ConditionalAugmentation(nn.Module):
    def __init__(self, text_dim, projected_dim):
        super(ConditionalAugmentation, self).__init__()
        self.proj = nn.Linear(text_dim, projected_dim * 2)

    def forward(self, text_embedding):
        mu_logvar = self.proj(text_embedding)
        mu, logvar = mu_logvar.chunk(2, dim=1)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

class Generator_Stage1(nn.Module):
    def __init__(self, noise_dim, text_dim, projected_dim):
        super(Generator_Stage1, self).__init__()
        self.ca = ConditionalAugmentation(text_dim, projected_dim)
        self.fc = nn.Linear(noise_dim + projected_dim, 256 * 4 * 4)
        self.main = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), # output: 8x8
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Dropout(0.3),
            nn.ConvTranspose2d(128, 64, 4, 2, 1), # output: 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 32, 4, 2, 1), # output: 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.ConvTranspose2d(32, 3, 4, 2, 1), # output: 64x64
            nn.Tanh()
        )

    def forward(self, noise, text_embedding):
        cond_code = self.ca(text_embedding)
        z = torch.cat([noise, cond_code], dim=1)
        out = self.fc(z)
        out = out.view(-1, 256, 4, 4)
        out = self.main(out)
        return out

class Discriminator_Stage1(nn.Module):
    def __init__(self, text_dim):
        super(Discriminator_Stage1, self).__init__()
        self.img_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.3),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.text_proj = nn.Linear(text_dim, 256)
        self.classifier = nn.Sequential(
            nn.Conv2d(512, 512, 3, 1, 1),  # Changed kernel size to 3
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, img, text_embedding):
        noise = torch.randn_like(img) * 0.1  # Add small noise to images
        img_features = self.img_encoder(img + noise)
        # img_features = self.img_encoder(img)
        text_features = self.text_proj(text_embedding)

        text_features = text_features.view(-1, 256, 1, 1)
        text_features = text_features.repeat(1, 1, img_features.size(2), img_features.size(3))

        features = torch.cat([img_features, text_features], dim=1)
        out = self.classifier(features)
        
        # Average over the spatial dimensions
        out = out.view(out.size(0), -1).mean(dim=1, keepdim=True)
        return out


In [7]:
# import torch
import torch.autograd as autograd

def compute_gradient_penalty(D, real_samples, fake_samples, text_embeddings, lambda_gp=10.0):
    """
    Computes the gradient penalty for WGAN-GP.
    
    Parameters:
        D (torch.nn.Module): The discriminator model.
        real_samples (torch.Tensor): Batch of real samples.
        fake_samples (torch.Tensor): Batch of fake samples.
        text_embeddings (torch.Tensor): Text embeddings corresponding to the samples.
        lambda_gp (float): Coefficient for the gradient penalty term.
    
    Returns:
        torch.Tensor: The computed gradient penalty.
    """
    # Random weight term for interpolation between real and fake samples
    alpha = torch.rand(real_samples.size(0), 1, 1, 1, device=real_samples.device)
    alpha = alpha.expand_as(real_samples)

    # Interpolated samples
    interpolates = alpha * real_samples + ((1 - alpha) * fake_samples)
    interpolates = autograd.Variable(interpolates, requires_grad=True)

    # Calculate discriminator's output for interpolated samples
    d_interpolates = D(interpolates, text_embeddings)

    # Get gradients of outputs with respect to interpolated samples
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones_like(d_interpolates),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # Compute the L2 norm of the gradients
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * lambda_gp

    return gradient_penalty


In [6]:

# Model initialization
noise_dim = 100
projected_dim = 128
text_dim = 300

netG1 = Generator_Stage1(noise_dim, text_dim, projected_dim)
netD1 = Discriminator_Stage1(text_dim)

In [8]:
# Optimizers
optimizerD1 = optim.Adam(netD1.parameters(), lr=0.000001, betas=(0.5, 0.999))
optimizerG1 = optim.Adam(netG1.parameters(), lr=0.00005, betas=(0.5, 0.999))

# Loss function
# criterion = nn.BCELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG1.to(device)
netD1.to(device)
# criterion.to(device)

Discriminator_Stage1(
  (img_encoder): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Dropout(p=0.3, inplace=False)
    (6): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace=True)
  )
  (text_proj): Linear(in_features=300, out_features=256, bias=True)
  (classifier): Sequential(
    (0): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2, inplace=True)
    (3): Dropout(p=0.4, 

In [9]:
import torch.optim.lr_scheduler as lr_scheduler

schedulerD = lr_scheduler.ExponentialLR(optimizerD1, gamma=0.95)
schedulerG = lr_scheduler.ExponentialLR(optimizerG1, gamma=1.05)

In [None]:

# # Training loop
# num_epochs = 20
# for epoch in range(num_epochs):
#     for i, (real_imgs, text_embeddings) in enumerate(dataloader):
#         real_imgs = real_imgs.to(device)
#         text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32).to(device)

#         valid = torch.ones(len(real_imgs), 1).to(device)
#         fake = torch.zeros(len(real_imgs), 1).to(device)

#         # ------------------
#         #  Train Generator
#         # ------------------

#         optimizerG1.zero_grad()

#         noise = torch.randn(len(real_imgs), noise_dim).to(device)
#         gen_imgs = netG1(noise, text_embeddings)

#         g_loss = criterion(netD1(gen_imgs, text_embeddings), valid)
#         g_loss.backward()
#         optimizerG1.step()

#         # ---------------------
#         #  Train Discriminator
#         # ---------------------

#         optimizerD1.zero_grad()

#         real_loss = criterion(netD1(real_imgs, text_embeddings), valid)
#         fake_loss = criterion(netD1(gen_imgs.detach(), text_embeddings), fake)
#         d_loss = (real_loss + fake_loss) / 2

#         d_loss.backward()
#         optimizerD1.step()

#         print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(dataloader)}] Loss D: {d_loss.item()}, loss G: {g_loss.item()}")


In [69]:
num_epochs = 25
k = 10
lambda_gp = 0

for epoch in range(num_epochs):
    for i, (real_imgs, text_embeddings) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32).to(device)

        # Train Discriminator multiple times
        for _ in range(k):
            optimizerD1.zero_grad()

            noise = torch.randn(len(real_imgs), noise_dim).to(device)
            fake_imgs = netG1(noise, text_embeddings)

            real_validity = netD1(real_imgs, text_embeddings)
            fake_validity = netD1(fake_imgs.detach(), text_embeddings)

            gradient_penalty = compute_gradient_penalty(netD1, real_imgs.data, fake_imgs.data, text_embeddings)
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gradient_penalty

            d_loss.backward()
            optimizerD1.step()

        # Train Generator once
        optimizerG1.zero_grad()

        noise = torch.randn(len(real_imgs), noise_dim).to(device)
        fake_imgs = netG1(noise, text_embeddings)

        gen_validity = netD1(fake_imgs, text_embeddings)
        g_loss = -torch.mean(gen_validity)

        g_loss.backward()
        optimizerG1.step()

        print(f"Epoch [{epoch+1}/{num_epochs}] Batch [{i+1}/{len(dataloader)}] Loss D: {d_loss.item()}, loss G: {g_loss.item()}")
    
    # schedulerD.step()
    # schedulerG.step()


  text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32).to(device)


Epoch [1/25] Batch [1/469] Loss D: -0.9999001622200012, loss G: -5.1940525736426935e-05
Epoch [1/25] Batch [2/469] Loss D: -0.9999231696128845, loss G: -1.6817999494378455e-05
Epoch [1/25] Batch [3/469] Loss D: -0.999898374080658, loss G: -6.989760004216805e-05
Epoch [1/25] Batch [4/469] Loss D: -0.9998424053192139, loss G: -0.0001344193733530119
Epoch [1/25] Batch [5/469] Loss D: -0.9998556971549988, loss G: -3.5350789403310046e-05
Epoch [1/25] Batch [6/469] Loss D: -0.9997780323028564, loss G: -0.00015075775445438921
Epoch [1/25] Batch [7/469] Loss D: -0.9995471239089966, loss G: -7.33557462808676e-05
Epoch [1/25] Batch [8/469] Loss D: -0.9996887445449829, loss G: -0.00021608785027638078
Epoch [1/25] Batch [9/469] Loss D: -0.9994522333145142, loss G: -0.00042240769835188985
Epoch [1/25] Batch [10/469] Loss D: -0.9993296265602112, loss G: -0.0001739794824970886
Epoch [1/25] Batch [11/469] Loss D: -0.9988899827003479, loss G: -0.0007432159618474543
Epoch [1/25] Batch [12/469] Loss D: -

In [71]:
# Save the trained Stage-I models
torch.save(netG1.state_dict(), 'netG1.pth')
torch.save(netD1.state_dict(), 'netD1.pth')

In [27]:
# Define Generator Stage-II
class Generator_Stage2(nn.Module):
    def __init__(self, text_dim, projected_dim):
        super(Generator_Stage2, self).__init__()
        self.ca = ConditionalAugmentation(text_dim, projected_dim)
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.Dropout(0.4),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512 + projected_dim, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Dropout(0.4),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, img, text_embedding):
        cond_code = self.ca(text_embedding)
        img_features = self.encoder(img)
        cond_code = cond_code.view(-1, cond_code.size(1), 1, 1).repeat(1, 1, img_features.size(2), img_features.size(3))
        features = torch.cat([img_features, cond_code], dim=1)
        out = self.decoder(features)
        return out

# Define Discriminator Stage-II
class Discriminator_Stage2(nn.Module):
    def __init__(self, text_dim):
        super(Discriminator_Stage2, self).__init__()
        self.img_encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.text_proj = nn.Linear(text_dim, 512)
        self.classifier = nn.Sequential(
            nn.Conv2d(1024, 512, 3, 1, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Dropout(0.4),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

    def forward(self, img, text_embedding):
        img_features = self.img_encoder(img)
        text_features = self.text_proj(text_embedding)
        text_features = text_features.view(-1, 512, 1, 1)
        text_features = text_features.repeat(1, 1, img_features.size(2), img_features.size(3))
        features = torch.cat([img_features, text_features], dim=1)
        out = self.classifier(features)
        
        # Average over the spatial dimensions
        out = out.view(out.size(0), -1).mean(dim=1, keepdim=True)
        return out


In [26]:
# Model initialization
projected_dim = 128
text_dim = 300

netG2 = Generator_Stage2(text_dim, projected_dim)
netD2 = Discriminator_Stage2(text_dim)

# Load the Stage-I generator model
netG1 = Generator_Stage1(noise_dim=100, text_dim=text_dim, projected_dim=projected_dim)
netG1.load_state_dict(torch.load('netG1.pth'))

<All keys matched successfully>

In [28]:
optimizerD2 = optim.Adam(netD2.parameters(), lr=0.000001, betas=(0.5, 0.999))
optimizerG2 = optim.Adam(netG2.parameters(), lr=0.0009, betas=(0.5, 0.999))

# Loss function
# criterion = nn.BCELoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG2.to(device)
netD2.to(device)
netG1.to(device)
# criterion.to(device)

Generator_Stage1(
  (ca): ConditionalAugmentation(
    (proj): Linear(in_features=300, out_features=256, bias=True)
  )
  (fc): Linear(in_features=228, out_features=4096, bias=True)
  (main): Sequential(
    (0): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.3, inplace=False)
    (4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (8): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): ConvTranspose2d(32, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (11): Tanh()
  )
)

In [29]:
num_epochs = 100
k = 10 # Number of discriminator updates per generator update
lambda_gp = 0  # Gradient penalty lambda hyperparameter

for epoch in range(num_epochs):
    for i, (real_imgs, text_embeddings) in enumerate(dataloader):
        real_imgs = real_imgs.to(device)
        text_embeddings = text_embeddings.to(device).float()
        batch_size = real_imgs.size(0)
        noise = torch.randn(batch_size, 100, device=device)
        
        # Train the discriminator k times
        for _ in range(k):
            ############################
            # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
            # Wasserstein loss: maximize D(real) - D(fake)
            ###########################
            netD2.zero_grad()

            # Generate fake images with Stage-I Generator
            with torch.no_grad():
                fake_imgs_stage1 = netG1(noise, text_embeddings)

            # Generate high-res fake images with Stage-II Generator
            fake_imgs = netG2(fake_imgs_stage1, text_embeddings)

            # Train with real images
            real_validity = netD2(real_imgs, text_embeddings).view(-1)
            fake_validity = netD2(fake_imgs.detach(), text_embeddings).view(-1)

            # Gradient penalty
            # gradient_penalty = compute_gradient_penalty(netD2, real_imgs, fake_imgs, text_embeddings)
            
            # Wasserstein loss with gradient penalty
            d_loss = -torch.mean(real_validity) + torch.mean(fake_validity) 
            d_loss.backward()
            optimizerD2.step()

        # Train the generator once
        ############################
        # (2) Update G network: maximize log(D(G(z)))
        # Wasserstein loss: maximize D(fake)
        ###########################
        netG2.zero_grad()

        # Generate fake images with Stage-I Generator
        fake_imgs_stage1 = netG1(noise, text_embeddings)
        # Generate high-res fake images with Stage-II Generator
        fake_imgs = netG2(fake_imgs_stage1, text_embeddings)
        
        fake_validity = netD2(fake_imgs, text_embeddings).view(-1)
        g_loss = -torch.mean(fake_validity)
        g_loss.backward()
        optimizerG2.step()
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "f"D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")


Epoch [1/100], Step [1/469], D Loss: -0.05874216556549072, G Loss: -0.5390399694442749
Epoch [1/100], Step [2/469], D Loss: -0.07477074861526489, G Loss: -0.5912109613418579
Epoch [1/100], Step [3/469], D Loss: -0.060169339179992676, G Loss: -0.652744710445404
Epoch [1/100], Step [4/469], D Loss: -0.12379884719848633, G Loss: -0.6423512697219849
Epoch [1/100], Step [5/469], D Loss: -0.08259940147399902, G Loss: -0.7153561115264893
Epoch [1/100], Step [6/469], D Loss: -0.058488428592681885, G Loss: -0.771367609500885
Epoch [1/100], Step [7/469], D Loss: -0.05841684341430664, G Loss: -0.7913858890533447
Epoch [1/100], Step [8/469], D Loss: -0.07330179214477539, G Loss: -0.8095190525054932
Epoch [1/100], Step [9/469], D Loss: -0.06748062372207642, G Loss: -0.8194131851196289
Epoch [1/100], Step [10/469], D Loss: -0.058943212032318115, G Loss: -0.8349072337150574
Epoch [1/100], Step [11/469], D Loss: -0.07659149169921875, G Loss: -0.8405243754386902
Epoch [1/100], Step [12/469], D Loss: -0

KeyboardInterrupt: 

In [30]:
# Save the trained Stage-I models
torch.save(netG2.state_dict(), 'netG2.pth')
torch.save(netD2.state_dict(), 'netD2.pth')

In [23]:
def load_glove_embeddings(glove_file):
    embeddings_index = {}
    with open(glove_file, 'r', encoding='utf-8') as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype='float32')
            embeddings_index[word] = coefs
    return embeddings_index

glove_file = 'goolove\glove.6B.300d.txt'
glove_model = load_glove_embeddings(glove_file)
text_dim = 300  # GloVe embedding dimension
projected_dim = 128
noise_dim = 100

  glove_file = 'goolove\glove.6B.300d.txt'


In [31]:

# Load the trained models
netG1 = Generator_Stage1(noise_dim, text_dim, projected_dim)
netG1.load_state_dict(torch.load('netG1.pth'))
netG1.eval()

netG2 = Generator_Stage2(text_dim, projected_dim)
netG2.load_state_dict(torch.load('netG2.pth'))
netG2.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
netG1.to(device)
netG2.to(device)

nltk.download('punkt')

# Function to get text embeddings using GloVe
def get_text_embedding(text, glove_model):
    words = word_tokenize(text.lower())
    embeddings = [glove_model[word] for word in words if word in glove_model]
    if embeddings:
        text_embedding = np.mean(embeddings, axis=0)
    else:
        text_embedding = np.zeros(len(next(iter(glove_model.values()))))
    text_embedding = torch.tensor(text_embedding, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
    return text_embedding.to(device)

# Function to generate and save image from text
def generate_image_from_text(text, noise_dim, glove_model):
    text_embedding = get_text_embedding(text, glove_model)
    
    noise = torch.randn(1, noise_dim).to(device)
    with torch.no_grad():
        fake_img_stage1 = netG1(noise, text_embedding)
        fake_img_stage2 = netG2(fake_img_stage1, text_embedding)

    # Convert the generated image to a PIL image and save
    img = fake_img_stage2.squeeze().cpu().numpy()
    img = np.transpose(img, (1, 2, 0))
    img = (img + 1) / 2.0 * 255  # Rescale to [0, 255]
    img = img.astype(np.uint8)
    img = Image.fromarray(img)
    
    # img.save('generated_image.png')
    return img

# Example usage
text_input = "brown blue bull horn long tail"
generated_image = generate_image_from_text(text_input, noise_dim, glove_model)
generated_image.show()

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\hardi\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
