In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time
import os
from torch.utils.data import Subset
from torchvision.utils import save_image
from torchvision.datasets import DatasetFolder
from PIL import Image
from torch.utils.data import Dataset

print('vae.py is running')

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
    print("Warning CUDA not Found. Using CPU")
else:
    print("Cuda is available. Using GPU")

# Hyper-parameters
num_epochs = 20
learning_rate = 1e-3

# OASIS dataset
print("> Setup OASIS dataset")

class PNGDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.files = [f for f in os.listdir(root) if f.endswith(".png")]
        self.files.sort()  # keep order consistent

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.files[idx])
        img = Image.open(img_path).convert("L")  # grayscale
        if self.transform:
            img = self.transform(img)
        return img, 0  # dummy label

transform = transforms.Compose([
    transforms.Grayscale(),       # ensure single channel
    transforms.Resize((64, 64)),  # resize to fixed size
    transforms.ToTensor(),        # convert to [0,1]
])


trainset = PNGDataset(
    root="/home/groups/comp3710/OASIS/keras_png_slices_train",
    transform=transform
)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)

testset = PNGDataset(
    root="/home/groups/comp3710/OASIS/keras_png_slices_test",
    transform=transform
)

test_loader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)

print("Train set size:", len(trainset))
print("Test set size:", len(testset))

images, labels = next(iter(train_loader))
print("Batch shape:", images.shape)
print("Labels shape:", labels.shape)


class CNNVAE(nn.Module):
    def __init__(self, latent_dim=2):
        super().__init__()
        self.latent_dim = latent_dim

        # -------------------
        # Encoder
        # -------------------
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 64x64 -> 64x64
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 64x64 -> 32x32

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32x32 -> 16x16

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 16x16 -> 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16x16 -> 8x8

            nn.Flatten(),
        )

        # Latent space parameters
        self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)      # Mean
        self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)  # Log variance

        # -------------------
        # Decoder
        # -------------------
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),  # Reshape to (batch, 128, 8, 8)

            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 8x8 -> 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16x16 -> 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),

            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 32x32 -> 64x64
            nn.Sigmoid()  # Keep outputs in [0, 1] range
        )

    # -------------------
    # Forward methods
    # -------------------
    def encode(self, x):
        h = self.encoder_conv(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu  # deterministic output for inference

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


# -------------------
# VAE Loss Function
# -------------------
def vae_loss_function(recon_x, x, mu, logvar, beta=1.0):
    # Reconstruction loss (use BCE if inputs are [0,1], MSE if normalized to [-1,1])
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

    # KL divergence loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + beta * KLD, BCE, KLD

# Make sure results directory exists
os.makedirs("results", exist_ok=True)

# Model + optimizer
vae = CNNVAE(latent_dim=2).to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

# Training loop
beta = 1.0  # Beta parameter for beta-VAE (1.0 = standard VAE)

print("Training VAE...")
for epoch in range(num_epochs):
    vae.train()
    total_loss, total_bce, total_kld = 0, 0, 0

    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(device)

        # Forward pass
        recon_images, mu, logvar = vae(images)

        # Loss
        loss, bce, kld = vae_loss_function(recon_images, images, mu, logvar, beta)

        # Backward + optimize
        vae_optimizer.zero_grad()
        loss.backward()
        vae_optimizer.step()

        total_loss += loss.item()
        total_bce += bce.item()
        total_kld += kld.item()

    avg_loss = total_loss / len(train_loader.dataset)
    avg_bce = total_bce / len(train_loader.dataset)
    avg_kld = total_kld / len(train_loader.dataset)

    print(f'Epoch [{epoch+1}/{num_epochs}] '
          f'Loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, KLD: {avg_kld:.4f}')

    # Save reconstructions every epoch
    vae.eval()
    with torch.no_grad():
        test_images, _ = next(iter(test_loader))
        test_images = test_images.to(device)
        recon_images, _, _ = vae(test_images)

        os.makedirs("results/recon", exist_ok=True)
        save_image(torch.cat([test_images[:8].cpu(), recon_images[:8].cpu()], dim=0),
                    f"results/recon/epoch_{epoch+1:03d}.png", nrow=8, normalize=True)

vae.eval()
with torch.no_grad():
    z = torch.randn(64, vae.latent_dim, device=device)
    samples = vae.decode(z).cpu()
os.makedirs("results/samples", exist_ok=True)
save_image(samples, "results/samples/final_samples.png", nrow=8, normalize=True)


vae.eval()
latents = []
with torch.no_grad():
    for x, _ in test_loader:
        x = x.to(device)
        _, mu, _ = vae(x)
        latents.append(mu.cpu())
latents = torch.cat(latents, dim=0)  # [N, 2]
torch.save(latents, "results/latent_2d.pt")



In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import time
import os
import matplotlib.pyplot as plt
from torch.utils.data import Subset
from torchvision.utils import save_image
from torchvision.datasets import DatasetFolder
from PIL import Image
from torch.utils.data import Dataset

print('vae.py is running')

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if not torch.cuda.is_available():
    print("Warning CUDA not Found. Using CPU")
else:
    print("Cuda is available. Using GPU")

# Hyper-parameters
num_epochs = 20
learning_rate = 1e-3

# OASIS dataset
print("> Setup OASIS dataset")

class PNGDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.files = [f for f in os.listdir(root) if f.endswith(".png")]
        self.files.sort()  # keep order consistent

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, self.files[idx])
        img = Image.open(img_path).convert("L")  # grayscale
        if self.transform:
            img = self.transform(img)
        return img, 0  # dummy label

transform = transforms.Compose([
    transforms.Grayscale(),       # ensure single channel
    transforms.Resize((64, 64)),  # resize to fixed size
    transforms.ToTensor(),        # convert to [0,1]
])

trainset = PNGDataset(
    root="/home/groups/comp3710/OASIS/keras_png_slices_train",
    transform=transform
)

# Reduced size for testing
trainset = Subset(trainset, range(100))

train_loader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)

testset = PNGDataset(
    root="/home/groups/comp3710/OASIS/keras_png_slices_test",
    transform=transform
)

# Reduced size for testing
testset = Subset(testset, range(100))

test_loader = torch.utils.data.DataLoader(testset, batch_size=16, shuffle=False)

print("Train set size:", len(trainset))
print("Test set size:", len(testset))

images, labels = next(iter(train_loader))
print("Batch shape:", images.shape)
print("Labels shape:", labels.shape)

class CNNVAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.latent_dim = latent_dim

        # -------------------
        # Encoder
        # -------------------
        self.encoder_conv = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1),  # 64x64 -> 64x64
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 64x64 -> 32x32

            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),  # 32x32 -> 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32x32 -> 16x16


            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),  # 16x16 -> 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout2d(0.2),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16x16 -> 8x8

            nn.Flatten(),
        )

        # Latent space parameters
        self.fc_mu = nn.Linear(128 * 8 * 8, latent_dim)      # Mean
        self.fc_logvar = nn.Linear(128 * 8 * 8, latent_dim)  # Log variance

        # -------------------
        # Decoder
        # -------------------
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 8 * 8),
            nn.ReLU(),
            nn.Unflatten(1, (128, 8, 8)),  # Reshape to (batch, 128, 8, 8)

            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # 8x8 -> 16x16
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout2d(0.2),

            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # 16x16 -> 32x32
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(0.2),

            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=2, padding=1, output_padding=1),  # 32x32 -> 64x64
            nn.Sigmoid()  # Keep outputs in [0, 1] range
        )

    # -------------------
    # Forward methods
    # -------------------
    def encode(self, x):
        h = self.encoder_conv(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
            return mu  # deterministic output for inference

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar


# -------------------
# VAE Loss Function
# -------------------
def vae_loss_function(recon_x, x, mu, logvar, beta=1.0):
    # Reconstruction loss (use BCE if inputs are [0,1], MSE if normalized to [-1,1])
    BCE = nn.functional.binary_cross_entropy(recon_x, x, reduction='sum')

    # KL divergence loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + beta * KLD, BCE, KLD

# Make sure results directory exists
os.makedirs("results", exist_ok=True)

# Model + optimizer
vae = CNNVAE(latent_dim=32).to(device)
vae_optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

# Training loop
beta = 1.0  # Beta parameter for beta-VAE (1.0 = standard VAE)

print("Training VAE...")
for epoch in range(num_epochs):
    vae.train()
    total_loss, total_bce, total_kld = 0, 0, 0

    for batch_idx, (images, _) in enumerate(train_loader):
        images = images.to(device)

        # Forward pass
        recon_images, mu, logvar = vae(images)

        # Loss
        loss, bce, kld = vae_loss_function(recon_images, images, mu, logvar, beta)

        # Backward + optimize
        vae_optimizer.zero_grad()
        loss.backward()
        vae_optimizer.step()

        total_loss += loss.item()
        total_bce += bce.item()
        total_kld += kld.item()

    avg_loss = total_loss / len(train_loader.dataset)
    avg_bce = total_bce / len(train_loader.dataset)
    avg_kld = total_kld / len(train_loader.dataset)

    print(f'Epoch [{epoch+1}/{num_epochs}] '
          f'Loss: {avg_loss:.4f}, BCE: {avg_bce:.4f}, KLD: {avg_kld:.4f}')

    # Save reconstructions every epoch
    vae.eval()
    with torch.no_grad():
        test_images, _ = next(iter(test_loader))
        test_images = test_images.to(device)
        recon_images, _, _ = vae(test_images)

    n = 8  # number of images to display
plt.figure(figsize=(16, 4))
for i in range(n):
    # Original
    ax = plt.subplot(2, n, i + 1)
    plt.imshow(test_images[i].cpu().squeeze(), cmap="gray")
    plt.title("Original")
    plt.axis("off")

    # Reconstruction
    ax = plt.subplot(2, n, i + 1 + n)
    plt.imshow(recon_images[i].cpu().squeeze(), cmap="gray")
    plt.title("Reconstructed")
    plt.axis("off")

plt.show()