In [12]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils
import torchvision

from pt_utils import  Embeddings, Trainer, VQVAE, data_sampler
from torchsummary import summary

import numpy as np

In [3]:
seed = 51
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x112b6acf0>

In [5]:
# device = "cuda"
device = "cpu"

model =    VQVAE(in_channel=1,
# model =    Vqvae2Adaptive(in_channel=3,
                   channel=128,
                   n_res_block=6,
                   n_res_channel=32,
                   embed_dim=8,
                   n_embed=8192,
                   decay=0.99).to(device)

# model =    Vqvae2Adaptive(in_channel=3,
#                  channel=128,
#                  n_res_block=6,
#                  n_res_channel=32,
#                  embed_dim=1,
#                  n_embed=8192,
#                  decay=0.99).to(device)

In [10]:
# summary(model, input_size=(1, 512, 512))

# MSE loss

In [8]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../datasets/original/o_bc_left_9x_512_360'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size = 4
val_split = 0.15

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [None]:
# model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_left_sub_right_minmax_4x_360/vqvae_001_train_0.04914_test_0.04206.pt'
model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360/vqvae_003_train_0.04287_test_0.04129.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

In [None]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
model_path = 'data/logs/emb_dim_2_n_embed_8192_bc_left_9x_512_360'

Trainer.train(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs)


# Triplet loss

In [17]:
class TripletImageFolder(torchvision.datasets.ImageFolder):
    """From the torchvision.datasets.ImageFolder it generates triplet samples, used in training. For testing we use normal image folder.
    Note: a triplet is composed by a pair of matching images and one of different class.
    """
    def __init__(self, *arg, **kw):
        super(TripletImageFolder, self).__init__(*arg, **kw)

        self.n_triplets = len(self.samples)
        self.train_triplets = self.generate_triplets()

    def generate_triplets(self):
        labels = torch.Tensor(self.targets)
        triplets = []
        for x in np.arange(self.n_triplets):
            idx = np.random.randint(0, labels.size(0))
            idx_matches = np.where(labels.numpy() == labels[idx].numpy())[0]
            idx_no_matches = np.where(labels.numpy() != labels[idx].numpy())[0]
            idx_a, idx_p = np.random.choice(idx_matches, 2, replace=False)
            idx_n = np.random.choice(idx_no_matches, 1)[0]
            triplets.append([idx_a, idx_p, idx_n])
        return np.array(triplets)

    def set_triplets(self, triplets):
        self.train_triplets = triplets

    def __getitem__(self, index):
        t = self.train_triplets[index]

        path_a, _ = self.samples[t[0]]
        path_p, _ = self.samples[t[1]]
        path_n, _ = self.samples[t[2]]

        img_a = self.loader(path_a)
        img_p = self.loader(path_p)
        img_n = self.loader(path_n)

        if self.transform is not None:
            img_a = self.transform(img_a)
            img_p = self.transform(img_p)
            img_n = self.transform(img_n)

        return img_a, img_p, img_n

In [None]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
# dataset_path = '../datasets/original/o_bc_left_9x_512_360'
dataset_path = '../datasets/original/o_bc_left_9x_512_360'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

n_gpu = 1
batch_size = 4
val_split = 0.15

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)

dataset = TripletImageFolder(dataset_path, transform=transform)
# dataset = datasets.ImageFolder(dataset_path, transform=transform)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler, num_workers=2
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler, num_workers=2
)

In [14]:
class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
        
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.calc_euclidean(anchor, positive)
        distance_negative = self.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()

In [18]:
def train_triplet(model, optimizer, train_loader, test_loader, model_path, epochs=100, device='cuda',
          latent_loss_weight=0.25, sample_size=25):

    if os.path.exists(model_path) is False:
        os.mkdir(model_path)

    for epoch in range(epochs):

        if is_primary():
            train_loader = tqdm(train_loader)

        criterion = nn.TripletMarginLoss()

        mse_sum = 0
        mse_n = 0
        test_mean_loss = []
        train_mean_loss = []

        for i, data in enumerate(train_loader):
            model.zero_grad()
            anchor_img, positive_img, negative_img = data
            
            anchor_img = anchor_img.to(device)
            positive_img = positive_img.to(device)
            negative_img = negative_img.to(device)

            anchor_out, anchor_latent_out = model(anchor_img)
            positive_out, positive_latent_out = model(positive_img)
            negative_out, negative_latent_out = model(negative_img)

            # triplet loss 
            recon_loss = criterion(anchor_out, positive_out, negative_out)
            
            latent_loss = latent_loss.mean()
            loss = recon_loss + latent_loss_weight * torch.mean([anchor_latent_out.mean(),
                                                                 anchor_latent_out.mean(),
                                                                 anchor_latent_out.mean() ])
            loss.backward()
            train_mean_loss.append(loss.item())
            optimizer.step()

            part_mse_sum = recon_loss.item() * img.shape[0]
            part_mse_n = img.shape[0]
            comm = {"mse_sum": part_mse_sum, "mse_n": part_mse_n}
            comm = all_gather(comm)

            for part in comm:
                mse_sum += part["mse_sum"]
                mse_n += part["mse_n"]

            if is_primary():
                lr = optimizer.param_groups[0]["lr"]

                train_loader.set_description(
                    (
                        f"epoch: {epoch + 1}; loss: {str(round(np.mean(train_mean_loss), 5))}; mse: {recon_loss.item():.5f}; "
                        f"latent: {latent_loss.item():.3f}; avg mse: {mse_sum / mse_n:.5f}; "
                        f"lr: {lr:.5f}"
                    )
                )

            model.train()

        model.eval()

        with torch.no_grad():

            for j, (img, label) in enumerate(test_loader):
                img = img.to(device)
                out, latent_loss = model(img)
                test_recon_loss = criterion(out, img)
                test_latent_loss = latent_loss.mean()
                test_loss = test_recon_loss + latent_loss_weight * latent_loss
                test_mean_loss.append(round(test_loss.item(), 5))

            sample = img[:sample_size]

        utils.save_image(
            torch.cat([sample, out], 0),
            f"{model_path}/{str(epoch + 1).zfill(5)}.png",
            nrow=sample_size,
            normalize=True,
            # range=(-1, 1),
        )

        print(f'test elbo: {str(round(np.mean(test_mean_loss), 5))}')
        torch.save(model.state_dict(),
                   f"{model_path}/vqvae_{str(epoch + 1).zfill(3)}_train_{str(round(np.mean(train_mean_loss), 5))}_test_{str(round(np.mean(test_mean_loss), 5))}.pt")

In [None]:
# model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_left_sub_right_minmax_4x_360/vqvae_001_train_0.04914_test_0.04206.pt'
model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360/vqvae_003_train_0.04287_test_0.04129.pt'

model.load_state_dict(torch.load(model_file, map_location=torch.device('cuda')))

In [None]:
epochs = 100
lr = 1e-4

latent_loss_weight = 0.25
sample_size = 25

optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-7, amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(), lr=lr,weight_decay=1e-6,centered=True)

# sample_path = '../data/logs/vq-vae-2/4x/samples'
# model_path = '../data/logs/vq-vae-2/4x/weights'

# model_path = 'data/logs/emb_dim_1_n_embed_8192_bc_right_sub_left_minmax_4x_360'
model_path = 'data/logs/emb_dim_2_n_embed_8192_bc_left_9x_512_360'

train_triplet(model=model, optimizer=optimizer, train_loader=train_loader, test_loader=test_loader,
              model_path=model_path, epochs=epochs)
