In [11]:
from torchvision.datasets import Food101
from torchvision.datasets import Flowers102
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
from torch.utils.data import DataLoader, Dataset
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn.functional as F
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [12]:
batch_size = 20
down_level_quntize = 7
crop_shape = (64, 64)
resize_shape = (128, 128)
num_workers = 2

In [13]:
transform = transforms.Compose([
    transforms.RandomCrop((256, 256)),
    transforms.Resize(crop_shape),
    transforms.Grayscale(),
    transforms.ToTensor()
])
train_flower_dataset = Flowers102(root='.',split = 'train', download=True, transform =transform)
test_flower_dataset = Flowers102(root='.',split = 'test', download=True, transform =transform)

In [14]:
data ={
    'flower_train':DataLoader(train_flower_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers),
    'flower_test':DataLoader(test_flower_dataset, batch_size=batch_size, shuffle=True,num_workers=num_workers)
} 

In [15]:
def round_bits(x, quantize_bits):
    mul = 2**quantize_bits
    x = x * mul
    x = torch.floor(x)
    x = x / mul
    return x

class Encoder(nn.Module):
    def __init__(self, image_shape):
        super(Encoder,self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(image_shape[0]*image_shape[1],  image_shape[0]*image_shape[1]),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(image_shape[0]*image_shape[1],  image_shape[0]*image_shape[1]),
        )
        
        self.image_shape = image_shape
    def forward(self, x, quantize_bits=None):
        batch_size = x.shape[0]
        x = x.reshape(batch_size, -1)
        x = self.fc1(x)
        x = torch.sigmoid(x)
        x = self.fc2(x)
        x = torch.sigmoid(x)
        if quantize_bits is not None:
            x = round_bits(x, quantize_bits)
        x = torch.logit(x, eps=0.001)
        x = x.reshape((batch_size,1,self.image_shape[0],self.image_shape[1]))
        return x
    
class Decoder(nn.Module):
    def __init__(self, input_shape):
        super(Decoder, self).__init__()
        self.stage1_1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        self.stage1_2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        
        self.max_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.stage2_1 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.stage2_2 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.max_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.stage3_1 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.stage3_2 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        
        self.max_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.stage4_1 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        
        self.stage4_2 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
        )
        
        self.encoder_block1_1 = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU()
        )
        self.encoder_block1_2 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.encoder_block2_1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU()
        )
        self.encoder_block2_2 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )
        self.encoder_block3_1 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.encoder_block3_2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )
        
        self.encoder_block3_3 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=1, kernel_size=3, padding=1),
        )
        
        
    def forward(self, x):
        x = self.stage1_1(x)
        x1 = self.stage1_2(x)
        x = self.max_pool1(x1)
        x = self.stage2_1(x)
        x2 = self.stage2_2(x)
        x = self.max_pool2(x2)
        x = self.stage3_1(x)
        x3 = self.stage3_2(x)
        x = self.max_pool3(x3)
        x = self.stage4_1(x)
        x = self.stage4_2(x)
        x = self.encoder_block1_1(x)
        x = self.encoder_block1_2(x) + x3
        x = self.encoder_block2_1(x)
        x = self.encoder_block2_2(x) + x2
        x = self.encoder_block3_1(x)
        x = self.encoder_block3_2(x) + x1
        x = self.encoder_block3_3(x)
        x = torch.tanh(x)

        return x

class Autoencoder(nn.Module):
    def __init__(self, input_shape):
        super(Autoencoder, self).__init__()
        self.encoder = Encoder(input_shape)
        self.decoder = Decoder(input_shape)

    def forward(self, x, quantize_bits):
        x = self.encoder(x, quantize_bits)
        x = self.decoder(x)
        return x

In [16]:
import torch.optim as optim
autoencoder = Autoencoder(crop_shape)

optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)

# mean-squared error loss
criterion = nn.MSELoss()

# # load from file
# file_name = 'path_to_pt_file'
# checkpoint = torch.load(file_name)
# from collections import OrderedDict
# new_state_dict = OrderedDict()
# for k, v in checkpoint['model_state_dict'].items():
#     name = k[7:] # remove `module.`
#     new_state_dict[name] = v

# # load params
# autoencoder.load_state_dict(new_state_dict)

In [10]:
import random
save_every_epoch = 1000
model_name = 'new_cnn'
epochs=0
end_epochs=20000

train_loader = data["flower_train"]

autoencoder.to(device)
autoencoder = torch.nn.DataParallel(autoencoder)
        
t = transforms.ToPILImage()
real_image_to_test, _ = random.choice(test_flower_dataset)
image1 = t(real_image_to_test)
min_loss = 10000000

for epoch in range(epochs, end_epochs):
    loss = 0
    for real_image, _ in train_loader:
        
        optimizer.zero_grad()        
        # compute reconstructions
        outputs = autoencoder(real_image.to(device), down_level_quntize)
        # compute training reconstruction loss
        outputs = outputs.to(device)
        train_loss = criterion(outputs, real_image.to(device))
        # compute accumulated gradients
        train_loss.backward()        
        # perform parameter update based on current gradients
        optimizer.step()        
        # add the mini-batch training loss to epoch loss
        loss += train_loss.item()  
        
    # compute the epoch training loss
    loss = loss / len(train_loader)    
    # display the epoch training loss
    print("epoch : {}/{}, loss = {:.6f}".format(epoch + 1, end_epochs, loss))
    if(min_loss>loss):
        print('save min loss epoch: ' + str(epoch + 1))
        torch.save({
            'model_state_dict': autoencoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, "model_and_optimizer_"+model_name+"_"+"minLoss"+".pt")
        min_loss = loss
    if (epoch+1)%save_every_epoch == 0:
        torch.save({
            'model_state_dict': autoencoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()
        }, "model_and_optimizer_"+model_name+"_"+str(epoch+1)+".pt")
        
        
        # SAVE FIGURE
        output = autoencoder(real_image_to_test.unsqueeze(0).to(device), down_level_quntize)
        image2 = t(output[0])

        # Create a figure with 3 subplots and a title
        fig, axs = plt.subplots(1, 3, figsize=(20, 8))
        fig.suptitle(model_name+"_Epoch"+str(epoch+1)+"_LOSS"+str(loss))

        # Display the first image in the first subplot
        axs[0].imshow(image1, cmap='gray')
        axs[0].axis("off")
        axs[0].set_title("Origin")
        axs[1].imshow(image2, cmap='gray')
        axs[1].axis("off")
        axs[2].set_title("Restored Image")
        plt.show()

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x110418430>
Traceback (most recent call last):
  File "/Users/idoshitrit/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1478, in __del__
    self._shutdown_workers()
  File "/Users/idoshitrit/miniconda3/envs/torch-gpu/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1442, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/idoshitrit/miniconda3/envs/torch-gpu/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/idoshitrit/miniconda3/envs/torch-gpu/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/idoshitrit/miniconda3/envs/torch-gpu/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/idoshitrit/miniconda3/envs/torch-gpu/lib/p