In [1]:
# train_vae_fashion.py
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os

# --- Configuration ---
# Create directories for results
os.makedirs('results_fashion', exist_ok=True)
os.makedirs('models_fashion', exist_ok=True)

# Hyperparameters adjusted for Fashion-MNIST
IMG_SIZE = 28
CHANNELS = 1   # Fashion-MNIST is grayscale
LATENT_DIM = 50
EPOCHS = 30    # A few more epochs can be helpful
BATCH_SIZE = 128
LEARNING_RATE = 1e-3
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Data Loading ---
transform = transforms.Compose([
    transforms.ToTensor(), # Images are already 28x28
])

# Use the FashionMNIST dataset loader
dataset = datasets.FashionMNIST(root='data', train=True, download=True, transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# --- VAE Model (Adjusted for 28x28 images) ---
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(CHANNELS, 32, kernel_size=4, stride=2, padding=1), # -> 14x14
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1), # -> 7x7
            nn.ReLU(),
            nn.Flatten()
        )
        
        # Latent space layers
        self.fc_mu = nn.Linear(64*7*7, LATENT_DIM)
        self.fc_logvar = nn.Linear(64*7*7, LATENT_DIM)

        # Decoder
        self.decoder_fc = nn.Linear(LATENT_DIM, 64*7*7)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1), # -> 14x14
            nn.ReLU(),
            nn.ConvTranspose2d(32, CHANNELS, kernel_size=4, stride=2, padding=1), # -> 28x28
            nn.Sigmoid() # Output pixel values between 0 and 1
        )

    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h = self.decoder_fc(z)
        h = h.view(-1, 64, 7, 7) # Reshape to match decoder input
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# --- Loss Function ---
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 1, 28, 28), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# --- Training Loop ---
model = VAE().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

print("Starting training on Fashion-MNIST...")
for epoch in range(EPOCHS):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(DEVICE)
        optimizer.zero_grad()
        
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f"Epoch: {epoch+1}/{EPOCHS} | Batch: {batch_idx}/{len(dataloader)} | Loss: {loss.item()/len(data):.4f}")

    avg_loss = train_loss / len(dataset.data)
    print(f"====> Epoch {epoch+1} complete: Average loss = {avg_loss:.4f}")

    # Save reconstructed images
    if (epoch + 1) % 5 == 0:
        with torch.no_grad():
            sample = torch.randn(64, LATENT_DIM).to(DEVICE)
            generated = model.decode(sample).cpu()
            save_image(generated.view(64, 1, IMG_SIZE, IMG_SIZE), f'results_fashion/sample_{epoch+1}.png')

print("Training finished.")
torch.save(model.state_dict(), 'models_fashion/vae_fashion_mnist_final.pth')

ModuleNotFoundError: No module named 'torchvision'