In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.mixture import GaussianMixture
import numpy as np
from torch.utils.data import Subset

In [2]:
transform = transforms.Compose([
    transforms.Resize(64),  # Resize to 64x64 to match generator output size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
from torchvision.datasets import FashionMNIST
dataset  = FashionMNIST(root='./data', train=True, download=True, transform=transform)
subset_indices = torch.randperm(len(dataset))[:10000]  
sub_dataset = Subset(dataset, subset_indices)

dataloader = DataLoader(sub_dataset, batch_size=128, shuffle=True)


In [31]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),  # Normalize to range [-1, 1]
])
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(64*64, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU()
        )
        self.fc_mu = nn.Linear(128, 64)
        self.fc_logvar = nn.Linear(128, 64)
        self.decoder = nn.Sequential(
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 64*64),  # Adjusted to match the size of z
            nn.Sigmoid()
        )

    def encode(self, x):
        h = self.encoder(x)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        x = x.view(-1, 64*64)  # Reshape input to [batch_size, 64*64]
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

# Loss function
def loss_function(recon_x, x, mu, logvar):
    BCE = nn.functional.binary_cross_entropy(recon_x, x.view(-1, 64*64), reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE + KLD

# Initialize model and optimizer
model = VAE().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()

    print('Epoch: {}, Loss: {:.4f}'.format(epoch, train_loss / len(dataloader.dataset)))

# Save the trained model
torch.save(model.state_dict(), 'vae.pth')

Epoch: 0, Loss: -164757.3998
Epoch: 1, Loss: -197703.5263
Epoch: 2, Loss: -200496.9862
Epoch: 3, Loss: -203201.1839
Epoch: 4, Loss: -206488.8753
Epoch: 5, Loss: -208809.1276
Epoch: 6, Loss: -210236.3160
Epoch: 7, Loss: -211443.2987
Epoch: 8, Loss: -212372.2654
Epoch: 9, Loss: -213177.3819


In [8]:

model.state_dict() = torch.load(vae.pth)
model.eval()

SyntaxError: cannot assign to function call here. Maybe you meant '==' instead of '='? (3993826934.py, line 1)