In [None]:
%pip install git+https://github.com/openai/CLIP.git

In [None]:
import torch
from torch.utils.tensorboard import SummaryWriter
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image, make_grid
import ssl
import os
import clip
import time

ssl._create_default_https_context = ssl._create_unverified_context

# Device selection
if torch.cuda.is_available():
    print("Using GPU")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Using MPS")
    device = torch.device("mps")
else:
    print("Using CPU")
    device = torch.device("cpu")

kwargs = {'num_workers': 1, 'pin_memory': True} if device.type == 'cuda' else {'num_workers': 0}

# Hyperparameters
batch_size = 16 if device.type == 'cuda' else 8
latent_size = 64
clip_dim = None  # TODO: CLIP embeddings are ___-dimensional
epochs = 20
beta = 0.5
init_channels = 32
image_size = 64  # downsample CelebA to 64x64 for manageable training

os.makedirs('output', exist_ok=True)

# Logging to Tensorboard
current_time = time.strftime("%Y%m%d-%H%M%S")
run_name = f'runs/celeba_cvae_v1_{current_time}'
writer = SummaryWriter(log_dir=run_name)

In [None]:
class CelebAWithCLIPEmbeddings(torch.utils.data.Dataset):
    def __init__(self, split, embeddings_path, data_dir='./data'):
        self.split = split
        self.embeddings_path = embeddings_path
        
        self.celeba = datasets.CelebA(
            root=data_dir, 
            split=split,  # 'train', 'valid', or 'test'
            download=True,
            transform=transforms.Compose([
                transforms.Resize((image_size, image_size)),  # Resize to manageable size
                transforms.ToTensor()
            ])
        )
        
        self.embeddings_data = torch.load(embeddings_path, map_location='cpu')
        self.embeddings = self.embeddings_data[f'{split}']
        self.labels = self.embeddings_data[f'{split}_labels']  # These are 40 attributes
        
        print(f"Loaded {len(self.embeddings)} CLIP embeddings for {split} set")
        print(f"CelebA attributes shape: {self.labels.shape if hasattr(self.labels, 'shape') else 'N/A'}")
        
    def __len__(self):
        return len(self.celeba)
    
    def __getitem__(self, idx):
        image, attributes = self.celeba[idx]  # attributes is tensor of shape [40]
        clip_embedding = self.embeddings[idx]
        return image, attributes, clip_embedding


In [None]:
class CelebaCVAE(nn.Module):
    def __init__(self, image_channels, init_channels, latent_size, class_size, image_size=64):
        super(CelebaCVAE, self).__init__()
        self.image_channels = image_channels
        self.latent_size = latent_size
        self.class_size = class_size  # clip_dim for CLIP embeddings
        self.init_channels = init_channels
        self.image_size = image_size
        
        conv_output_size = init_channels * 8  # Final channel count
        
        # Encoder
        self.encoder = nn.Sequential(
            # 64x64 -> 32x32
            nn.Conv2d(self.image_channels, init_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels),
            nn.ReLU(),
            
            # 32x32 -> 16x16
            nn.Conv2d(init_channels, init_channels*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*2),
            nn.ReLU(),
            
            # 16x16 -> 8x8
            nn.Conv2d(init_channels*2, init_channels*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*4),
            nn.ReLU(),
            
            # 8x8 -> 4x4
            nn.Conv2d(init_channels*4, init_channels*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*8),
            nn.ReLU(),
            
            # 4x4 -> 2x2
            nn.Conv2d(init_channels*8, init_channels*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*8),
            nn.ReLU(),
            
            # 2x2 -> 1x1
            nn.Conv2d(init_channels*8, conv_output_size, kernel_size=2, stride=1, padding=0),
            nn.ReLU()
        )
        
        # FC layers to get mu and logvar
        # TODO: These layers need to accept CLIP embeddings (self.class_size)
        self.fc1 = nn.Linear(conv_output_size, 512)  # TODO: Modify
        self.fc_mu = nn.Linear(512, self.latent_size)
        self.fc_logvar = nn.Linear(512, self.latent_size)
        self.fc2 = nn.Linear(self.latent_size, conv_output_size)  # TODO: Modify
        
        # Decoder for 64x64 images
        self.decoder = nn.Sequential(
            # 1x1 -> 2x2
            nn.ConvTranspose2d(conv_output_size, init_channels*8, kernel_size=2, stride=1, padding=0),
            nn.BatchNorm2d(init_channels*8),
            nn.ReLU(),
            
            # 2x2 -> 4x4
            nn.ConvTranspose2d(init_channels*8, init_channels*8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*8),
            nn.ReLU(),
            
            # 4x4 -> 8x8
            nn.ConvTranspose2d(init_channels*8, init_channels*4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*4),
            nn.ReLU(),
            
            # 8x8 -> 16x16
            nn.ConvTranspose2d(init_channels*4, init_channels*2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels*2),
            nn.ReLU(),
            
            # 16x16 -> 32x32
            nn.ConvTranspose2d(init_channels*2, init_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(init_channels),
            nn.ReLU(),
            
            # 32x32 -> 64x64
            nn.ConvTranspose2d(init_channels, self.image_channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # output values between 0 and 1
        )
    
    def encode(self, x, c):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        
        inputs = torch.cat([h, c], 1)
        
        h_fc = F.relu(self.fc1(inputs))
        mu = self.fc_mu(h_fc)
        logvar = self.fc_logvar(h_fc)
        return mu, logvar
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        sample = mu + eps * std
        return sample
    
    def decode(self, z, c):
        inputs = torch.cat([z, c], 1)
        
        h = F.relu(self.fc2(inputs))
        h = h.view(-1, self.init_channels * 8, 1, 1)
        
        return self.decoder(h)

    def forward(self, x, c):
        mu, logvar = self.encode(x, c)
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, c)
        return recon_x, mu, logvar


In [None]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x.view(recon_x.size(0), -1), x.view(x.size(0), -1), reduction='sum')
    KLD = -beta * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD * 0.5, BCE, KLD

def train(epoch, model, optimizer, train_loader):
    model.train()
    train_loss = 0
    
    for batch_idx, (data, attributes, clip_embeddings) in enumerate(train_loader):
        data, attributes, clip_embeddings = data.to(device), attributes.to(device), clip_embeddings.to(device)
        
        # TODO: Pass clip_embeddings to the model (unlike MNIST's one_hot encoding, use clip_embeddings directly)
        recon_batch, mu, logvar = model(___, ___)  # TODO: Modify
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)[0]
        loss.backward()
        train_loss += loss.detach().cpu().numpy()
        optimizer.step()
        
        if batch_idx % 100 == 0:  # less frequent logging for larger dataset
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
    
    avg_loss = train_loss / len(train_loader.dataset)
    print('====> Epoch: {} Average loss: {:.4f}'.format(epoch, avg_loss))

    # Log to Tensorboard
    writer.add_scalar('Loss/train', avg_loss, epoch)
    
def test(epoch, model, test_loader):
    model.eval()
    test_loss = 0
    total_bce = 0
    total_kld = 0
    
    with torch.no_grad():
        for i, (data, attributes, clip_embeddings) in enumerate(test_loader):
            data, attributes, clip_embeddings = data.to(device), attributes.to(device), clip_embeddings.to(device)
            
            recon_batch, mu, logvar = model(data, clip_embeddings)
            loss, bce, kld = loss_function(recon_batch, data, mu, logvar)
            
            test_loss += loss.detach().cpu().numpy()
            total_bce += bce.detach().cpu().numpy()
            total_kld += kld.detach().cpu().numpy()
            
            if i == 0:
                n = min(data.size(0), 8)  # Show more samples for faces
                comparison = torch.cat([data[:n], recon_batch[:n]])
                grid = make_grid(comparison.cpu(), nrow=n, normalize=True)
                writer.add_image('Reconstruction', grid, epoch)
                # save_image(comparison.cpu(),
                #          'output/reconstruction_' + str(f"{epoch:02}") + '.png', nrow=n)
                save_image(comparison.cpu(),
                         f'output/reconstruction_{epoch:02d}.png', nrow=n, normalize=True)
    
    avg_loss = test_loss / len(test_loader.dataset)
    avg_bce = total_bce / len(test_loader.dataset)
    avg_kld = total_kld / len(test_loader.dataset)
    
    print('====> Test set loss: {:.4f}'.format(avg_loss))
    
    # Log to Tensorboard
    writer.add_scalar('Loss/test', avg_loss, epoch)
    writer.add_scalar('Loss/BCE', avg_bce, epoch)
    writer.add_scalar('Loss/KLD', avg_kld, epoch)

In [None]:
# TODO: Specify path to pre-computed CLIP embeddings
# Note: We ran setup_clip_embeddings.ipynb first to generate this file
embeddings_path = None

train_dataset = CelebAWithCLIPEmbeddings('train', embeddings_path, './data')
valid_dataset = CelebAWithCLIPEmbeddings('valid', embeddings_path, './data')
test_dataset = CelebAWithCLIPEmbeddings('test', embeddings_path, './data')

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, **kwargs)
valid_loader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=batch_size, shuffle=False, **kwargs)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=batch_size, shuffle=False, **kwargs)

print(f"Train: {len(train_dataset)}, Valid: {len(valid_dataset)}, Test: {len(test_dataset)}")

In [None]:
# TODO: Set image_channels to account for RGB (unlike MNIST's 1 channel)
image_channels = None

model = CelebaCVAE(image_channels, init_channels, latent_size, clip_dim, image_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-4, betas=(0.5, 0.999))  # Better params for faces

for epoch in range(1, epochs + 1):
    train(epoch, model, optimizer, train_loader)
    test(epoch, model, valid_loader)  # Use validation set for monitoring
    
    # Generate sample faces every few epochs
    if epoch % 5 == 0:
        with torch.no_grad():
            # Sample random embeddings from test set
            sample_embeddings = []
            for i in range(8):
                idx = torch.randint(0, len(test_dataset), (1,)).item()
                sample_embeddings.append(test_dataset.embeddings[idx])
            
            # TODO: Stack embeddings and move to device
            sample_clip_embeddings = None  # HINT: torch.stack(____).to(device)
            sample_z = torch.randn(8, latent_size).to(device)
            
            # TODO: Decode using latent z and CLIP embeddings
            sample = None  # HINT: model.decode(___, ___).cpu()
            
            # Log images to Tensorboard
            sample_grid = make_grid(sample, nrow=4, normalize=True)
            writer.add_image('Generated/from_embeddings', sample_grid, epoch)
            
            save_image(sample, f'output/sample_{epoch:02d}.png', nrow=4, normalize=True)

# Save the model
torch.save(model.state_dict(), 'celeba_cvae_model.pth')

# Close the Tensorboard writer
writer.close()


In [None]:
# Text prompt face generation

def generate_faces_with_text(model, text_prompt, num_samples=8):
    print(f"Generating faces for text: '{text_prompt}'")
    
    clip_model, _ = clip.load("ViT-B/32", device=device)
    
    text = clip.tokenize([text_prompt]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text)
    
    model.eval()
    with torch.no_grad():
        sample_z = torch.randn(num_samples, latent_size).to(device)
        
        # TODO: Repeat text features for all samples
        text_condition = None  # Hint: text_features.repeat(___, 1)
        
        # TODO: Decode
        sample = None  # HINT: model.decode(___, ___).cpu()
        
        filename = f'text_generation_{text_prompt.replace(" ", "_").replace(",", "")}.png'
        save_image(sample, filename, nrow=4, normalize=True)
        
        print(f"Generated faces saved to {filename}")
        return sample

# Generate faces for different prompts
face_prompts = [
    "smiling woman",
    "man with beard", 
    "young person",
    "person with glasses",
    "blonde hair",
    "dark hair"
]

for prompt in face_prompts:
    generate_faces_with_text(model, prompt, 8)
