In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

class CVAE(nn.Module):
    def __init__(self, image_dim, text_dim, latent_dim):
        super(CVAE, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),  # 128 -> 64
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 64 -> 32
            nn.ReLU(),
            nn.Flatten()
        )
        flattened_dim = 128 * (image_dim // 4) * (image_dim // 4)  # 128 -> 32 (downsampled by 4)
        self.fc_mu = nn.Linear(flattened_dim + text_dim, latent_dim)
        self.fc_logvar = nn.Linear(flattened_dim + text_dim, latent_dim)

        # Decoder
        self.fc = nn.Linear(latent_dim + text_dim, flattened_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 32 -> 64
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),   # 64 -> 128
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel_size=3, stride=1, padding=1),  # 128 -> 128
            nn.Sigmoid()
        )

    def encode(self, x, c):
        img_features = self.encoder(x)
        if c.dim() == 1:
            c = c.unsqueeze(0).repeat(x.size(0), 1)
        combined = torch.cat([img_features, c], dim=1)
        mu = self.fc_mu(combined)
        logvar = self.fc_logvar(combined)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        if c.dim() == 1:
            c = c.unsqueeze(0).repeat(z.size(0), 1)
        combined = torch.cat([z, c], dim=1)
        out = self.fc(combined).view(z.size(0), 128, 32, 32)
        out = self.decoder(out)
        return out

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        return self.decode(z, c), mu, logvar


def print_layer_shapes(model, input_tensors):
    """
    Print input/output shapes of each layer using forward hooks.
    """
    hooks = []

    def hook_fn(module, input, output):
        print(f"{module.__class__.__name__}:")
        print(f"  Input shape: {[i.shape for i in input]}")
        print(f"  Output shape: {output.shape if isinstance(output, torch.Tensor) else [o.shape for o in output]}")

    # Register hooks for each layer
    for name, layer in model.named_modules():
        if len(list(layer.children())) == 0:  # Register only on leaf layers
            hooks.append(layer.register_forward_hook(hook_fn))

    # Pass dummy inputs
    with torch.no_grad():
        model(*input_tensors)

    # Remove hooks
    for hook in hooks:
        hook.remove()


# Initialize model
image_dim = 128
text_dim = 768
latent_dim = 256
cvae = CVAE(image_dim=image_dim, text_dim=text_dim, latent_dim=latent_dim)

# Dummy inputs
dummy_image = torch.randn(32, 3, image_dim, image_dim)  # Batch size of 2, 3-channel image
dummy_text_embedding = torch.randn(32, text_dim)  # Batch size of 2, text embedding size

# Print layer shapes
print("Network architecture and layer shapes:")
print_layer_shapes(cvae, (dummy_image, dummy_text_embedding))


Network architecture and layer shapes:
Conv2d:
  Input shape: [torch.Size([32, 3, 128, 128])]
  Output shape: torch.Size([32, 64, 64, 64])
ReLU:
  Input shape: [torch.Size([32, 64, 64, 64])]
  Output shape: torch.Size([32, 64, 64, 64])
Conv2d:
  Input shape: [torch.Size([32, 64, 64, 64])]
  Output shape: torch.Size([32, 128, 32, 32])
ReLU:
  Input shape: [torch.Size([32, 128, 32, 32])]
  Output shape: torch.Size([32, 128, 32, 32])
Flatten:
  Input shape: [torch.Size([32, 128, 32, 32])]
  Output shape: torch.Size([32, 131072])
Linear:
  Input shape: [torch.Size([32, 131840])]
  Output shape: torch.Size([32, 256])
Linear:
  Input shape: [torch.Size([32, 131840])]
  Output shape: torch.Size([32, 256])
Linear:
  Input shape: [torch.Size([32, 1024])]
  Output shape: torch.Size([32, 131072])
ConvTranspose2d:
  Input shape: [torch.Size([32, 128, 32, 32])]
  Output shape: torch.Size([32, 64, 64, 64])
ReLU:
  Input shape: [torch.Size([32, 64, 64, 64])]
  Output shape: torch.Size([32, 64, 64, 64

In [76]:
class CocoDataset(Dataset):
    def __init__(self, root, captions_file, instances_file, transform, embedding_dir=None):
        self.root = root
        self.coco_captions = COCO(captions_file)
        self.coco_instances = COCO(instances_file)
        self.ids = list(self.coco_captions.imgToAnns.keys())
        self.transform = transform
        self.embedding_dir = embedding_dir

        # Lazy initialization of tokenizer and model
        self.tokenizer = None
        self.model = None

        self.grouped_data = None  # Store grouped data after group_by_category is called

    def _initialize_model(self):
        if self.tokenizer is None or self.model is None:
            self.tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
            self.model = DistilBertModel.from_pretrained('distilbert-base-uncased').to("cuda")

    def __len__(self):
        if self.grouped_data:
            return sum(len(samples) for samples in self.grouped_data.values())
        return len(self.ids)

    def __getitem__(self, idx):
        if self.grouped_data:
            flattened_data = [
                (category, img_id, captions)
                for category, samples in self.grouped_data.items()
                for img_id, captions in samples
            ]
            category, img_id, captions = flattened_data[idx]
        else:
            img_id = self.ids[idx]
            ann = self.coco_captions.imgToAnns[img_id]
            captions = [a['caption'] for a in ann]
            category = None

        img_info = self.coco_captions.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        if self.embedding_dir:
            embedding_path = os.path.join(self.embedding_dir, f"{img_id}.pt")
            if not os.path.exists(embedding_path):
                raise FileNotFoundError(f"Embedding not found for image ID {img_id}")
            embeddings = torch.load(embedding_path)
        else:
            self._initialize_model()
            embeddings = []
            for caption in captions:
                tokenized = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=16).to("cuda")
                with torch.no_grad():
                    text_embedding = self.model(**tokenized).last_hidden_state.mean(dim=1).squeeze(0).cpu()
                    embeddings.append(text_embedding)
            embeddings = torch.stack(embeddings)

        return image, embeddings, category

    def group_by_category(self):
        categories = self.coco_instances.loadCats(self.coco_instances.getCatIds())
        category_id_to_name = {cat['id']: cat['name'] for cat in categories}
        grouped_data = {cat['name']: [] for cat in categories}

        for img_id in self.ids:
            ann_ids = self.coco_instances.getAnnIds(imgIds=img_id)
            anns = self.coco_instances.loadAnns(ann_ids)
            category_ids = {ann['category_id'] for ann in anns}
            caption_ann_ids = self.coco_captions.getAnnIds(imgIds=img_id)
            caption_anns = self.coco_captions.loadAnns(caption_ann_ids)
            captions = [ann['caption'] for ann in caption_anns]

            for category_id in category_ids:
                category_name = category_id_to_name[category_id]
                grouped_data[category_name].append((img_id, captions))

        self.grouped_data = grouped_data
        return grouped_data


In [70]:
import os
from torchvision import transforms
from torch.utils.data import Dataset
from pycocotools.coco import COCO
from PIL import Image
from transformers import BertTokenizer, BertModel

class CocoDataset(Dataset):
    def __init__(self, root, captions_file, instances_file, transform):
        self.root = root
        self.coco_captions = COCO(captions_file)
        self.coco_instances = COCO(instances_file)  # Load instances file for categories
        self.ids = list(self.coco_captions.imgToAnns.keys())
        self.transform = transform
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert_model = BertModel.from_pretrained('bert-base-uncased')

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        ann = self.coco_captions.imgToAnns[img_id]
        caption = ann[0]['caption']  # Get the first annotation's caption

        img_info = self.coco_captions.loadImgs(img_id)[0]
        img_path = os.path.join(self.root, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        # Tokenize caption and get BERT embeddings
        tokenized = self.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=16)
        with torch.no_grad():
            text_embedding = self.model(**tokenized).last_hidden_state.mean(dim=1).squeeze(0)

        return image, text_embedding

    def group_by_category(self):
        """
        Group all samples by category and fetch corresponding captions.

        :return: Dictionary with category names as keys and lists of (image_id, captions) as values.
        """
        # Get all categories from instances
        categories = self.coco_instances.loadCats(self.coco_instances.getCatIds())
        category_id_to_name = {cat['id']: cat['name'] for cat in categories}

        # Initialize dictionary to group samples by category
        grouped_data = {cat['name']: [] for cat in categories}
        
        # Iterate over all image IDs
        for img_id in self.ids:
            # Get annotations from instances file
            ann_ids = self.coco_instances.getAnnIds(imgIds=img_id)
            anns = self.coco_instances.loadAnns(ann_ids)

            # Get the associated category IDs
            category_ids = {ann['category_id'] for ann in anns}

            # Fetch captions from captions file
            caption_ann_ids = self.coco_captions.getAnnIds(imgIds=img_id)
            caption_anns = self.coco_captions.loadAnns(caption_ann_ids)
            captions = [ann['caption'] for ann in caption_anns]

            # Group by category name
            for category_id in category_ids:
                category_name = category_id_to_name[category_id]
                grouped_data[category_name].append((img_id, captions))

        return grouped_data



In [54]:
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from tqdm import tqdm

class Trainer:
    def __init__(self, model, dataloader, lr=0.001, device="cuda"):
        self.model = model.to(device)
        self.dataloader = dataloader
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        self.device = device

    @staticmethod
    def loss_function(recon_x, x, mu, logvar):
        recon_loss = nn.MSELoss()(recon_x, x)
        kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss, kl_loss

    def train(self, epochs):
        for epoch in range(epochs):
            self.model.train()            
            train_loss = 0
            recon_loss_total = 0
            kl_loss_total = 0
            latent_vectors = []
            labels = []  # Collect labels or categories for visualization
            
            for step, (images, text_embeddings) in enumerate(tqdm(self.dataloader)):
                images = images.to(self.device)
                text_embeddings = text_embeddings.to(self.device)

                self.optimizer.zero_grad()
                recon_images, mu, logvar = self.model(images, text_embeddings)
                z = self.model.reparameterize(mu, logvar)  # Get latent vector
                
                # Collect latent vectors and labels
                latent_vectors.append(z.detach().cpu())
                labels.append(text_embeddings.detach().cpu())  # Replace with your label logic

                recon_loss, kl_loss = self.loss_function(recon_images, images, mu, logvar)
                loss = recon_loss + kl_loss
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                recon_loss_total += recon_loss.item()
                kl_loss_total += kl_loss.item()

                # Log training metrics to wandb every N steps
                if step % 10 == 0:
                    wandb.log({
                        "train/reconstruction_loss": recon_loss.item(),
                        "train/kl_loss": kl_loss.item(),
                        "train/total_loss": loss.item(),
                        "train/step": step + epoch * len(self.dataloader)
                    })
                    
            torch.save(cvae.state_dict(), "cvae_model1.pth")
            print("Model parameters saved successfully.")
            
            # Concatenate collected latent vectors and labels
            latent_vectors = torch.cat(latent_vectors, dim=0)
            labels = torch.cat(labels, dim=0)  # Modify if you have actual labels
            
            # Visualize latent space after each epoch
            self.visualize_latent_space(latent_vectors, labels, epoch)

            print(f"Epoch {epoch + 1}/{epochs}, Loss: {train_loss / len(self.dataloader):.4f}")
            # Validation
            
            # Log generated image samples
            if epoch % 1 == 0:  # Log every epoch
                sample_image, sample_text = next(iter(self.val_dataloader))
                generated_image = self.generate_sample(sample_image[0], sample_text[0])
                wandb.log({"generated_image": wandb.Image(generated_image)})
    
    def generate_sample(self, image, text_embedding):
        self.model.eval()
        with torch.no_grad():
            recon_image, _, _ = self.model(image.unsqueeze(0).to(self.device), 
                                           text_embedding.unsqueeze(0).to(self.device))
        return recon_image.squeeze(0).cpu()
        
    @staticmethod
    def visualize_latent_space(latent_vectors, labels, epoch):
        # Use PCA or t-SNE for dimensionality reduction
        if latent_vectors.size(1) > 2:  # Dimensionality reduction if latent dim > 2
            latent_vectors = TSNE(n_components=2).fit_transform(latent_vectors)
        else:
            latent_vectors = latent_vectors.numpy()

        # Optional: Use t-SNE for nonlinear dimensionality reduction
        # latent_vectors = TSNE(n_components=2).fit_transform(latent_vectors)

        # Plot latent space
        plt.figure(figsize=(8, 8))
        scatter = plt.scatter(latent_vectors[:, 0], latent_vectors[:, 1], c=labels.numpy(), cmap='viridis', s=10)
        plt.colorbar(scatter, label='Text Embedding Labels')  # Modify label as needed
        plt.title(f"Latent Space Visualization - Epoch {epoch + 1}")
        plt.xlabel("Latent Dim 1")
        plt.ylabel("Latent Dim 2")
        plt.grid(True)
        plt.savefig(f"latent_space_epoch_{epoch + 1}.png")  # Save plot for each epoch
        plt.show()



In [4]:
import matplotlib.pyplot as plt

class Inference:
    def __init__(self, model, device="cuda"):
        self.model = model.to(device)
        self.device = device

    def generate_image(self, image, text_embedding):
        self.model.eval()
        image = image.to(self.device)
        text_embedding = text_embedding.to(self.device)

        with torch.no_grad():
            mu, logvar = self.model.encode(image.unsqueeze(0), text_embedding.unsqueeze(0))
            z = self.model.reparameterize(mu, logvar)
            generated_image = self.model.decode(z, text_embedding.unsqueeze(0)).squeeze(0)

        return generated_image

    @staticmethod
    def visualize_image(image_tensor):
        plt.imshow(image_tensor.permute(1, 2, 0).cpu().numpy())
        plt.axis("off")
        plt.show()


In [81]:
import os
from torch.utils.data import DataLoader
from torchvision import transforms
import torch
from tqdm import tqdm
#from coco_dataset import CocoDataset  # Ensure the updated CocoDataset is imported

# Set CUDA_VISIBLE_DEVICES
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

def precompute_embeddings_gpu(dataset, output_dir, batch_size=64):
    """
    Precompute embeddings using GPU for faster computation.
    
    :param dataset: Dataset object
    :param output_dir: Directory to save embeddings
    :param batch_size: Number of captions processed in a batch
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize model and tokenizer
    dataset._initialize_model()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dataset.model.to(device)  # Move model to GPU

    for img_id in tqdm(dataset.ids, desc="Precomputing embeddings"):
        captions = [ann['caption'] for ann in dataset.coco_captions.imgToAnns[img_id]]

        # Tokenize captions and move to GPU
        tokenized = dataset.tokenizer(captions, return_tensors="pt", padding=True, truncation=True, max_length=16).to(device)

        with torch.no_grad():
            # Compute embeddings on GPU
            embeddings = dataset.model(**tokenized).last_hidden_state.mean(dim=1).cpu()  # Move back to CPU for saving

        # Save embeddings
        torch.save(embeddings, os.path.join(output_dir, f"{img_id}.pt"))


root = "./coco/images/train2017"
annotation_file = "./coco/annotations/captions_train2017.json"
instances_file = "./coco/annotations/instances_train2017.json"
embedding_dir = "./precomputed_embeddings"
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

if not os.path.exists(embedding_dir) or len(os.listdir(embedding_dir)) == 0:
    dataset_for_precompute = CocoDataset(root, annotation_file, instances_file, transform=transform)
    print("Precomputing embeddings...")
    precompute_embeddings_gpu(dataset_for_precompute, embedding_dir)

dataset = CocoDataset(root, annotation_file, instances_file, transform=transform, embedding_dir=embedding_dir)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

for images, text_embeddings, category in dataloader:
    print("Batch image shape:", images.shape)
    print("Batch text embedding shape:", text_embeddings.shape)
    print("Category:", category)
    break


'''
# Model setup
image_dim = config.image_dim
text_dim = config.text_dim
latent_dim = config.latent_dim
cvae = CVAE(image_dim=image_dim, text_dim=text_dim, latent_dim=latent_dim)

# Training
trainer = Trainer(model=cvae, dataloader=dataloader, lr=config.learning_rate, device="cuda")
trainer.train(epochs=config.epochs)

torch.save(cvae.state_dict(), "cvae_model.pth")
print("Model parameters saved successfully.")

# Inference
inference = Inference(model=cvae, device="cuda")
sample_image, sample_text_embedding = dataset[0]
generated_image = inference.generate_image(sample_image, sample_text_embedding)

# Visualize
inference.visualize_image(generated_image)
'''


loading annotations into memory...
Done (t=0.74s)
creating index...
index created!
loading annotations into memory...


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7f2f48123c70>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/ipkernel.py", line 788, in _clean_thread_parent_frames
    if phase != "start":
KeyboardInterrupt: 


Done (t=21.79s)
creating index...


KeyboardInterrupt: 