In [10]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from skimage.metrics import structural_similarity as ssim
import numpy as np
import sys

if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

# Define the Encoder Network
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU(True)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        return x

# Define the Decoder Network
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.deconv4 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1)
        self.relu = nn.ReLU(True)
        self.tanh = nn.Tanh()

    def forward(self, x):
        x = self.relu(self.deconv4(x))
        x = self.tanh(self.deconv5(x))
        return x

# Instantiate the networks and move them to the appropriate device
encoder = Encoder().to(device)
decoder = Decoder().to(device)

# Freeze the encoder parameters
for param in encoder.parameters():
    param.requires_grad = False

# Define the loss function and optimizer
criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(decoder.parameters(), lr=0.0002)

# Data loading and transformations
transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
])

# Using CIFAR-10 dataset for demonstration
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for batch_idx, (data, _) in enumerate(train_loader):
        # Move data to the appropriate device
        data = data.to(device)

        # Encode and then Decode
        with torch.no_grad():
            latent_rep = encoder(data)
        output = decoder(latent_rep)

        # Calculate loss
        loss = criterion(output, data)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_idx % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{batch_idx}/{len(train_loader)}], Loss: {loss.item():.4f}')

print("Training finished.")


Files already downloaded and verified
Epoch [1/10], Step [0/782], Loss: 0.3865
Epoch [1/10], Step [100/782], Loss: 0.0320
Epoch [1/10], Step [200/782], Loss: 0.0245
Epoch [1/10], Step [300/782], Loss: 0.0133
Epoch [1/10], Step [400/782], Loss: 0.0075
Epoch [1/10], Step [500/782], Loss: 0.0059
Epoch [1/10], Step [600/782], Loss: 0.0051
Epoch [1/10], Step [700/782], Loss: 0.0039
Epoch [2/10], Step [0/782], Loss: 0.0037
Epoch [2/10], Step [100/782], Loss: 0.0033
Epoch [2/10], Step [200/782], Loss: 0.0040
Epoch [2/10], Step [300/782], Loss: 0.0027
Epoch [2/10], Step [400/782], Loss: 0.0032
Epoch [2/10], Step [500/782], Loss: 0.0023
Epoch [2/10], Step [600/782], Loss: 0.0021
Epoch [2/10], Step [700/782], Loss: 0.0019
Epoch [3/10], Step [0/782], Loss: 0.0021
Epoch [3/10], Step [100/782], Loss: 0.0017
Epoch [3/10], Step [200/782], Loss: 0.0017
Epoch [3/10], Step [300/782], Loss: 0.0015
Epoch [3/10], Step [400/782], Loss: 0.0014
Epoch [3/10], Step [500/782], Loss: 0.0012
Epoch [3/10], Step [60

In [11]:
def get_model_size(model):
    param_size = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
    
    buffer_size = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
    
    size_in_bytes = param_size + buffer_size
    size_in_megabytes = size_in_bytes / (1024 ** 2)
    
    return size_in_megabytes

In [12]:
def calculate_ssim(im1, im2):
    ssim_value = np.mean([
        ssim(im1, im2, data_range=im2.max() - im2.min())
    ])
    
    return ssim_value

In [15]:
ssim_list = []
for i in data:
    latent = encoder(i)
    output = decoder(latent)
    for j in range(3):
        ssim_output = calculate_ssim(i.to('cpu').detach().numpy()[j], output.to('cpu').detach().numpy()[j])
        ssim_list.append(ssim_output)

encoder_size = get_model_size(encoder)
decoder_size = get_model_size(decoder)
ssim_list = np.array(ssim_list)
print("Average SSIM: ", ssim_list.mean())
encoder_size = get_model_size(encoder)
decoder_size = get_model_size(decoder)
print('Size of encoder {} MB'. format(encoder_size))
print('Size of decoder {} MB'. format(decoder_size))
print('Size of latent representation of image: {} KB'.format(sys.getsizeof(latent)/1024))

Average SSIM:  0.9763615762057603
Size of encoder 0.512451171875 MB
Size of decoder 0.5119743347167969 MB
Size of latent representation of image: 0.078125 KB
