<a href="https://colab.research.google.com/github/divyadharshini1286/CodeGeneratorAI/blob/main/Untitled4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

# --- Hyperparameters ---
latent_dim = 100
text_embedding_dim = 64
image_size = 64  # Output image size (square)
channels = 3    # RGB images
batch_size = 32
epochs = 10
learning_rate = 0.0002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Dummy Text-Image Dataset ---
# In a real scenario, you'd have a dataset of (text, image) pairs
class DummyTextImageDataset(Dataset):
    def __init__(self, text_image_pairs, transform=None):
        self.text_image_pairs = text_image_pairs
        self.transform = transform

    def __len__(self):
        return len(self.text_image_pairs)

    def __getitem__(self, idx):
        text, image = self.text_image_pairs[idx]
        # Very basic text "encoding" (mean of character ordinals)
        text_encoded = torch.tensor([ord(char) for char in text], dtype=torch.float32).mean().unsqueeze(0)
        if self.transform:
            image = self.transform(image)
        return text_encoded, image

# Create some dummy data
dummy_data = [("cat", torch.randn(3, image_size, image_size)),
              ("dog", torch.randn(3, image_size, image_size)),
              ("bird", torch.randn(3, image_size, image_size)),
              ("flower", torch.randn(3, image_size, image_size))]
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = DummyTextImageDataset(dummy_data, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# --- Simple Text Encoder ---
class SimpleTextEncoder(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(SimpleTextEncoder, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, text):
        return self.linear(text)

# --- Conditional Generator ---
class ConditionalGenerator(nn.Module):
    def __init__(self, latent_dim, text_embedding_dim, channels, image_size):
        super(ConditionalGenerator, self).__init__()
        self.init_size = image_size // 4
        self.linear_noise = nn.Linear(latent_dim, self.init_size * self.init_size * 256)
        self.linear_text = nn.Linear(text_embedding_dim, self.init_size * self.init_size * 256)

        self.conv_blocks = nn.Sequential(
            nn.ConvTranspose2d(512, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, channels, 4, 2, 1),
            nn.Tanh()
        )

    def forward(self, noise, text_embedding):
        noise_proj = self.linear_noise(noise).view(-1, 256, self.init_size, self.init_size)
        text_proj = self.linear_text(text_embedding).view(-1, 256, self.init_size, self.init_size)
        combined = torch.cat((noise_proj, text_proj), dim=1)
        return self.conv_blocks(combined)

# --- Conditional Discriminator ---
class ConditionalDiscriminator(nn.Module):
    def __init__(self, channels, image_size, text_embedding_dim):
        super(ConditionalDiscriminator, self).__init__()

        self.conv_blocks = nn.Sequential(
            nn.Conv2d(channels + 1, 128, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, 4, 2, 1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, 4, 1, 0),
            nn.Sigmoid()
        )

        self.linear_text = nn.Linear(text_embedding_dim, image_size * image_size)

    def forward(self, image, text_embedding):
        text_proj = self.linear_text(text_embedding).view(-1, 1, image.size(2), image.size(3))
        combined = torch.cat((image, text_proj), dim=1)
        return self.conv_blocks(combined).view(-1, 1)

# --- Initialize Networks and Optimizers ---
text_encoder = SimpleTextEncoder(1, text_embedding_dim).to(device)
generator = ConditionalGenerator(latent_dim, text_embedding_dim, channels, image_size).to(device)
discriminator = ConditionalDiscriminator(channels, image_size, text_embedding_dim).to(device)
optimizer_G = optim.Adam(list(generator.parameters()) + list(text_encoder.parameters()), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
fixed_noise = torch.randn(4, latent_dim).to(device)
# Example fixed text embeddings
fixed_texts = torch.tensor([ord(char) for char in "cat"], dtype=torch.float32).mean().unsqueeze(0).repeat(4, 1).to(device)
fixed_text_embeddings = text_encoder(fixed_texts)

# --- Visualization Function (Very Basic) ---
def visualize_generated_images(generator, text_encoder, fixed_noise, fixed_texts, epoch):
    generator.eval()
    with torch.no_grad():
        fixed_text_embeddings_vis = text_encoder(fixed_texts)
        fake_images = generator(fixed_noise, fixed_text_embeddings_vis)
        fake_images = (fake_images + 1) / 2.0  # Rescale to [0, 1] for display
        fake_images = fake_images.cpu().numpy().transpose(0, 2, 3, 1)

    fig, axes = plt.subplots(2, 2, figsize=(4, 4))
    for i, ax in enumerate(axes.flat):
        ax.imshow(fake_images[i])
        ax.axis('off')
    plt.suptitle(f"Generated (Text: 'cat') at Epoch {epoch+1}")
    plt.tight_layout()
    plt.show()
    generator.train()

# --- Training Loop (Very Basic) ---
for epoch in range(epochs):
    for i, (text_encoded, real_images) in enumerate(train_loader):
        real_images = real_images.to(device)
        text_encoded = text_encoded.to(device)
        batch_size = real_images.size(0)
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # --- Train Discriminator ---
        optimizer_D.zero_grad()
        text_embedding_real = text_encoder(text_encoded)
        output_real = discriminator(real_images, text_embedding_real)
        loss_real = criterion(output_real, real_labels)

        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise, text_embedding_real.detach())
        output_fake = discriminator(fake_images.detach(), text_embedding_real.detach())
        loss_fake = criterion(output_fake, fake_labels)
        loss_D = loss_real + loss_fake
        loss_D.backward()
        optimizer_D.step()

        # --- Train Generator ---
        optimizer_G.zero_grad()
        noise = torch.randn(batch_size, latent_dim).to(device)
        fake_images = generator(noise, text_embedding_real)
        output_fake = discriminator(fake_images, text_embedding_real)
        loss_G = criterion(output_fake, real_labels)
        loss_G.backward()
        optimizer_G.step()

        if (i + 1) % 10 == 0:
            print(f"Epoch [{epoch+1}/{epochs}], Step [{i+1}/{len(train_loader)}], Loss D: {loss_D.item():.4f}, Loss G: {loss_G.item():.4f}")

    visualize_generated_images(generator, text_encoder, fixed_noise, fixed_texts, epoch)

print("Training finished!")

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

TypeError: pic should be PIL Image or ndarray. Got <class 'torch.Tensor'>