In [1]:
from torch.nn import functional as F
from torchsummary import summary
import src.torch_utils as dist_fn

import argparse
import sys
import os
from skimage import io

import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils

from tqdm import tqdm

# from scheduler import CycleScheduler
import src.torch_utils as dist
from umap import UMAP
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from collections import Counter
from skimage import transform, metrics
import skimage

import numpy as np


In [2]:
class Quantize(nn.Module):
    def __init__(self, dim, n_embed, decay=0.99, eps=1e-5):
        super().__init__()

        self.dim = dim
        self.n_embed = n_embed
        self.decay = decay
        self.eps = eps

        embed = torch.randn(dim, n_embed)
        self.register_buffer("embed", embed)
        self.register_buffer("cluster_size", torch.zeros(n_embed))
        self.register_buffer("embed_avg", embed.clone())

    def forward(self, input):
        flatten = input.reshape(-1, self.dim)
        dist = (
                flatten.pow(2).sum(1, keepdim=True)
                - 2 * flatten @ self.embed
                + self.embed.pow(2).sum(0, keepdim=True)
        )
        _, embed_ind = (-dist).max(1)
        embed_onehot = F.one_hot(embed_ind, self.n_embed).type(flatten.dtype)
        embed_ind = embed_ind.view(*input.shape[:-1])
        quantize = self.embed_code(embed_ind)

        if self.training:
            embed_onehot_sum = embed_onehot.sum(0)
            embed_sum = flatten.transpose(0, 1) @ embed_onehot

            dist_fn.all_reduce(embed_onehot_sum)
            dist_fn.all_reduce(embed_sum)

            self.cluster_size.data.mul_(self.decay).add_(
                embed_onehot_sum, alpha=1 - self.decay
            )
            self.embed_avg.data.mul_(self.decay).add_(embed_sum, alpha=1 - self.decay)
            n = self.cluster_size.sum()
            cluster_size = (
                    (self.cluster_size + self.eps) / (n + self.n_embed * self.eps) * n
            )
            embed_normalized = self.embed_avg / cluster_size.unsqueeze(0)
            self.embed.data.copy_(embed_normalized)

        diff = (quantize.detach() - input).pow(2).mean()
        quantize = input + (quantize - input).detach()

        return quantize, diff, embed_ind

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.embed.transpose(0, 1))


class ResBlock(nn.Module):
    def __init__(self, in_channel, channel):
        super().__init__()

        self.conv = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(in_channel, channel, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel, in_channel, 1),
        )

    def forward(self, input):
        out = self.conv(input)
        out += input

        return out


class Encoder(nn.Module):
    def __init__(self, in_channel, channel, n_res_block, n_res_channel, stride):
        super().__init__()

        if stride == 4:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel, channel, 3, padding=1),
            ]

        elif stride == 2:
            blocks = [
                nn.Conv2d(in_channel, channel // 2, 4, stride=2, padding=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(channel // 2, channel, 3, padding=1),
            ]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


class Decoder(nn.Module):
    def __init__(
            self, in_channel, out_channel, channel, n_res_block, n_res_channel, stride
    ):
        super().__init__()

        blocks = [nn.Conv2d(in_channel, channel, 3, padding=1)]

        for i in range(n_res_block):
            blocks.append(ResBlock(channel, n_res_channel))

        blocks.append(nn.ReLU(inplace=True))

        if stride == 4:
            blocks.extend(
                [
                    nn.ConvTranspose2d(channel, channel // 2, 4, stride=2, padding=1),
                    nn.ReLU(inplace=True),
                    nn.ConvTranspose2d(
                        channel // 2, out_channel, 4, stride=2, padding=1
                    ),
                ]
            )

        elif stride == 2:
            blocks.append(
                nn.ConvTranspose2d(channel, out_channel, 4, stride=2, padding=1)
            )

        self.blocks = nn.Sequential(*blocks)

    def forward(self, input):
        return self.blocks(input)


class VQVAE(nn.Module):
    def __init__(
            self,
            in_channel=3,
            channel=128,
            n_res_block=5,
            n_res_channel=32,
            embed_dim=64,
            n_embed=512,
            decay=0.99,
    ):
        super().__init__()

        self.enc_b = Encoder(in_channel, channel, n_res_block, n_res_channel, stride=4)
        self.enc_t = Encoder(channel, channel, n_res_block, n_res_channel, stride=2)
        self.quantize_conv_t = nn.Conv2d(channel, embed_dim, 1)
        self.quantize_t = Quantize(embed_dim, n_embed)
        self.dec_t = Decoder(
            embed_dim, embed_dim, channel, n_res_block, n_res_channel, stride=2
        )
        self.quantize_conv_b = nn.Conv2d(embed_dim + channel, embed_dim, 1)
        self.quantize_b = Quantize(embed_dim, n_embed)
        self.upsample_t = nn.ConvTranspose2d(
            embed_dim, embed_dim, 4, stride=2, padding=1
        )
        self.dec = Decoder(
            embed_dim + embed_dim,
            in_channel,
            channel,
            n_res_block,
            n_res_channel,
            stride=4,
        )

    def forward(self, input):
        quant_t, quant_b, diff, _, _ = self.encode(input)
        dec = self.decode(quant_t, quant_b)

        return dec, diff

    def encode(self, input):
        enc_b = self.enc_b(input)
        enc_t = self.enc_t(enc_b)

        quant_t = self.quantize_conv_t(enc_t).permute(0, 2, 3, 1)
        quant_t, diff_t, id_t = self.quantize_t(quant_t)
        quant_t = quant_t.permute(0, 3, 1, 2)
        diff_t = diff_t.unsqueeze(0)

        dec_t = self.dec_t(quant_t)
        enc_b = torch.cat([dec_t, enc_b], 1)

        quant_b = self.quantize_conv_b(enc_b).permute(0, 2, 3, 1)
        quant_b, diff_b, id_b = self.quantize_b(quant_b)
        quant_b = quant_b.permute(0, 3, 1, 2)
        diff_b = diff_b.unsqueeze(0)

        return quant_t, quant_b, diff_t + diff_b, id_t, id_b

    def decode(self, quant_t, quant_b):
        upsample_t = self.upsample_t(quant_t)
        quant = torch.cat([upsample_t, quant_b], 1)
        dec = self.dec(quant)

        return dec

    def decode_code(self, code_t, code_b):
        quant_t = self.quantize_t.embed_code(code_t)
        quant_t = quant_t.permute(0, 3, 1, 2)
        quant_b = self.quantize_b.embed_code(code_b)
        quant_b = quant_b.permute(0, 3, 1, 2)

        dec = self.decode(quant_t, quant_b)

        return dec

In [3]:
device = "cuda"
model = VQVAE().to(device)


In [None]:
# summary(model, input_size=(3, 512, 512))

In [4]:
dataset_path = 'data/dataset_t/'
# resize_shape=(1024,1024)
resize_shape=(512,512)
# resize_shape = (256, 256)
n_gpu = 1
batch_size = 8
val_split = 0.25

transform = transforms.Compose(
    [
        transforms.Resize(resize_shape),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(42))

train_sampler = dist.data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = dist.data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
model_file='data/logs/vq-vae-2/weights/vqvae_002_train_0.00976_test_0.00967.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

In [None]:
epochs = 100
lr = 1e-4
sample_path = 'data/logs/vq-vae-2/samples'
model_path = 'data/logs/vq-vae-2/weights'

optimizer = optim.Adam(model.parameters(), lr=lr,weight_decay=1e-7,amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

for epoch in range(epochs):

    if dist.is_primary():
        train_loader = tqdm(train_loader)

    criterion = nn.MSELoss()

    latent_loss_weight = 0.25
    sample_size = 25

    mse_sum = 0
    mse_n = 0
    test_mean_loss = []
    train_mean_loss = []

    for i, (img, label) in enumerate(train_loader):
        model.zero_grad()

        img = img.to(device)

        out, latent_loss = model(img)
        recon_loss = criterion(out, img)
        latent_loss = latent_loss.mean()
        loss = recon_loss + latent_loss_weight * latent_loss
        loss.backward()
        train_mean_loss.append(loss.item())
        # if scheduler is not None:
        #     scheduler.step()
        optimizer.step()

        part_mse_sum = recon_loss.item() * img.shape[0]
        part_mse_n = img.shape[0]
        comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
        comm = dist.all_gather(comm)

        for part in comm:
            mse_sum += part["mse_sum"]
            mse_n += part["mse_n"]

        if dist.is_primary():
            lr = optimizer.param_groups[0]["lr"]

            train_loader.set_description(
                (
                    f"epoch: {epoch + 1}; loss: {str(round(np.mean(train_mean_loss), 5))}; mse: {recon_loss.item():.5f}; "
                    f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                    f"lr: {lr:.5f}"
                )
            )

        model.train()

    model.eval()

    with torch.no_grad():

        for j, (img, label) in enumerate(test_loader):
            img = img.to(device)
            out, latent_loss = model(img)
            test_recon_loss = criterion(out, img)
            test_latent_loss = latent_loss.mean()
            test_loss = test_recon_loss + latent_loss_weight * latent_loss
            test_mean_loss.append(round(test_loss.item(), 5))

        sample = img[:sample_size]

    utils.save_image(
        torch.cat([sample, out], 0),
        f"{sample_path}/{str(epoch + 1).zfill(5)}.png",
        nrow=sample_size,
        normalize=True,
        range=(-1, 1),
    )

    print(f'test elbo: {str(round(np.mean(test_mean_loss), 5))}')
    torch.save(model.state_dict(),
               f"{model_path}/vqvae_{str(epoch + 1).zfill(3)}_train_{str(round(np.mean(train_mean_loss), 5))}_test_{str(round(np.mean(test_mean_loss), 5))}.pt")

In [None]:
sample_path = 'data/logs/vq-vae-2/samples'
model_path = 'data/logs/vq-vae-2/weights'

with torch.no_grad():
    for j, (img, label) in enumerate(test_loader):

        img = img.to(device)
        out, latent_loss = model(img)

        sample_size=25
        sample = img[:sample_size]

        utils.save_image(
            torch.cat([sample, out], 0),
            f"{sample_path}/{j}_vq_vae_2_test.png",
            nrow=sample_size,
            normalize=True,
            range=(-1, 1),
        )

In [7]:
model_file='data/logs/vq-vae-2/vqvae_100_train_0.00409_test_0.00401_512.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

<All keys matched successfully>

In [8]:
criterion = nn.MSELoss()
latent_loss_weight = 0.25
images_embs_t = []
images_embs_b = []

visual_loader = DataLoader(
    dataset, batch_size=1 // n_gpu, num_workers=2
)


for i, (img, label) in enumerate(visual_loader):
    model.zero_grad()
    img = img.to(device)

    out, latent_loss = model(img)
    quant_t, quant_b, diff, _, _ = model.encode(img)
    images_embs_t.append(quant_t.cpu().detach().numpy().flatten())
    images_embs_b.append(quant_b.cpu().detach().numpy().flatten())


In [11]:
images_embs_b[0].shape

(1048576,)

In [None]:
sample = test_loader.dataset[45][0]
img = torch.unsqueeze(sample, 0)
img = img.to(device)
model_out = model(img)[0]
out = model_out.cpu().detach().numpy()[0]
out = np.moveaxis(out, 0, -1)
out = skimage.color.rgb2gray(out) * 255
out = out.astype(np.uint8)

utils.save_image(
    model_out,
    "vq_vae_2_test.png",
    nrow=1,
    normalize=True,
    range=(-1, 1),
)

plt.figure(figsize=(10, 10))
plt.imshow(out, cmap='gray')
# io.imsave('vq_vae_2_test.png',out)

In [None]:
seed = 51
umap_2d = UMAP(random_state=seed)
umaped_vct_2d_t = umap_2d.fit_transform(images_embs_t)
umaped_vct_2d_b = umap_2d.fit_transform(images_embs_b)
legend = ['Ultra_Co8\nсредние зерна', 'Ultra_Co11\nмелкие зерна', 'Ultra_Co6_2\nмелкие зерна',
          'Ultra_Co15\nкрупные зерна', 'Ultra_Co25\nсредне-мелкие зерна']

In [None]:
# def plot_2d_scatter(images, legend, dot_size=20, fontsize=15, save=False, plot=True, N=15, M=15):
N = 15
M = 15
fontsize = 15
dot_size = 40

embs_scatter = umaped_vct_2d_t

fig, ax = plt.subplots(figsize=(N, M))

colors = ['b', 'g', 'y', 'm', 'c']
markers = ['8', 'v', 's', 'd', '*', ]

f_vects = []
cnt = Counter(np.array(dataset.samples)[:, 1])

for i in cnt.keys():
    emb_number = cnt[i]
    key = int(i)

    start = 0
    for val in range(int(i)):
        start += cnt[str(val)]

    stop = start + emb_number
    ax.scatter(embs_scatter[start:stop, 0], embs_scatter[start:stop, 1], color=colors[key], s=dot_size,
               marker=markers[key])

ax.legend(legend, fontsize=fontsize)
plt.savefig(f'embs_space_seed={seed}_t_512.png')
plt.show()

In [None]:
mse_losses = []
ssim_losses = []


visual_loader = DataLoader(
    dataset, batch_size=1 , num_workers=2
)

for i, (img, label) in enumerate(visual_loader):
    model.zero_grad()
    img = img.to(device)

    out, latent_loss = model(img)
    predicted_image = np.transpose(out.cpu().detach().numpy()[0])
    original_image = np.transpose(img.cpu().detach().numpy()[0])

    mse_losses.append(metrics.mean_squared_error(original_image, predicted_image))
    ssim_losses.append(metrics.structural_similarity(original_image, predicted_image, multichannel=True))

mse_losses = np.array(mse_losses)
ssim_losses = np.array(ssim_losses)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(30, 15))

names = ['Ultra_Co8\nсредние зерна', 'Ultra_Co11\nмелкие зерна', 'Ultra_Co6_2\nмелкие зерна',
         'Ultra_Co15\nкрупные зерна', 'Ultra_Co25\nсредне-мелкие зерна']

colors = ['b', 'g', 'y', 'm', 'c']
markers = ['8', 'v', 's', 'd', '*', ]

f_vects = []

cnt = Counter(np.array(dataset.samples)[:, 1])

for i in cnt.keys():
    emb_number = cnt[i]
    key = int(i)

    start = 0
    for val in range(int(i)):
        start += cnt[str(val)]

    end = start + emb_number

    x = np.arange(mse_losses.shape[0])[start:end]
    mse_y = mse_losses[start:end]
    ssim_y = ssim_losses[start:end]

    # ax1.plot(x, mse_y)
    ax1.scatter(x, mse_y)

    # ax2.plot(x, ssim_y)
    ax2.scatter(x, ssim_y)

plt.rcParams['font.size'] = '20'

ax1.legend(names, fontsize=15)
ax1.set_title('MSE image losses', fontsize=20)
ax1.set_xlabel('image number', fontsize=20)
ax1.set_ylabel('MSE loss', fontsize=20)

ax2.legend(names, fontsize=15)
ax2.set_title('Structural image similarity', fontsize=20)
ax2.set_xlabel('image number', fontsize=20)
ax2.set_ylabel('SSIM similarity', fontsize=20)

name = 'vq-vae-2'
plt.savefig(f'mse_ssim_losses_{name}_512.png')

plt.show()