In [1]:
import os
import uuid
import torch
import requests
import glob
import cv2

import imageio.v2 as io
import numpy as np
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import torch.nn.functional as F

from PIL import Image
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset

torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True

ModuleNotFoundError: No module named 'imageio'

In [None]:
def show(images, N=4):
    x = images.to(device)
    y = model(x.to(device))
    
    fig, axes = plt.subplots(2, N, figsize=(20, 10))
    for i in range(min(len(x), N)):
        
        img = x.detach().cpu().numpy()[i].transpose(1, 2, 0)
        axes[0, i].imshow(img, cmap='gray')
        axes[0, i].axis('off')
        axes[0, i].set_title(f"input")

        img = y.detach().cpu().numpy()[i].transpose(1, 2, 0)
        img = img - np.min(img)
        img = img / np.max(img)
        axes[1, i].imshow(img, cmap='gray')
        axes[1, i].axis('off')
        axes[1, i].set_title(f"pred")
    fig.tight_layout()
    plt.show()

In [None]:
class Resize:
    def __init__(self, H, W):
        self.H = H
        self.W = W
        
    def __call__(self, img):
        c, w, h = img.shape
        min_size = min(w, h)
        img = img[None, :, :min_size, :min_size]
        img = F.interpolate(img, (self.H, self.W), mode="bilinear")
        img = img[0]
        
        if c == 1:
            img = img.repeat(3, 1, 1)
        
        return img

In [None]:
class Laion400MDataset(Dataset):
    def __init__(self, root_dir, img_size=256, transform=None, split=0.8, train=True):
        self.root_dir = root_dir
        self.transform = transform
        self.img_size = img_size
        self.train = train

        self.img_files = np.sort(glob.glob(os.path.join(root_dir, "imgs", "*")))

        filtered_files = list()
        for path in self.img_files:
            try:
                img = Image.open(path)

                if img.mode == "P":
                    continue
                
                filtered_files.append(path)
            except Exception as e:
                pass
        self.img_files = filtered_files 
            
        if train:
            n = int(len(self.img_files) * 0.8)
            self.img_files = self.img_files[:n]
        else:
            n = int(len(self.img_files) * 0.8)
            self.img_files = self.img_files[n:]

    def __len__(self):
        return len(self.img_files)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_path = self.img_files[idx]
        
        with Image.open(img_path) as img:

            mode = img.mode
            img = img.convert('RGB')
            img = np.asarray(img)
            h, w, c = img.shape
            
            if h > w:
                img_ = np.zeros((h, h, c))
                diff = h - w
                d1 = diff // 2
                d2 = diff - d1
                img_[:, d1:d1 + w] = img
                img = img_
            elif w > h:
                img_ = np.zeros((w, w, c))
                diff = w - h
                d1 = diff // 2
                d2 = diff - d1
                img_[d1:d1 + h] = img
                img = img_
            else:
                pass

        img = cv2.resize(img.copy(), (self.img_size, self.img_size))
        img = img.transpose(2, 0, 1)
        img = np.asarray(img, dtype=np.float32)
        
        img = img - img.min()
        if np.max(img) != 0:
            img = img / np.max(img)

        element = dict()
        element["img"] = img
                
        return element

In [None]:
num_epochs = 100
learning_rate = 1e-3
device = "cuda"
epoch = 0
batch_size = 4
N_BS = 32
IMG_SIZE = 2048
file_path = "data/autoencoder.pth"
ROOT_PATH = "/mnt/data/laion400M/"

In [None]:
# torch.save(model.state_dict(), file_path)

In [None]:
transform = transforms.Compose([transforms.ToTensor(), Resize(IMG_SIZE, IMG_SIZE)])
dataset = Laion400MDataset(root_dir=ROOT_PATH, img_size=IMG_SIZE)
data_loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
len(dataset)

In [None]:
class Encoder(nn.Module):
    def __init__(self, ch_in=3, ch=32, ch_out=64):
        super().__init__()

        self.pool = nn.MaxPool2d(2, 2)
        self.relu = nn.ReLU()

        self.conv1 = nn.Conv2d(ch_in, ch, 3, 1)
        self.conv2 = nn.Conv2d(ch, ch, 3, 1)
        self.conv3 = nn.Conv2d(ch, ch, 3, 1)
        self.conv4 = nn.Conv2d(ch, ch, 3, 1)
        self.conv5 = nn.Conv2d(ch, ch, 3, 1)
        self.conv6 = nn.Conv2d(ch, ch, 3, 1)
        self.conv7 = nn.Conv2d(ch, ch_out, 3, 1)

    def forward(self, x):

        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv3(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv4(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv5(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv6(x)
        x = self.relu(x)
        x = self.pool(x)

        x = self.conv7(x)
        x = self.relu(x)
        x = self.pool(x)
        
        return x
class Decoder(nn.Module):
    def __init__(self, ch_in=64, ch=32, ch_out=3):
        super().__init__()

        self.relu = nn.ReLU()
        
        self.up1 = nn.ConvTranspose2d(ch_in, ch, 4, 2)
        self.up2 = nn.ConvTranspose2d(ch, ch, 4, 2)
        self.up3 = nn.ConvTranspose2d(ch, ch, 4, 2)
        self.up4 = nn.ConvTranspose2d(ch, ch, 4, 2)
        self.up5 = nn.ConvTranspose2d(ch, ch, 4, 2)
        self.up6 = nn.ConvTranspose2d(ch, ch, 4, 2)
        self.up7 = nn.ConvTranspose2d(ch, ch_out, 4, 2)

    def forward(self, x0):

        x = self.up1(x0)
        x1 = self.relu(x)

        x = self.up2(x1)
        x2 = self.relu(x)

        x = self.up3(x2)
        x3 = self.relu(x)

        x = self.up4(x3)
        x4 = self.relu(x)

        x = self.up5(x4)
        x5 = self.relu(x)

        x = self.up6(x5)
        x6 = self.relu(x)

        x = self.up7(x6)

        x = F.interpolate(x, (IMG_SIZE, IMG_SIZE), mode="bilinear")

        x = nn.Tanh()(x)
        
        return x
class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()

        self.encoder = Encoder(ch_in=3, ch=64, ch_out=128)
        self.decoder = Decoder(ch_in=128, ch=64, ch_out=3)
    
    def forward(self, x):
        self.latent = self.encoder(x)
        x = self.decoder(self.latent)
        return x

In [None]:
model = Autoencoder().to(device)
criterion = nn.MSELoss(reduction="mean")
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
print(sum(p.numel() for p in model.encoder.parameters() if p.requires_grad))
print(sum(p.numel() for p in model.decoder.parameters() if p.requires_grad))
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
# model.load_state_dict(torch.load(file_path))

In [None]:
element = next(iter(data_loader))

x = element["img"]
x = x[0:1]
x = x.to(device)
print(x.shape, x.numel())

latent = model.encoder(x)
print(latent.shape, latent.numel())
print("Compression: ", x.numel() / latent.numel())
y = model.decoder(latent)
print(y.shape, y.numel())

In [None]:
element = next(iter(data_loader))
images = element["img"]
show(images)

In [None]:
scaler = torch.cuda.amp.GradScaler()
for epoch in range(epoch, num_epochs - epoch):
    losses = list()
    iterator = iter(data_loader)
    for batch in range(len(data_loader)):
        optimizer.zero_grad()
        try:
            for i in range(N_BS):
                element = next(iterator)
                images = element["img"]
                with torch.cuda.amp.autocast():
                    y = model(images.to(device))
                    loss = criterion(y, images.to(device))
                    losses.append(loss.item())
                
                scaler.scale(loss).backward()
        except Exception as e:
            iterator = iter(data_loader)
            
        scaler.step(optimizer)
        scaler.update()

        if (batch + 1) % 500 == 0:
            show(images)

        if (batch + 1) % 100 == 0:
            print(f'\rEpoch [{epoch + 1}/{num_epochs}], Batch [{batch + 1}/{len(data_loader)}], loss: {np.mean(losses):.6f}', end="")
            losses = list()

    if (epoch + 1) % 1 == 0:
        show(images)