In [None]:
import argparse

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    from apex import amp

except ImportError:
    amp = None

# from dataset import LMDBDataset
from pixelsnail import PixelSNAIL
from scheduler import CycleScheduler

import argparse
import sys
import os
from skimage import io

import torch
import torch.nn.functional as F
from torch import nn, optim
from torch.utils.data import DataLoader

from torchvision import datasets, transforms, utils
import matplotlib.patches as mpatches


# from scheduler import CycleScheduler
from pt_utils import  Embeddings, Trainer, VQVAE, data_sampler, Vqvae2Adaptive
from torch.utils import data
from torch import distributed as dist

from umap import UMAP
from sklearn.model_selection import train_test_split
from matplotlib import pyplot as plt
from collections import Counter
from skimage import transform, metrics
import skimage
import pandas as pd
import numpy as np
import glob

from torchsummary import summary

from tqdm.notebook import trange, tqdm
import torchvision
from PIL import Image

import numpy as np
from sklearn.decomposition import PCA
from scipy.ndimage.filters import gaussian_filter
import scipy as sp
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import make_axes_locatable

import joblib
import pickle

from tensordict import TensorDict

seed = 51
np.random.seed(seed)
torch.manual_seed(seed)

# Create embeddings dataset

In [None]:
# dataset_path = '../data/dataset_512/'
# dataset_path = '../datasets/bc_right_sub_left_minmax_4x_360'
# dataset_path = '../datasets/bc_left_4x_360'
dataset_path = '../datasets/original/o_bc_left_9x_512_360'
# dataset_path = '../datasets/original/o_bc_left_4x_768'

new_dataset_path='../datasets/original/emb_dim_1_n_embed_8192_bc_left_9x_512_360'
if os.path.exists(new_dataset_path) is False:
    os.mkdir(new_dataset_path)

device='cuda'

resize_shape = (512, 512)
# resize_shape = (1024, 1024)

batch_size = 4

transform = transforms.Compose(
    [
        # transforms.Resize(resize_shape),
        # transforms.CenterCrop(resize_shape),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ]
)
model_file = 'data/logs/emb_dim_1_n_embed_8192_bc_left_9x_512_360/vqvae_002_train_0.01976_test_0.01984.pt'

model =    VQVAE(in_channel=3,
                   channel=128,
                   n_res_block=6,
                   n_res_channel=32,
                   embed_dim=1,
                   n_embed=8192,
                   decay=0.99).to(device)

dataset = datasets.ImageFolder(dataset_path, transform=transform)
device='cuda'

images_embs_t = []
images_embs_b = []

dataset_path = dataset.__dict__['root']
classes_folders = os.listdir(dataset_path)
classes_folders_images = [os.listdir(dataset_path + '/' + folder) for folder in classes_folders]
classes_folders_images_num = [len(os.listdir(dataset_path + '/' + folder)) for folder in classes_folders]

img_transform = dataset.__dict__['transform']

In [None]:
tensordict = TensorDict(
    {"a": torch.zeros(1,32,32), "b": torch.ones(1, 64, 64)}, batch_size=[1])

In [None]:
tensordict['b']

In [None]:
for i in range(len(classes_folders)):
    print(f'Number of folders {i + 1}/{len(classes_folders)}')
    
    os.mkdir(new_dataset_path + '/'+classes_folders[i])
        
    for j in tqdm(range(classes_folders_images_num[i]), desc=f'Folder {classes_folders[i]}'):
        image_path = dataset_path + '/' + classes_folders[i] + '/' + classes_folders_images[i][j]

        image = Image.open(image_path)
        image = image.convert("RGB")
        image = img_transform(image)
        image = image.unsqueeze(0).to(device)
        

        model.zero_grad()

        quant_t, quant_b, _, _, _ = model.encode(image)
        
        # quant_t.requires_grad=False
        # quant_b.requires_grad=False

        # torch.save(torch.stack([quant_t, quant_b]), new_dataset_path+ '/'+classes_folders[i]+'/'+classes_folders_images[i][j][:-5]+'.pt',)
        torch.save(TensorDict({"top": quant_t, "bottom": quant_b}, batch_size=[1]),
                   new_dataset_path+ '/'+classes_folders[i]+'/'+classes_folders_images[i][j][:-5]+'.pt')
        

In [None]:
import argparse
import pickle

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import lmdb
from tqdm import tqdm

from dataset import ImageFileDataset, CodeRow
from vqvae import VQVAE


def extract(lmdb_env, loader, model, device):
    index = 0

    with lmdb_env.begin(write=True) as txn:
        pbar = tqdm(loader)

        for img, _, filename in pbar:
            img = img.to(device)

            _, _, _, id_t, id_b = model.encode(img)
            id_t = id_t.detach().cpu().numpy()
            id_b = id_b.detach().cpu().numpy()

            for file, top, bottom in zip(filename, id_t, id_b):
                row = CodeRow(top=top, bottom=bottom, filename=file)
                txn.put(str(index).encode('utf-8'), pickle.dumps(row))
                index += 1
                pbar.set_description(f'inserted: {index}')

        txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--name', type=str)
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    device = 'cuda'

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dataset = ImageFileDataset(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

    model = VQVAE()
    model.load_state_dict(torch.load(args.ckpt))
    model = model.to(device)
    model.eval()

    map_size = 100 * 1024 * 1024 * 1024

    env = lmdb.open(args.name, map_size=map_size)

    extract(env, loader, model, device)

# Train pixelsnail

In [None]:
# CodeRow = namedtuple('CodeRow', ['top', 'bottom', 'filename'])


# class ImageFileDataset(datasets.ImageFolder):
#     def __getitem__(self, index):
#         sample, target = super().__getitem__(index)
#         path, _ = self.samples[index]
#         dirs, filename = os.path.split(path)
#         _, class_name = os.path.split(dirs)
#         filename = os.path.join(class_name, filename)

#         return sample, target, filename


class LMDBDataset(datasets.DatasetFolder):
#     def __init__(self, path):
#         self.env = lmdb.open(
#             path,
#             max_readers=32,
#             readonly=True,
#             lock=False,
#             readahead=False,
#             meminit=False,
#         )


#         self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

#     def __len__(self):
#         return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = str(index).encode('utf-8')

            row = pickle.loads(txn.get(key))

        return torch.from_numpy(row.top), torch.from_numpy(row.bottom), row.filename
    
    
class EmbsDataset(torchvision.datasets.DatasetFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        pt_dict= torch.load(path)
        
        return pt_dict['top'], pt_dict['bottom']

In [None]:
dataset_path = '../datasets/original/emb_dim_1_n_embed_8192_bc_left_9x_512_360/'

dataset = EmbsDataset(dataset_path, loader=torch.load,extensions=['.pt'] )

In [None]:
def train(epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()

    for i, (top, bottom) in tqdm(enumerate(loader)):
        model.zero_grad()

        top = top.to(device)
        bottom = bottom.to(device)
        
        top=torch.squeeze(top,[1,2])
        bottom=torch.squeeze(bottom, [1,2])

        if hier == 'top':
            target = top
            print(top.shape)
            print(model(top))
            print(model(top).shape)
            out, _ = model(top)

        elif hier == 'bottom':
            bottom = data.to(bottom)
            target = bottom
            out, _ = model(bottom, condition=top)

        loss = criterion(out, target)
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == target).float()
        accuracy = correct.sum() / target.numel()

        lr = optimizer.param_groups[0]['lr']
        
        print(f'Train epoch: {epoch + 1}; loss: {loss.item():.5f}; acc: {accuracy:.5f}; lr: {lr:.5f}')

        
def evaluate( epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (top, bottom) in tqdm(enumerate(loader)):


            top = top.to(device)

            if args.hier == 'top':
                target = top
                out, _ = model(top)

            elif args.hier == 'bottom':
                bottom = data.to(device)
                target = bottom
                out, _ = model(bottom, condition=top)

            loss = criterion(out, target)

            _, pred = out.max(1)
            correct = (pred == target).float()
            accuracy = correct.sum() / target.numel()

            lr = optimizer.param_groups[0]['lr']

            print(f'Test epoch: {epoch + 1}; loss: {loss.item():.5f}; acc: {accuracy:.5f}; lr: {lr:.5f}')

            return round(loss.item()), round(accuracy,3)


In [None]:
file_name='Ultra_Co6_2/Ultra_Co6_2-001_part_2_angle_270.pt'
torch.load(dataset_path+'/'+file_name)

In [None]:
dataset_path = '../datasets/original/emb_dim_1_n_embed_8192_bc_left_9x_512_360'

n_gpu = 1
batch_size = 4
val_split = 0.15

dataset = EmbsDataset(dataset_path, loader=torch.load, extensions=['.pt'])

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            generator=torch.Generator().manual_seed(seed))

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

train_loader = DataLoader(
    train_dataset, batch_size=batch_size // n_gpu, sampler=train_sampler
)
test_loader = DataLoader(
    test_dataset, batch_size=batch_size // n_gpu, sampler=test_sampler
)


In [None]:
batch=32
hier='top'
lr=3e-4
channel=256
n_res_block=4
n_res_channel=256
n_out_res_block=0
n_cond_res_block=3
dropout=0.1
amp='O0'
sched=None
ckpt=None
path=None

device = 'cuda'


if hier == 'top':
    model = PixelSNAIL(
        shape=[32, 32],
        n_class=512,
        channel=channel,
        kernel_size=5,
        n_block=4,
        n_res_block=n_res_block,
        res_channel=n_res_channel,
        dropout=dropout,
        n_out_res_block=n_out_res_block,
    )

elif hier == 'bottom':
    model = PixelSNAIL(
        shape=[64, 64],
        n_class=512,
        channel=channel,
        kernel_size=5,
        n_block=4,
        n_res_block=n_res_block,
        res_channel=n_res_channel,
        attention=False,
        dropout=dropout,
        n_cond_res_block=n_cond_res_block,
        cond_res_channel=n_res_channel,
    )


# if 'model' in ckpt:
#     model.load_state_dict(torch.load(ckpt))

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [None]:
epoch=420
device='cuda'

scheduler = CycleScheduler(
    optimizer, lr, n_iter=len(train_loader) * epoch, momentum=None
)

for i in range(epoch):
    train( i, train_loader, model, optimizer, scheduler, device)
    loss, acc = evaluate( i, test_loader, model, optimizer, scheduler, device)
    torch.save(model.state_dict(), f'checkpoint/pixelsnail_{hier}_{str(i + 1)}_loss_{loss}_acc_{acc}.pt')
         

In [None]:
import torch
import torch.nn as nn
import torchvision

from tqdm.notebook import tqdm
from model import PixelSnail

import time


TRY_CUDA = True
MODEL_SAVING = True
NB_EPOCHS = 200
BATCH_SIZE = 32

if __name__ == "__main__":
    device = torch.device('cuda' if TRY_CUDA and torch.cuda.is_available() else 'cpu')
    print(f"> Using device {device}")

    print(f"> Instantiating PixelSnail")
    # model = PixelSnail([28, 28], 256, 32, 5, 3, 2, 16, nb_out_res_block=2).to(device)
    model = PixelSnail([28, 28], 256, 32, 5, 3, 2, 16, nb_cond_res_block=2, cond_res_channel=16, nb_out_res_block=2).to(device)
    print(f"> Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}\n")

    print("> Loading dataset")
    train_dataset = torchvision.datasets.MNIST('data', train=True, download=True, transform=torchvision.transforms.ToTensor())
    test_dataset = torchvision.datasets.MNIST('data', train=False, download=True, transform=torchvision.transforms.ToTensor())

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    optim = torch.optim.Adam(model.parameters(), lr=0.001)
    crit = nn.CrossEntropyLoss()

    save_id = int(time.time())

    for ei in range(NB_EPOCHS):
        print(f"\n> Epoch {ei+1}/{NB_EPOCHS}")
        train_loss = 0.0
        eval_loss = 0.0

        model.train()
        for x_input, c in tqdm(train_loader):
            optim.zero_grad()
            x = (x_input*255).long().squeeze().to(device)
            c = c.view(-1,1,1).expand(-1,7,7).to(device)

            pred, _ = model(x, c=c)
            loss = crit(pred.view(BATCH_SIZE, 256, -1), x.view(BATCH_SIZE, -1))
            train_loss += loss.item()

            loss.backward()
            optim.step()

        model.eval()
        with torch.no_grad():
            for i, (x, c) in enumerate(tqdm(test_loader)):
                optim.zero_grad()
                x = (x*255).long().squeeze().to(device)
                c = c.view(-1,1,1).expand(-1,7,7).to(device)

                pred, _ = model(x, c=c)
                loss = crit(pred.view(BATCH_SIZE, 256, -1), x.view(BATCH_SIZE, -1))
                eval_loss += loss.item()

                if i == 0:
                    img = torch.cat([x, torch.argmax(pred, dim=1)], dim=0) / 255.
                    torchvision.utils.save_image(img.unsqueeze(1), f"imgs/pixelcnn-{ei}.png")
        torch.save(model.state_dict(), f"checkpoints/{save_id}-{ei}-pixelcnn.pt")
        print(f"> Training Loss: {train_loss / len(train_loader)}")
        print(f"> Evaluation Loss: {eval_loss / len(test_loader)}")

In [None]:
x=x_input
x1=(x*255).long().squeeze().to(device)

x2 = F.one_hot(x1, 256).permute(0, 3, 1, 2)

In [None]:
x1[0]