In [1]:
import torch
import torch.nn as nn

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.c1 = nn.Conv2d(3, 64, 3, 2, padding=1)  # B x 64 x 64 x 64
        self.bn1 = nn.BatchNorm2d(64)
        self.lru1 = nn.LeakyReLU(1e-2)

        self.c2 = nn.Conv2d(64, 128, 3, 2, padding=1)  # B x 128 x 32 x 32
        self.bn2 = nn.BatchNorm2d(128)
        self.lru2 = nn.LeakyReLU(1e-2)

        self.c3 = nn.Conv2d(128, 128, 3, 2, padding=1)  # B x 128 x 16 x 16
        self.bn3 = nn.BatchNorm2d(128)
        self.lru3 = nn.LeakyReLU(1e-2)

        self.c4 = nn.Conv2d(128, 128, 3, 2, padding=1)  # B x 128 x 8 x 8
        self.bn4 = nn.BatchNorm2d(128)
        self.lru4 = nn.LeakyReLU(1e-2)

        self.c5 = nn.Conv2d(128, 128, 3, 2, padding=1)  # B x 128 x 4 x 4
        self.bn5 = nn.BatchNorm2d(128)
        self.lru5 = nn.LeakyReLU(1e-2)

        self.c6 = nn.Conv2d(128, 128, 3, 2, padding=1)  # B x 128 x 2 x 2
        self.bn6 = nn.BatchNorm2d(128)
        self.lru6 = nn.LeakyReLU(1e-2)

        self.flatten_layer = nn.Flatten(start_dim=1, end_dim=-1)
        self.z_mu = nn.Linear(128 * 2 * 2, 200)
        self.z_logvar = nn.Linear(128 * 2 * 2, 200)

    def forward(self, x):
        x = self.lru1(self.bn1(self.c1(x)))
        x = self.lru2(self.bn2(self.c2(x)))
        x = self.lru3(self.bn3(self.c3(x)))
        x = self.lru4(self.bn4(self.c4(x)))
        x = self.lru5(self.bn5(self.c5(x)))
        x = self.lru6(self.bn6(self.c6(x)))

        x = self.flatten_layer(x)
        z_mu = self.z_mu(x)
        z_logvar = self.z_logvar(x)

        return z_mu, z_logvar

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(200, 128 * 2 * 2)  # Match the latent space size

        self.deconv1 = nn.ConvTranspose2d(128, 128, 3, 2, padding=1, output_padding=1)  # B x 128 x 4 x 4
        self.bn1 = nn.BatchNorm2d(128)
        self.lru1 = nn.LeakyReLU(1e-2)

        self.deconv2 = nn.ConvTranspose2d(128, 128, 3, 2, padding=1, output_padding=1)  # B x 128 x 8 x 8
        self.bn2 = nn.BatchNorm2d(128)
        self.lru2 = nn.LeakyReLU(1e-2)

        self.deconv3 = nn.ConvTranspose2d(128, 128, 3, 2, padding=1, output_padding=1)  # B x 128 x 16 x 16
        self.bn3 = nn.BatchNorm2d(128)
        self.lru3 = nn.LeakyReLU(1e-2)

        self.deconv4 = nn.ConvTranspose2d(128, 128, 3, 2, padding=1, output_padding=1)  # B x 128 x 32 x 32
        self.bn4 = nn.BatchNorm2d(128)
        self.lru4 = nn.LeakyReLU(1e-2)

        self.deconv5 = nn.ConvTranspose2d(128, 64, 3, 2, padding=1, output_padding=1)  # B x 64 x 64 x 64
        self.bn5 = nn.BatchNorm2d(64)
        self.lru5 = nn.LeakyReLU(1e-2)

        self.deconv6 = nn.ConvTranspose2d(64, 3, 3, 2, padding=1, output_padding=1)  # B x 3 x 128 x 128
        self.output_activation = nn.Sigmoid()

    def forward(self, z):
        x = self.fc(z)
        x = x.view(-1, 128, 2, 2)

        x = self.deconv1(x)
        x = self.bn1(x)
        x = self.lru1(x)

        x = self.deconv2(x)
        x = self.bn2(x)
        x = self.lru2(x)

        x = self.deconv3(x)
        x = self.bn3(x)
        x = self.lru3(x)

        x = self.deconv4(x)
        x = self.bn4(x)
        x = self.lru4(x)

        x = self.deconv5(x)
        x = self.bn5(x)
        x = self.lru5(x)

        x = self.deconv6(x)
        x = self.output_activation(x)

        return x

In [None]:
class Sampling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self,z_mean,z_logvar):
        batch_size,dim = z_mean.shape
        epsilon = torch.randn(batch_size,dim,device=z_mean.device)
        return z_mean + torch.exp(0.5*z_logvar)*epsilon

In [None]:
def KL_Divergence_Loss(z_mean,z_logvar):
    loss = -0.5 * torch.sum(1 + z_logvar - z_mean.pow(2) - z_logvar.exp())
    loss = loss.mean()
    return loss

loss = nn.MSELoss(reduction='sum')

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "mps")

In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)
sampling_layer = Sampling().to(device)

class AutoEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.sampling_layer = sampling_layer

    def forward(self,x):
        z_mean,z_logvar = self.encoder(x)
        loss = KL_Divergence_Loss(z_mean,z_logvar)
        z_sample = self.sampling_layer(z_mean,z_logvar)
        out = self.decoder(z_sample)
        return loss,out

model = AutoEncoder().to(device)

In [None]:
from torch.optim.lr_scheduler import ReduceLROnPlateau
optimizer = torch.optim.AdamW(model.parameters(),lr=0.01,weight_decay=0.0001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=3, factor=0.1,threshold=0.001)