# Introduction to Deep Learning 
## Professor Vahid Tarokh
### Student: Ashley, John, Ryan, Julian
#### Team Project
#### Concatenation C-VAE on  COCO, dimensions 64x64, 128x128, 256x256




##### Disclaimer: ChatGPT was used for creating the solution to the project assignment.
##### Disclaimer: Solution partly based on HW5.

In [None]:

! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
! pip install --upgrade torchmetrics
! pip install torch-fidelity

Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/44.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-7ok2avxd
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-7ok2avxd
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369489

In [None]:


import numpy as np
import matplotlib.pyplot as plt
import os
from pycocotools.coco import COCO
import requests
from tqdm import tqdm
import json
import clip
from typing import List, Union
import random
import zipfile
from PIL import Image

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision.datasets import CocoCaptions
from torchvision import transforms
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.utils import make_grid

from torchmetrics.image import StructuralSimilarityIndexMeasure
from torchmetrics.image.fid import FrechetInceptionDistance

In [None]:
# load clip
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

try:
    model, preprocess = clip.load("ViT-B/32", device=device, jit=False)  # Set jit=False for better stability
    print("CLIP loaded successfully")
except Exception as e:
    print(f"Error loading CLIP: {e}")

Using device: cuda


100%|████████████████████████████████████████| 338M/338M [00:02<00:00, 138MiB/s]


CLIP loaded successfully


## Util functions

In [None]:
def generate_from_text(model, text_prompt, device="cuda"):
    model.eval()
    with torch.no_grad():
        # Encode text with CLIP
        sentences = re.split(r'(?<=[.!?])\s+', text_prompt)
        text_embedding = model.encode_condition(sentences)
        text_embedding = text_embedding.to(device)

        # Sample from latent space
        z = torch.randn(1, model.latent_dim).to(device)

        # Generate image
        generated_img = model.decode(z, text_embedding)

        # Convert to displayable format
        generated_img = generated_img.squeeze(0).cpu().permute(1, 2, 0)

        # Plot
        plt.figure(figsize=(5, 10))
        plt.imshow(generated_img)
        plt.axis('off')
        plt.title(f'Generated image for: "{text_prompt}"')
        plt.show()

        return generated_img

## Dataloading

In [None]:
def download_coco_subset(num_images=1000):
    os.makedirs('coco_images', exist_ok=True)
    os.makedirs('coco_annotations', exist_ok=True)

    annotation_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    if not os.path.exists('coco_annotations/captions_train2017.json'):
        print("Downloading annotations...")
        response = requests.get(annotation_url)
        with open('annotations.zip', 'wb') as f:
            f.write(response.content)

        with zipfile.ZipFile('annotations.zip', 'r') as zip_ref:
            zip_ref.extractall('coco_annotations')

    coco = COCO('/content/coco_annotations/annotations/captions_train2017.json')

    img_ids = coco.getImgIds()
    selected_ids = random.sample(img_ids, num_images)

    print(f"Downloading {num_images} images...")
    for img_id in tqdm(selected_ids):
        # Get image info
        img_info = coco.loadImgs(img_id)[0]
        img_url = img_info['coco_url']
        file_name = img_info['file_name']
        file_path = os.path.join('coco_images', file_name)

        if os.path.exists(file_path):
            continue

        try:
            response = requests.get(img_url)
            if response.status_code == 200:
                with open(file_path, 'wb') as f:
                    f.write(response.content)
        except Exception as e:
            print(f"Error downloading {file_name}: {e}")

    print("Download complete!")
    return 'coco_images', 'coco_annotations/captions_train2017.json'

image_dir, annotation_file = download_coco_subset(num_images=5000)

Downloading annotations...
loading annotations into memory...
Done (t=1.36s)
creating index...
index created!
Downloading 5000 images...


100%|██████████| 5000/5000 [35:27<00:00,  2.35it/s]

Download complete!





In [None]:
def download_coco_val_subset(num_images=1000):
    os.makedirs('coco_images', exist_ok=True)
    os.makedirs('coco_annotations', exist_ok=True)

    annotation_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    if not os.path.exists('coco_annotations/captions_train2017.json'):
        print("Downloading annotations...")
        response = requests.get(annotation_url)
        with open('annotations.zip', 'wb') as f:
            f.write(response.content)

        with zipfile.ZipFile('annotations.zip', 'r') as zip_ref:
            zip_ref.extractall('coco_annotations')

    coco = COCO('/content/coco_annotations/annotations/captions_val2017.json')

    img_ids = coco.getImgIds()
    selected_ids = random.sample(img_ids, num_images)

    print(f"Downloading {num_images} images...")
    for img_id in tqdm(selected_ids):
        img_info = coco.loadImgs(img_id)[0]
        img_url = img_info['coco_url']
        file_name = img_info['file_name']
        file_path = os.path.join('coco_images', file_name)

        if os.path.exists(file_path):
            continue

        try:
            response = requests.get(img_url)
            if response.status_code == 200:
                with open(file_path, 'wb') as f:
                    f.write(response.content)
        except Exception as e:
            print(f"Error downloading {file_name}: {e}")

    print("Download complete!")
    return 'coco_images', 'coco_annotations/captions_val2017.json'

image_dir_val, annotation_file_val = download_coco_val_subset(num_images=1000)

Downloading annotations...
loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
Downloading 1000 images...


100%|██████████| 1000/1000 [07:01<00:00,  2.37it/s]

Download complete!





In [None]:
class COCODatasetWithClip(Dataset):
    def __init__(self, root_dir, annotation_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.coco = COCO(annotation_file)

        all_ids = list(self.coco.imgs.keys())
        self.ids = []
        for img_id in all_ids:
            img_info = self.coco.loadImgs(img_id)[0]
            file_path = os.path.join(root_dir, img_info['file_name'])
            if os.path.exists(file_path):
                self.ids.append(img_id)

        print(f"Found {len(self.ids)} images in directory")

        print("Loading CLIP...")
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()

        print("Pre-encoding captions...")
        self.encoded_captions = {}
        for img_id in tqdm(self.ids):
            ann_ids = self.coco.getAnnIds(imgIds=img_id)
            anns = self.coco.loadAnns(ann_ids)
            captions = [ann['caption'] for ann in anns]

            with torch.no_grad():
                text_inputs = clip.tokenize(captions).to(device)
                text_features = self.clip_model.encode_text(text_inputs)
                # Convert to float32 before averaging
                text_features = text_features.float()
                avg_embedding = text_features.mean(dim=0)
                self.encoded_captions[img_id] = avg_embedding

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]

        img_info = self.coco.loadImgs(img_id)[0]
        image_path = os.path.join(self.root_dir, img_info['file_name'])
        image = Image.open(image_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        caption_embedding = self.encoded_captions[img_id].float()

        return image, caption_embedding

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
#    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
train_dataset = COCODatasetWithClip(
    root_dir='coco_images',
    annotation_file='/content/coco_annotations/annotations/captions_train2017.json',
    transform=transform)


train_dataloader = DataLoader(
    train_dataset,
    batch_size=256,
    shuffle=True)

loading annotations into memory...
Done (t=1.38s)
creating index...
index created!
Found 5000 images in directory
Loading CLIP...
Pre-encoding captions...


100%|██████████| 5000/5000 [00:57<00:00, 86.91it/s] 


In [None]:
val_dataset = COCODatasetWithClip(
    root_dir='coco_images',
    annotation_file='/content/coco_annotations/annotations/captions_val2017.json',
    transform=transform
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=256,
    shuffle=True)

loading annotations into memory...
Done (t=0.04s)
creating index...
index created!
Found 1000 images in directory
Loading CLIP...
Pre-encoding captions...


100%|██████████| 1000/1000 [00:11<00:00, 87.76it/s]


## 256x256 Model

In [None]:
class CatcVAELarge(nn.Module):
    def __init__(self, latent_dim=128):
        super().__init__()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()
        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.latent_dim = latent_dim

        # encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 6, stride=4, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, stride=2, padding=1),
            nn.Flatten()
        )

        self.flatten_size = 128 * 8 * 8

        self.condition_processor = nn.Sequential(
            nn.Linear(512, self.flatten_size)
        )

        # Latent space
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_var = nn.Linear(self.flatten_size, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim + 512, 8 * 8 * 128)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(128),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(16),
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode_condition(self, text):
        with torch.no_grad():
            embeddings = []
            for sentence in text:
                embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).type(torch.float32))
            return torch.mean(torch.stack(embeddings), dim=0)

    def encode(self, x, c):
        x = self.encoder(x)
        return self.fc_mu(x), self.fc_var(x)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        z = torch.cat([z, c], dim=1)
        z = self.decoder_input(z)
        z = z.view(-1, 128, 8, 8)
        return self.decoder(z)

    def forward(self, x, c):
        mu, log_var = self.encode(x, c)
        z = self.reparameterize(mu, log_var)
        return self.decode(z, c), mu, log_var

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

def sample_images(model, prompts, num_samples=1, device='cuda'):
    model.eval()
    with torch.no_grad():
        text_features = model.encode_condition(prompts)

        z = torch.randn(len(prompts) * num_samples, model.latent_dim).to(device)

        text_features = text_features.repeat_interleave(num_samples, dim=0)
        print(text_features)
        samples = model.decode(z, text_features)
        return samples

## 128x128 Model

In [None]:
class CatcVAEMed(nn.Module):
    def __init__(self, latent_dim=256):
        super().__init__()

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()
        for param in self.clip_model.parameters():
            param.requires_grad = False

        self.latent_dim = latent_dim

        # encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, 4, stride=2, padding=1),
            nn.Flatten()
        )


        self.flatten_size = 128 * 8 * 8

        self.condition_processor = nn.Sequential(
            nn.Linear(512, self.flatten_size)
        )

        # Latent space
        self.fc_mu = nn.Linear(self.flatten_size, latent_dim)
        self.fc_var = nn.Linear(self.flatten_size, latent_dim)

        # Decoder
        self.decoder_input = nn.Linear(latent_dim + 512, 8 * 8 * 128)

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, stride=2, padding=1),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def encode_condition(self, text):
        with torch.no_grad():
            embeddings = []
            for sentence in text:
                embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).type(torch.float32))
            return torch.mean(torch.stack(embeddings), dim=0)

    def encode(self, x, c):
        x = self.encoder(x)
        return self.fc_mu(x), self.fc_var(x)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, c):
        z = torch.cat([z, c], dim=1)
        z = self.decoder_input(z)
        z = z.view(-1, 128, 8, 8)
        return self.decoder(z)

    def forward(self, x, c):
        mu, log_var = self.encode(x, c)
        z = self.reparameterize(mu, log_var)
        return self.decode(z, c), mu, log_var

def loss_function(recon_x, x, mu, log_var):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    return BCE + KLD

def sample_images(model, prompts, num_samples=1, device='cuda'):
    model.eval()
    with torch.no_grad():
        text_features = model.encode_condition(prompts)

        z = torch.randn(len(prompts) * num_samples, model.latent_dim).to(device)

        text_features = text_features.repeat_interleave(num_samples, dim=0)
        print(text_features)
        samples = model.decode(z, text_features)
        return samples

## 64x64 Model

In [None]:
class CatCVAESmall(nn.Module):
    def __init__(self, text_embedding_dim=512, latent_dim=256, image_channels=3, image_size=64):
        """
        Args:
            text_embedding_dim (int): Dimension of text embeddings.
            latent_dim (int): Dimension of latent space.
            image_channels (int): Number of channels in the output image.
            image_size (int): Size (height and width) of the generated images (assumes square images).
        """
        super().__init__()
        self.text_embedding_dim = text_embedding_dim
        self.latent_dim = latent_dim
        self.image_channels = image_channels
        self.image_size = image_size

        self.clip_model, _ = clip.load("ViT-B/32", device=device)
        self.clip_model.eval()
        for param in self.clip_model.parameters():
            param.requires_grad = False


        #### for 64x64 images
        # Encoder:
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, 64, kernel_size=4, stride=2, padding=1),  # 64x64 -> 32x32
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),  # 32x32 -> 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),  # 16x16 -> 8x8
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten()
        )

        # Add text embedding and map to latent space dimensions
        self.add_text_embedding = nn.Linear(256 * (image_size // 8) ** 2 + text_embedding_dim, 1024)
        self.mu = nn.Linear(1024, latent_dim)
        self.logvar = nn.Linear(1024, latent_dim)

        # Decoder:
        self.decoder_input = nn.Linear(latent_dim + text_embedding_dim, 256 * (image_size // 8) ** 2)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 8x8 -> 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),  # 16x16 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, image_channels, kernel_size=4, stride=2, padding=1),  # 32x32 -> 64x64
            nn.Sigmoid()  # Outputs normalized to [-1, 1]
        )

    def encode_condition(self, text):
        with torch.no_grad():
            embeddings = []
            for sentence in text:
                embeddings.append(self.clip_model.encode_text(clip.tokenize(sentence).to('cuda')).type(torch.float32))
            return torch.mean(torch.stack(embeddings), dim=0)


    def reparameterize(self, mu, logvar):
        """Reparameterization trick: z = mu + sigma * epsilon."""
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std


    def encode(self, images, text_embedding):
        """Encoder forward pass."""
        # Encode images
        image_features = self.encoder(images)

        # Combine image features with text embedding for encoder
        combined_features = torch.cat([image_features, text_embedding], dim=1)
        latent_space = self.add_text_embedding(combined_features)

        # Get mean and standard deviation to sample form the latent space
        mu = self.mu(latent_space)
        logvar = self.logvar(latent_space)

        return mu, logvar


    def decode(self, z, text_embedding):
        """Decoder forward pass."""
        # Combine latent space with text embedding for decoder
        decoder_input = torch.cat([z, text_embedding], dim=1)
        decoder_input = self.decoder_input(decoder_input)
        batch_size_dynamic = decoder_input.size(0)  # Dynamically get the batch size
        decoder_input = decoder_input.view(batch_size_dynamic, 256, self.image_size // 8, self.image_size // 8) # for 64x64
        reconstructed_images = self.decoder(decoder_input)
        return reconstructed_images


    def forward(self, images, text_embedding):
        """
        Forward pass through the cVAE.
        Args:
            text_embedding (torch.Tensor): Text embeddings of shape (batch_size, text_embedding_dim).
            images (torch.Tensor, optional): Ground-truth images of shape (batch_size, image_channels, image_size, image_size).
        Returns:
            reconstructed_images (torch.Tensor): Generated images.
            mu (torch.Tensor): Mean of latent distribution.
            logvar (torch.Tensor): Log variance of latent distribution.
        """

        mu, logvar = self.encode(images, text_embedding)
        z = self.reparameterize(mu, logvar)
        reconstructed_images = self.decode(z, text_embedding)
        return reconstructed_images, mu, logvar


## Model Training

In [None]:
def train_cvae(model, train_loader, num_epochs=100, learning_rate=1e-4, device="cuda"):
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3, verbose=True
    )

    train_losses = []
    best_loss = float('inf')

    for epoch in range(num_epochs):
        model.train()
        epoch_loss = 0

        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')

        for batch_idx, (images, captions) in enumerate(pbar):
            images = images.to(device)
            captions = captions.to(device)

            optimizer.zero_grad()

            recon_batch, mu, log_var = model(images, captions)

            loss = loss_function(recon_batch, images, mu, log_var)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            epoch_loss += loss.item()
            pbar.set_postfix({'batch_loss': loss.item()})

        avg_epoch_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_epoch_loss)

        scheduler.step(avg_epoch_loss)

        print(f'====> Epoch: {epoch + 1} Average loss: {avg_epoch_loss:.4f} Learning Rate: {learning_rate}')

        if avg_epoch_loss < best_loss:
            best_loss = avg_epoch_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_loss,
            }, 'best_cvae_model.pth')

    return train_losses

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CatCVAESmall()
# for medium: model = CatcVAEMed()
# for large: model = CatcVAELarge()
model.to(device)

losses = train_cvae(
    model=model,
    train_loader=train_dataloader,
    num_epochs=50,
    learning_rate=1e-3,
    device=device
)

Epoch 1/50: 100%|██████████| 20/20 [00:30<00:00,  1.51s/it, batch_loss=1.5e+6]


====> Epoch: 1 Average loss: 4275483.6125 Learning Rate: 0.001


Epoch 2/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=5.3e+8]


====> Epoch: 2 Average loss: 28767547.0125 Learning Rate: 0.001


Epoch 3/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=1.1e+6]


====> Epoch: 3 Average loss: 2128763.1000 Learning Rate: 0.001


Epoch 4/50: 100%|██████████| 20/20 [00:29<00:00,  1.50s/it, batch_loss=1.06e+6]


====> Epoch: 4 Average loss: 1968157.9000 Learning Rate: 0.001


Epoch 5/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=1.06e+6]


====> Epoch: 5 Average loss: 1944687.7125 Learning Rate: 0.001


Epoch 6/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=1.02e+6]


====> Epoch: 6 Average loss: 1889781.5531 Learning Rate: 0.001


Epoch 7/50: 100%|██████████| 20/20 [00:29<00:00,  1.48s/it, batch_loss=9.99e+5]


====> Epoch: 7 Average loss: 1868094.1719 Learning Rate: 0.001


Epoch 8/50: 100%|██████████| 20/20 [00:29<00:00,  1.45s/it, batch_loss=1.03e+6]


====> Epoch: 8 Average loss: 1860837.4344 Learning Rate: 0.001


Epoch 9/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=1.02e+6]


====> Epoch: 9 Average loss: 1847713.5031 Learning Rate: 0.001


Epoch 10/50: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it, batch_loss=9.98e+5]


====> Epoch: 10 Average loss: 1833786.0281 Learning Rate: 0.001


Epoch 11/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.8e+5]


====> Epoch: 11 Average loss: 1827870.0688 Learning Rate: 0.001


Epoch 12/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.92e+5]


====> Epoch: 12 Average loss: 1820021.0094 Learning Rate: 0.001


Epoch 13/50: 100%|██████████| 20/20 [00:29<00:00,  1.48s/it, batch_loss=9.93e+5]


====> Epoch: 13 Average loss: 1815246.1531 Learning Rate: 0.001


Epoch 14/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=9.98e+5]


====> Epoch: 14 Average loss: 1817746.7500 Learning Rate: 0.001


Epoch 15/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=9.91e+5]


====> Epoch: 15 Average loss: 1812603.4563 Learning Rate: 0.001


Epoch 16/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.88e+5]


====> Epoch: 16 Average loss: 1807243.2156 Learning Rate: 0.001


Epoch 17/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=9.73e+5]


====> Epoch: 17 Average loss: 1803559.6250 Learning Rate: 0.001


Epoch 18/50: 100%|██████████| 20/20 [00:29<00:00,  1.47s/it, batch_loss=9.81e+5]


====> Epoch: 18 Average loss: 1803987.9000 Learning Rate: 0.001


Epoch 19/50: 100%|██████████| 20/20 [00:32<00:00,  1.61s/it, batch_loss=9.69e+5]


====> Epoch: 19 Average loss: 1796161.0844 Learning Rate: 0.001


Epoch 20/50: 100%|██████████| 20/20 [00:30<00:00,  1.55s/it, batch_loss=9.78e+5]


====> Epoch: 20 Average loss: 1796071.2531 Learning Rate: 0.001


Epoch 21/50: 100%|██████████| 20/20 [00:31<00:00,  1.56s/it, batch_loss=9.71e+5]


====> Epoch: 21 Average loss: 1791819.8344 Learning Rate: 0.001


Epoch 22/50: 100%|██████████| 20/20 [00:29<00:00,  1.49s/it, batch_loss=9.8e+5]


====> Epoch: 22 Average loss: 1785513.8094 Learning Rate: 0.001


Epoch 23/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.59e+5]


====> Epoch: 23 Average loss: 1783177.3500 Learning Rate: 0.001


Epoch 24/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.61e+5]


====> Epoch: 24 Average loss: 1780144.5281 Learning Rate: 0.001


Epoch 25/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.59e+5]


====> Epoch: 25 Average loss: 1778638.8906 Learning Rate: 0.001


Epoch 26/50: 100%|██████████| 20/20 [00:30<00:00,  1.50s/it, batch_loss=9.72e+5]


====> Epoch: 26 Average loss: 1775663.9563 Learning Rate: 0.001


Epoch 27/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.8e+5]


====> Epoch: 27 Average loss: 1774180.2281 Learning Rate: 0.001


Epoch 28/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.79e+5]


====> Epoch: 28 Average loss: 1770203.3125 Learning Rate: 0.001


Epoch 29/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.6e+5]


====> Epoch: 29 Average loss: 1769306.7875 Learning Rate: 0.001


Epoch 30/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=9.51e+5]


====> Epoch: 30 Average loss: 1766178.6125 Learning Rate: 0.001


Epoch 31/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.57e+5]


====> Epoch: 31 Average loss: 1768978.0719 Learning Rate: 0.001


Epoch 32/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=9.53e+5]


====> Epoch: 32 Average loss: 1766347.2188 Learning Rate: 0.001


Epoch 33/50: 100%|██████████| 20/20 [00:28<00:00,  1.43s/it, batch_loss=9.75e+5]


====> Epoch: 33 Average loss: 1762905.5938 Learning Rate: 0.001


Epoch 34/50: 100%|██████████| 20/20 [00:28<00:00,  1.43s/it, batch_loss=9.59e+5]


====> Epoch: 34 Average loss: 1762111.7812 Learning Rate: 0.001


Epoch 35/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.64e+5]


====> Epoch: 35 Average loss: 1763300.7000 Learning Rate: 0.001


Epoch 36/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.67e+5]


====> Epoch: 36 Average loss: 1762293.1500 Learning Rate: 0.001


Epoch 37/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.61e+5]


====> Epoch: 37 Average loss: 1759938.7937 Learning Rate: 0.001


Epoch 38/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.47e+5]


====> Epoch: 38 Average loss: 1757622.4438 Learning Rate: 0.001


Epoch 39/50: 100%|██████████| 20/20 [00:29<00:00,  1.46s/it, batch_loss=9.56e+5]


====> Epoch: 39 Average loss: 1758525.2219 Learning Rate: 0.001


Epoch 40/50: 100%|██████████| 20/20 [00:28<00:00,  1.44s/it, batch_loss=9.58e+5]


====> Epoch: 40 Average loss: 1756915.1156 Learning Rate: 0.001


Epoch 41/50: 100%|██████████| 20/20 [00:28<00:00,  1.45s/it, batch_loss=9.51e+5]


====> Epoch: 41 Average loss: 1754413.6344 Learning Rate: 0.001


Epoch 42/50: 100%|██████████| 20/20 [00:29<00:00,  1.45s/it, batch_loss=9.43e+5]


====> Epoch: 42 Average loss: 1753185.6156 Learning Rate: 0.001


Epoch 43/50:  15%|█▌        | 3/20 [00:04<00:23,  1.41s/it, batch_loss=1.79e+6]

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(losses)
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')

In [None]:
torch.save(model.state_dict(), 'catcvae_64.pt')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Testing Metric and Image Generation

In [None]:
def compute_mse(model, data_loader):

    model.eval()
    mse_loss = 0.0
    total_samples = 0

    with torch.no_grad():
        for data, target in data_loader:
            data, target = data.to(device), target.to(device)

            # reconstruct the data using C-VAE
            reconstructed_data, _, _ = model(data, target)
            # print(reconstructed_data.shape)
            # print(data.shape)
            # compute the MSE for the batch
            batch_size = data.size(0)
            batch_mse = F.mse_loss(reconstructed_data, data, reduction='sum')
            mse_loss += batch_mse.item()
            total_samples += batch_size

    # compute and return the average MSE
    average_mse = mse_loss / total_samples
    return average_mse


def compute_average_ssim(model, data_loader, image_size):
    model.eval()
    total_samples = 0
    total_ssim = 0.0

    # Initialize the SSIM metric
    ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)  # Assumes input data is in [0, 1]

    with torch.no_grad():
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)

            reconstructed_data, _, _ = model(data, labels)

            data = data.view(data.size(0), 3, image_size, image_size)
            reconstructed_data = reconstructed_data.view(data.size(0), 3, image_size, image_size)

            # Compute SSIM for the batch
            batch_ssim = ssim(reconstructed_data, data)
            total_ssim += batch_ssim.item() * data.size(0)
            total_samples += data.size(0)

    # Compute average SSIM
    average_ssim = total_ssim / total_samples
    return average_ssim


def compute_fid(model, data_loader, image_size):
    model.eval()
    fid = FrechetInceptionDistance(feature=2048).to(device)  # Feature layer 2048 corresponds to InceptionV3

    with torch.no_grad():
        for data, labels in data_loader:
            data, labels = data.to(device), labels.to(device)

            reconstructed_data, _, _ = model(data, labels)

            data = data.view(data.size(0), 3, image_size, image_size)
            reconstructed_data = reconstructed_data.view(data.size(0), 3, image_size, image_size)

            data = (data * 255).clamp(0, 255).to(torch.uint8)
            reconstructed_data = (reconstructed_data * 255).clamp(0, 255).to(torch.uint8)


            # Add fake and real data to fid object
            fid.update(data, real=True)
            fid.update(reconstructed_data, real=False)

    fid_score = fid.compute()
    return fid_score.item()

In [None]:
def show_reconstruction(model, val_dataloader, size, device="cuda"):
  truth = []
  predicted = []
  with torch.no_grad():
      for data, labels in val_dataloader:
          data, labels = data.to(device), labels.to(device)

          #  process input
          reconstructed_data, _, _ = model(data, labels)

          # Reshape to image format for SSIM computation
          data = data.view(data.size(0), 3, size, size)  # (batch_size, channels, height, width)
          reconstructed_data = reconstructed_data.view(data.size(0), 3, size, size)
          truth.append(data)
          predicted.append(reconstructed_data)

  def show(img1, img2):
      npimg1 = img1.cpu().numpy()
      npimg2 = img2.cpu().numpy()



      fig, axes = plt.subplots(1,2, figsize=(20, 10))
      axes[0].imshow(np.transpose(npimg1, (1, 2, 0)), interpolation='nearest')
      axes[1].imshow(np.transpose(npimg2, (1, 2, 0)), interpolation='nearest')

  # show reconstruction results
  data, caption = next(iter(val_dataloader))
  data, caption = data.to(device), caption.to(device)
  reconstructed_data, _, _ = model(data, caption)

  data = data[:32]
  reconstructed_data = reconstructed_data[:32]

  show(make_grid(data), make_grid(reconstructed_data))


# results from text

def generate_from_text(model, text_prompt, device="cuda"):
    model.eval()
    with torch.no_grad():
        # Encode text with CLIP
        text_embedding = model.encode_condition([text_prompt])
        text_embedding = text_embedding.to(device)

        # Sample from latent space
        z = torch.randn(1, model.latent_dim).to(device)

        # Generate image
        generated_img = model.decode(z, text_embedding)

        # Convert to displayable format
        generated_img = generated_img.squeeze(0).cpu().permute(1, 2, 0)

        # Plot
        plt.figure(figsize=(3, 10))
        plt.imshow(generated_img)
        plt.axis('off')
        plt.title(f'Generated image for: "{text_prompt}"')
        plt.show()

        return generated_img



## Testing Models

In [None]:
model = CatCVAESmall()
model.load_state_dict(torch.load('/content/catcvae_64.pt'))
model.to(device)
print('Small model loaded')


# get test stats for small model
test_mse = compute_mse(model, val_dataloader)
print(f'Test MSE: {test_mse:.4f}')
test_ssim = compute_average_ssim(model, val_dataloader,64)
print(f'Test SSIM: {test_ssim:.4f}')
test_fid = compute_fid(model, val_dataloader,64)
print(f'Test FID: {test_fid:.4f}')


In [None]:
# show reconstruction results
show_reconstruction(model, val_dataloader, 64)

In [None]:
# generate images with small model
test_prompts = [
    "a dog playing in the park",
    "a cat sleeping on a couch",
    "a sunset over the ocean",
    "a person riding a bicycle"
]

for prompt in test_prompts:
    generate_from_text(model, prompt)

In [None]:
# load medium model
model = CatcVAEMed()
model.load_state_dict(torch.load('/content/catcvae_128.pt'))
model.to(device)
print('Small model loaded')


# get test stats for med model
test_mse = compute_mse(model, val_dataloader)
print(f'Test MSE: {test_mse:.4f}')
test_ssim = compute_average_ssim(model, val_dataloader,128)
print(f'Test SSIM: {test_ssim:.4f}')
test_fid = compute_fid(model, val_dataloader,128)
print(f'Test FID: {test_fid:.4f}')


In [None]:
# show reconstruction results
show_reconstruction(model, val_dataloader, 128)

In [None]:
# generate images with med model
test_prompts = [
    "a dog playing in the park",
    "a cat sleeping on a couch",
    "a sunset over the ocean",
    "a person riding a bicycle"
]

for prompt in test_prompts:
    generate_from_text(model, prompt)

In [None]:
# load large model
model = CatcVAEMed()
model.load_state_dict(torch.load('/content/catcvae_256.pt'))
model.to(device)
print('Small model loaded')


# get test stats for med model
test_mse = compute_mse(model, val_dataloader)
print(f'Test MSE: {test_mse:.4f}')
test_ssim = compute_average_ssim(model, val_dataloader,256)
print(f'Test SSIM: {test_ssim:.4f}')
test_fid = compute_fid(model, val_dataloader,256)
print(f'Test FID: {test_fid:.4f}')

In [None]:
# show reconstruction results
show_reconstruction(model, val_dataloader, 256)

In [None]:
# generate images with large model
test_prompts = [
    "a dog playing in the park",
    "a cat sleeping on a couch",
    "a sunset over the ocean",
    "a person riding a bicycle"
]

for prompt in test_prompts:
    generate_from_text(model, prompt)