In [142]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Normal, kl_divergence
import torch
import torchvision
import torchvision.transforms as transforms
from tqdm import tqdm
from datasets import load_dataset
import torch.utils.data as data



In [146]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class VAE(nn.Module):
    def __init__(self,in_channels:int,latent_dim:int,hidden_dims: list=None,**kwargs)-> None :
        super().__init__()
        self.latent_dim = latent_dim
        modules = []
        if hidden_dims is None:
            hidden_dims = [32,64,128,256,512]
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(nn.Conv2d(in_channels,out_channels = h_dim,kernel_size=3,stride=2,padding=1),
                nn.BatchNorm2d(h_dim),
                nn.LeakyReLU()))
            in_channels = h_dim
        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4,latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4,latent_dim)
        modules = []
        self.decode_input = nn.Linear(latent_dim,hidden_dims[-1]*4)
        hidden_dims.reverse()
        for i in range(len(hidden_dims)-1):
            modules.append(
                nn.Sequential(nn.ConvTranspose2d(hidden_dims[i],hidden_dims[i+1],kernel_size=3,stride=2,padding=1,output_padding=1),
                nn.BatchNorm2d(hidden_dims[i+1]),
                nn.LeakyReLU()))
        self.decoder = nn.Sequential(*modules)
        self.final_layer = nn.Sequential(nn.ConvTranspose2d(hidden_dims[-1],hidden_dims[-1],kernel_size=3,stride=2,padding=1,output_padding=1),
         nn.BatchNorm2d(hidden_dims[-1]),
         nn.LeakyReLU(),
         nn.Conv2d(hidden_dims[-1],out_channels=3,kernel_size=3,padding=1),
         nn.Sigmoid())
        self.kld_weight = nn.Parameter(torch.ones(1))
    def encode(self,x: torch.Tensor)-> list[torch.Tensor]:
        x = self.encoder(x)
        x = torch.flatten(x,start_dim=1)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        return [ mu , log_var ]
    def decode(self,z:torch.Tensor)-> torch.Tensor:
        x = self.decode_input(z)
        x = x.view(-1,512,2,2)
        x = self.decoder(x)
        out = self.final_layer(x)
        return out
    def reparam_trick(self,mu:torch.Tensor,log_var:torch.Tensor)-> torch.Tensor:
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return eps * std + mu
    def forward(self,input:torch.Tensor)->list[torch.Tensor]:
        mu , log_var = self.encode(input)
        z = self.reparam_trick(mu,log_var)
        return [self.decode(z),mu,log_var]
    def sample(self,num_samples:int)-> torch.Tensor:
        z = torch.randn(num_samples,self.latent_dim)
        z = z.to(device)
        samples = self.decode(z)
        return samples
    def generate(self,x :torch.Tensor)-> torch.Tensor:
        return self(x)[0]

In [147]:
def training_loop(epochs):
    model = VAE(in_channels=3,latent_dim=72)
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),lr=1e-3)
    for p in model.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    model.train()
    for step in range(epochs+1):
        batch_iterator = tqdm(train_loader,desc=f'Processing epoch{step:02d}')
        for x,_ in batch_iterator:
            x = x.to(device)
            optimizer.zero_grad()
            out , mu , log_var = model(x)
            loss1 = F.mse_loss(out,x)
            loss2 = torch.mean(-0.5 * torch.sum(1+ log_var - mu**2 -log_var.exp(),dim=1),dim=0)
            loss = loss1 + model.kld_weight*loss2
            loss.backward()
            optimizer.step()
        if step % 2 == 0:
            val_loss = evaluation(model,val_loader)
            print(f'Epoch {step} Loss {loss.item()} val_loss {val_loss}')
    torch.save(model.state_dict(),'VAE.pth')
    return model 

In [143]:
def managing_data(data_name:str):
    ds = load_dataset(data_name)
    train_imgs ,train_labels = ds['train']['image'],ds['train']['label']
    val_imgs , val_labels = ds['valid']['image'],ds['valid']['label']
    return train_imgs,train_labels,val_imgs,val_labels
train_imgs,train_labels,val_imgs,val_labels = managing_data("zh-plus/tiny-imagenet")
def data_cleaning(images, labels):
    m = len(images)
    list_index = []
    transform = transforms.ToTensor()
    for i in range(m):
        images[i] = transform(images[i])
        if images[i].size()[0] == 1:
            list_index.append(i)
    s = 0
    for j in range(len(list_index)):
        images.pop(list_index[j] - s)
        labels.pop(list_index[j] - s)
        s += 1
    labels = torch.tensor(labels)
    return images, labels
train_imgs,train_labels = data_cleaning(train_imgs,train_labels)
val_imgs, val_labels = data_cleaning(val_imgs,val_labels)
train_dataset = data.TensorDataset(torch.stack(train_imgs), train_labels)


# Create the DataLoader
train_loader = data.DataLoader(train_dataset, batch_size=120, shuffle=True)
val_dataset = data.TensorDataset(torch.stack(val_imgs),val_labels)
val_loader = data.DataLoader(val_dataset,batch_size=50,shuffle=True)

In [126]:
def evaluation(model, test_loader):
    model.to(device)
    model.eval()  
    total_loss = 0.0

    loss_fn = nn.MSELoss()
    
    with torch.no_grad():  
        for x, _ in test_loader:
            x = x.to(device).float()  
            out, _, _ = model(x)
            loss = loss_fn(out, x)  
            total_loss += loss.item()  

    avg_loss = total_loss / len(test_loader)  
    return avg_loss
