In [None]:
# Import libraries

import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import matplotlib.pyplot as plt


In [None]:
# Device selection
if torch.cuda.is_available():
    print("Using GPU")
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    print("Using MPS")
    device = torch.device("mps")
else:
    print("Using CPU")
    device = torch.device("cpu")

kwargs = {'num_workers': 1}
if torch.cuda.is_available():
    kwargs['pin_memory'] = True


In [None]:
# Hyperparameters

batch_size = 64
latent_size = 32
init_channels = 8
class_size = None  # TODO: parameter for the number of classes in MNIST (0-9)
epochs = 10


In [None]:
# Data loaders
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        'data', train=True, download=True,
        transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        'data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=False, **kwargs)


In [None]:
def one_hot(labels, class_size):
    """
    Hint: Create a tensor of zeros with shape (batch_size, class_size)
    Then set the appropriate indices to 1 based on the labels
    """
    # TODO: Create a tensor of zeros with the correct shape
    targets = None
    
    # TODO: Fill in the one-hot encoding
    # For each sample i and its corresponding label, set corresponding target index to 1
    for i, label in enumerate(labels):
        pass
    
    return targets.to(device)


In [None]:
class CVAE(nn.Module):
    """    
    Key differences from regular VAE:
    1. Constructor takes an additional 'class_size' parameter
    2. encode() and decode() methods take an additional condition parameter 'c'
    3. FC layers are modified to concatenate class information
    """
    def __init__(self, image_channels, init_channels, latent_size, class_size):
        super(CVAE, self).__init__()
        self.image_channels = image_channels
        self.latent_size = latent_size
        self.class_size = class_size
        self.init_channels = init_channels
        
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(image_channels, init_channels, kernel_size=3, stride=2, padding=1),     # (1, 28, 28) -> (8, 14, 14)
            nn.ReLU(),
            nn.Conv2d(init_channels, init_channels*2, kernel_size=3, stride=2, padding=1),    # (8, 14, 14) -> (16, 7, 7)
            nn.ReLU(),
            nn.Conv2d(init_channels*2, init_channels*4, kernel_size=3, stride=2, padding=1),  # (16, 7, 7) -> (32, 4, 4)
            nn.ReLU(),
            nn.Conv2d(init_channels*4, 64, kernel_size=3, stride=1, padding=0),               # (32, 4, 4) -> (64, 2, 2)
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=2, stride=1, padding=0),                          # (64, 2, 2) -> (64, 1, 1)
            nn.ReLU()
        )
        
        # Hint: The input size should be 64 + ___ for fc1, and latent_size + ___ for fc2
        self.fc1 = nn.Linear(64, 128)  # TODO: Modify
        self.fc_mu = nn.Linear(128, latent_size)
        self.fc_logvar = nn.Linear(128, latent_size)
        self.fc2 = nn.Linear(latent_size, 64)  # TODO: Modify
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=1, padding=0),  # (64, 1, 1) -> (64, 2, 2)
            nn.ReLU(),
            nn.ConvTranspose2d(64, init_channels*4, kernel_size=3, stride=1, padding=0),  # (64, 2, 2) -> (32, 4, 4)
            nn.ReLU(),
            nn.ConvTranspose2d(init_channels*4, init_channels*2, kernel_size=3, stride=2, padding=1),  # (32, 4, 4) -> (16, 7, 7)
            nn.ReLU(),
            nn.ConvTranspose2d(init_channels*2, init_channels, kernel_size=3, stride=2, padding=1, output_padding=1),  # (16, 7, 7) -> (8, 14, 14)
            nn.ReLU(),
            nn.ConvTranspose2d(init_channels, image_channels, kernel_size=4, stride=2, padding=1)  # (8, 14, 14) -> (1, 28, 28)
        )
    
    # TODO: Modify the encode method to accept class condition 'c'
    def encode(self, x, ):
        h = self.encoder(x)
        h = h.view(h.size(0), -1)
        
        # Hint: Use torch.cat(___, 1) to concatenate h and c along dimension 1
        inputs = None  # TODO: Modify
        
        h_fc = F.relu(self.fc1(inputs))
        mu = self.fc_mu(h_fc)
        logvar = self.fc_logvar(h_fc)
        return mu, logvar
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        sample = mu + eps * std
        return sample
    
    # TODO: Modify the decode method to accept class condition 'c'
    def decode(self, z, ):
        # Hint: Similar logic to encoder
        inputs = None  # TODO: Modify
        
        h = F.relu(self.fc2(inputs))
        h = h.view(-1, 64, 1, 1) 
        return torch.sigmoid(self.decoder(h))

    # TODO: Modify the forward method to accept class condition 'c'
    def forward(self, x, ):
        mu, logvar = self.encode(x, ) # TODO: Modify
        z = self.reparameterize(mu, logvar)
        recon_x = self.decode(z, ) # TODO: Modify
        return recon_x, mu, logvar


In [None]:
# Loss function - same as original VAE
def loss_function(recon_x, x, mu, logvar):
    # Reconstruction + KL divergence losses summed over all elements and batch
    BCE = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
    # -0.5 * torch.sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD


In [None]:
def train(model, optimizer, epoch, losses):
    model.train()
    train_loss = 0
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)
        
        # Hint: Use the one_hot function you implemented earlier
        labels_onehot = None  # TODO: Modify
        recon_batch, mu, logvar = model(data, labels_onehot)
        
        optimizer.zero_grad()
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.detach().cpu().numpy()
        optimizer.step()
        
        if batch_idx % 20 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
    
    losses.append(train_loss / len(train_loader.dataset))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))

def test(model, epoch, losses):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, labels) in enumerate(test_loader):
            data, labels = data.to(device), labels.to(device)
            
            # Hint: Same as train()
            labels_onehot = None  # TODO: Modify
            recon_batch, mu, logvar = model(data, labels_onehot) 
            test_loss += loss_function(recon_batch, data, mu, logvar).detach().cpu().numpy()
            
            # Save reconstruction comparison for first batch
            if i == 0:
                n = min(data.size(0), 5)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(-1, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'cvae_reconstruction_' + str(f"{epoch:02}") + '.png', nrow=n)

    losses.append(test_loss / len(test_loader.dataset))
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))


In [None]:
# Initialize and train the CVAE model
model = CVAE(image_channels=1, init_channels=init_channels, latent_size=latent_size, class_size=class_size).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

train_losses = []
test_losses = []

print("Starting CVAE training...")
for epoch in range(1, epochs + 1):
    train(model, optimizer, epoch, train_losses)
    test(model, epoch, test_losses)
    
    # Generate samples for each class
    with torch.no_grad():
        # Hint: torch.eye(___, ___) creates an identity matrix (perfect for one-hot encoding)
        c = None  # TODO: Replace with torch.eye(10, 10).to(device)
        sample = torch.randn(10, latent_size).to(device)
        sample = model.decode(sample, c).cpu()
        save_image(sample.view(10, 1, 28, 28), 
                  str(f"cvae_sample_{epoch:02}.png"))
print("Training completed!")


In [None]:
# Visualize the training and test losses
print("Train losses:", train_losses)
print("Test losses:", test_losses)

plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Train Loss', marker='o')
plt.plot(test_losses, label='Test Loss', marker='s')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('CVAE Training and Test Losses')
plt.grid(True)
plt.legend()
plt.show()


In [None]:
# Save and load model
torch.save(model.state_dict(), 'cvae_model_mnist.pth')
print("Model saved as 'cvae_model_mnist.pth'")


In [None]:
# Load the model (skip training for future use)
# model = CVAE(image_channels=1, init_channels=init_channels, latent_size=latent_size, class_size=class_size).to(device)
# model.load_state_dict(torch.load('cvae_model_mnist.pth'))
# print("Model loaded successfully!")

In [None]:
def generate_digit(model, digit, num_samples=10):
    model.eval()
    with torch.no_grad():
        # Hint: Create a tensor of zeros with shape (num_samples, class_size)
        c = None  # TODO: Modify
        c[:, digit] = 1        
        # Hint: Sample from latent space using torch.randn(___, ___).to(device)
        z = None  # TODO: Modify
        
        sample = None # TODO: Generate samples using the decode method with both z and c
        save_image(sample.view(num_samples, 1, 28, 28), 
                  f"cvae_generated_digit_{digit}.png", nrow=5)
        print(f"Generated {num_samples} samples of digit {digit}")

def generate_all_digits(model, num_samples_per_digit=8):
    model.eval()
    with torch.no_grad():
        all_samples = []
        
        for digit in range(10):
            # TODO: Use logic from generate_digit() to define c, z, and sample
            c = None
            c[:, digit] = None
            z = None

            sample = None
            all_samples.append(sample)
        
        all_samples = torch.cat(all_samples, dim=0)
        save_image(all_samples.view(-1, 1, 28, 28), 
                  "cvae_all_digits_generated.png", nrow=num_samples_per_digit)
        print(f"Generated {num_samples_per_digit} samples for each digit (0-9)")

# TODO: Test your conditional generation functions
print("Testing conditional generation...")
generate_digit(model, 3, num_samples=10)  # example: generate 10 threes


In [None]:
def interpolate_between_classes(model, class1, class2, num_steps=10):
    model.eval()
    with torch.no_grad():
        # TODO: Create one-hot encodings for both classes (replace None's)
        c1 = None
        c1[0, class1] = 1
        c2 = None
        c2[0, class2] = 1
        
        z1 = None
        z2 = None
        
        interpolated_samples = []
        
        for i in range(num_steps):
            # Linear interpolation parameter (0 to 1)
            alpha = i / (num_steps - 1)
            
            # Interpolate in latent space
            z_interp = (1 - alpha) * z1 + alpha * z2
            # Hint: use z_interp as reference for c_interp
            c_interp = None  # TODO: Modify
            
            # Hint: use sample definitions from earlier
            sample = None  # TODO: Generate sample with interpolated conditions 
            interpolated_samples.append(sample)
        
        interpolated_samples = torch.cat(interpolated_samples, dim=0)
        save_image(interpolated_samples.view(-1, 1, 28, 28), 
                  f"cvae_interpolation_{class1}_to_{class2}.png", nrow=num_steps)
        print(f"Generated interpolation from class {class1} to class {class2}")

# TODO: Test interpolation
interpolate_between_classes(model, 0, 1, num_steps=10)  # example: 0 to 1
