In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils
import torchvision

from pt_utils import  Embeddings, Trainer, VQVAE, data_sampler, Vqvae2AdaptiveVae, VanillaVAE
from torchsummary import summary
import os
from torch import distributed as dist
from tqdm.notebook import trange, tqdm
from torchvision.datasets import ImageFolder
import numpy as np
from typing import Any, Callable, cast, Dict, List, Optional, Tuple

seed = 51
np.random.seed(seed)
torch.manual_seed(seed)

In [None]:
# device = "cuda"
device = "cpu"

model =    VQVAE(in_channel=3,
# model =    Vqvae2Adaptive(in_channel=3,
                   channel=128,
                   n_res_block=6,
                   n_res_channel=32,
                   embed_dim=8,
                   n_embed=8192,
                   decay=0.99).to(device)

# model =    Vqvae2Adaptive(in_channel=3,
#                  channel=128,
#                  n_res_block=6,
#                  n_res_channel=32,
#                  embed_dim=1,
#                  n_embed=8192,
#                  decay=0.99).to(device)

In [None]:
l=[]
for i, (name, param) in enumerate(model.named_parameters()):
    if param.requires_grad:
        l.append((i, name))
l[-1]

In [None]:
model.named_parameters()

In [None]:
summary(model, input_size=(3, 512, 512))

# MSE loss

In [None]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../datasets/original/o_bc_left_4x_768'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size = 4
val_split = 0.15

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
# model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_left_sub_right_minmax_4x_360/vqvae_001_train_0.04914_test_0.04206.pt'
model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360/vqvae_003_train_0.04287_test_0.04129.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

In [None]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
model_path = 'data/logs/emb_dim_2_n_embed_8192_bc_left_9x_512_360'

Trainer.train(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs)


# Triplet loss

In [None]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../datasets/original/o_bc_left_4x_768'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size = 4
val_split = 0.15

transform = transforms.Compose(
    [
        transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        # transforms.Grayscale(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = TripletFolder(dataset_path, transform=transform)
# dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
class TripletLoss(nn.Module):
    def __init__(self, margin = 1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def calc_euclidean(self, x1, x2):
        return(x1 - x2).pow(2).sum(1)

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.calc_euclidean(anchor, positive)
        distance_negative_a = self.calc_euclidean(anchor, negative)
        distance_negative_b = self.calc_euclidean(positive, negative)

        losses = torch.relu(distance_positive - (distance_negative_a + distance_negative_b)/2.0 + self.margin)

        return losses.mean()

In [None]:
def is_primary():
    return get_rank() == 0

def get_rank():
    if not dist.is_available():
        return 0

    if not dist.is_initialized():
        return 0

    return dist.get_rank()

def all_gather(data):
    world_size = get_world_size()

    if world_size == 1:
        return [data]

    buffer = pickle.dumps(data)
    storage = torch.ByteStorage.from_buffer(buffer)
    tensor = torch.ByteTensor(storage).to("cuda")

    local_size = torch.IntTensor([tensor.numel()]).to("cuda")
    size_list = [torch.IntTensor([1]).to("cuda") for _ in range(world_size)]
    dist.all_gather(size_list, local_size)
    size_list = [int(size.item()) for size in size_list]
    max_size = max(size_list)

    tensor_list = []
    for _ in size_list:
        tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda"))

    if local_size != max_size:
        padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda")
        tensor = torch.cat((tensor, padding), 0)

    dist.all_gather(tensor_list, tensor)

    data_list = []

    for size, tensor in zip(size_list, tensor_list):
        buffer = tensor.cpu().numpy().tobytes()[:size]
        data_list.append(pickle.loads(buffer))

    return data_list

def get_world_size():
    if not dist.is_available():
        return 1

    if not dist.is_initialized():
        return 1

    return dist.get_world_size()

def train_triplet(model, optimizer, train_loader, test_loader, model_path, epochs=100, device='cuda',
          latent_loss_weight=0.25, sample_size=25):

    if os.path.exists(model_path) is False:
        os.mkdir(model_path)

    for epoch in range(epochs):

        if is_primary():
            train_loader = tqdm(train_loader)

        criterion = nn.TripletMarginLoss()

        mse_sum = 0
        mse_n = 0
        test_mean_loss = []
        train_mean_loss = []

        for i, data in enumerate(train_loader):
            model.zero_grad()
            anchor_img, positive_img, negative_img = data
            
            anchor_img = anchor_img.to(device)
            positive_img = positive_img.to(device)
            negative_img = negative_img.to(device)

            anchor_out, anchor_latent_out = model(anchor_img)
            positive_out, positive_latent_out = model(positive_img)
            negative_out, negative_latent_out = model(negative_img)

            # triplet loss 
            recon_loss = criterion(anchor_out, positive_out, negative_out)
            latent_loss=torch.mean(torch.stack([anchor_latent_out.mean(),
                                                                 anchor_latent_out.mean(),
                                                                 anchor_latent_out.mean() ]))
            
            loss = recon_loss + latent_loss_weight * latent_loss
            loss.backward()
            train_mean_loss.append(loss.item())
            optimizer.step()

            part_mse_sum = recon_loss.item() * anchor_img.shape[0]
            part_mse_n = anchor_img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            if is_primary():
                lr = optimizer.param_groups[0]["lr"]

                train_loader.set_description(
                    (
                        f"epoch: {epoch + 1}; loss: {str(round(np.mean(train_mean_loss), 5))}; mse: {recon_loss.item():.5f}; "
                        f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                        f"lr: {lr:.5f}"
                    )
                )

            model.train()

        model.eval()

        with torch.no_grad():

            for j, data in enumerate(test_loader):
                anchor_img, positive_img, negative_img = data

                anchor_img = anchor_img.to(device)
                positive_img = positive_img.to(device)
                negative_img = negative_img.to(device)

                anchor_out, anchor_latent_out = model(anchor_img)
                positive_out, positive_latent_out = model(positive_img)
                negative_out, negative_latent_out = model(negative_img)

                # triplet loss 
                recon_loss = criterion(anchor_out, positive_out, negative_out)

                test_loss = recon_loss + latent_loss_weight * torch.mean(torch.stack([anchor_latent_out.mean(),
                                                                     anchor_latent_out.mean(),
                                                                     anchor_latent_out.mean() ]))
                test_mean_loss.append(round(test_loss.item(), 5))

            sample = anchor_img[:sample_size]

        utils.save_image(
            torch.cat([sample, anchor_out], 0),
            f"{model_path}/{str(epoch + 1).zfill(5)}.png",
            nrow=sample_size,
            normalize=True,
            # range=(-1, 1),
        )

        print(f'test elbo: {str(round(np.mean(test_mean_loss), 5))}')
        torch.save(model.state_dict(),
                   f"{model_path}/vqvae_{str(epoch + 1).zfill(3)}_train_{str(round(np.mean(train_mean_loss), 5))}_test_{str(round(np.mean(test_mean_loss), 5))}.pt")

In [None]:
# model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_left_sub_right_minmax_4x_360/vqvae_001_train_0.04914_test_0.04206.pt'
model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360/vqvae_003_train_0.04287_test_0.04129.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

In [None]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
# model_path = 'data/logs/emb_dim_2_n_embed_8192_bc_left_9x_512_360'
model_path = 'data/logs/emb_dim_1_n_embed_8192o_o_bc_left_4x_512_triplet'

train_triplet(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs)


# VQ-VAE-2 + VAE

In [None]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../datasets/original/o_bc_left_4x_768'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size =4
val_split = 0.15

transform = transforms.Compose(
    [
        transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
class Vqvae2AdaptiveVae1(Vqvae2AdaptiveVae):

    def forward(self, input,n_embedded_l=None, dim_l=None):
        quant_t, quant_b, diff, _, _ = self.encode(input,n_embedded_l=n_embedded_l, dim_l=dim_l)
        # quant_t=torch.squeeze(quant_t, 1)
        # quant_b=torch.squeeze(quant_b, 1)
        quant_t_out, mu_t, log_var_t=self.vae_top(quant_t)
        quant_b_out, mu_b, log_var_b=self.vae_bottom(quant_b)
        dec = self.decode(quant_t_out, quant_b_out)

        elbo_t=self.elbo_loss(quant_t_out, quant_t, mu_t, log_var_t)
        elbo_b=self.elbo_loss(quant_b_out, quant_b, mu_b, log_var_b)

        return dec, diff, elbo_t, elbo_b

    def elbo_loss(self, recon_x, x, mu, logvar, beta=1):
        """
        ELBO Optimization objective for gaussian posterior
        (reconstruction term + regularization term)
        """
        reconstruction_function = nn.MSELoss(reduction='sum')
        MSE = reconstruction_function(recon_x, x)
    
        # https://arxiv.org/abs/1312.6114 (Appendix B)
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    
        KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        KLD = torch.sum(KLD_element).mul_(-0.5)
    
        return MSE + beta*KLD


In [None]:
device = "cuda"
# device = "cpu"


model =    Vqvae2AdaptiveVae1(in_channel=3,
                   channel=128,
                   n_res_block=6,
                   n_res_channel=32,
                   embed_dim=1,
                   n_embed=8192,
                   decay=0.99,
                    latent_dims=200).to(device)


In [None]:
l=[]
for i, (name, param) in enumerate(model.named_parameters()):
    if param.requires_grad:
        l.append((i, name))
l[-1]

In [None]:
def train(model, optimizer, train_loader, test_loader, model_path, epochs=100, device='cuda',
          latent_loss_weight=0.25, sample_size=25):

    if os.path.exists(model_path) is False:
        os.mkdir(model_path)

    for epoch in range(epochs):

        if is_primary():
            train_loader = tqdm(train_loader)

        criterion = nn.MSELoss()

        mse_sum = 0
        mse_n = 0
        test_mean_loss = []
        train_mean_loss = []

        for i, (img, label) in enumerate(train_loader):
            model.zero_grad()

            img = img.to(device)

            out, latent_loss, elbo_t, elbo_b=model(img)
            
            recon_loss = criterion(out, img)
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * latent_loss+elbo_t+elbo_b
            loss.backward()
            train_mean_loss.append(loss.item())
            # if scheduler is not None:
            #     scheduler.step()
            optimizer.step()

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            if is_primary():
                lr = optimizer.param_groups[0]["lr"]

                train_loader.set_description(
                    (
                        f"epoch: {epoch + 1}; loss: {str(round(np.mean(train_mean_loss), 5))}; mse: {recon_loss.item():.5f}; "
                        f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                        f"lr: {lr:.5f}"
                    )
                )

            model.train()

        model.eval()

        with torch.no_grad():

            for j, (img, label) in enumerate(test_loader):
                img = img.to(device)
                out, latent_loss, elbo_t, elbo_b=model(img)
                test_recon_loss = criterion(out, img)
                test_latent_loss = latent_loss.mean()
                test_loss = test_recon_loss + latent_loss_weight * latent_loss+elbo_t+elbo_b
                test_mean_loss.append(round(test_loss.item(), 5))

            sample = img[:sample_size]

        utils.save_image(
            torch.cat([sample, out], 0),
            f"{model_path}/{str(epoch + 1).zfill(5)}.png",
            nrow=sample_size,
            normalize=True,
            # range=(-1, 1),
        )

        print(f'test elbo: {str(round(np.mean(test_mean_loss), 5))}')
        torch.save(model.state_dict(),
                   f"{model_path}/vqvae_{str(epoch + 1).zfill(3)}_train_{str(round(np.mean(train_mean_loss), 5))}_test_{str(round(np.mean(test_mean_loss), 5))}.pt")


In [None]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_left_9x_512_360'

model_path='test'

train(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs)


In [None]:
model=VanillaVAE(1,200,flag_128=True, hidden_dims=[16, 32, 64, 128, 256, 512])
model.to('cuda')

# summary(model, (1, 64, 64))
summary(model, (1, 128, 128))