In [6]:
import argparse

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    from apex import amp

except ImportError:
    amp = None

# from dataset import LMDBDataset
from pixelsnail import PixelSNAIL
from scheduler import CycleScheduler

In [None]:
import argparse
import pickle

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import lmdb
from tqdm import tqdm

from dataset import ImageFileDataset, CodeRow
from vqvae import VQVAE


def extract(lmdb_env, loader, model, device):
    index = 0

    with lmdb_env.begin(write=True) as txn:
        pbar = tqdm(loader)

        for img, _, filename in pbar:
            img = img.to(device)

            _, _, _, id_t, id_b = model.encode(img)
            id_t = id_t.detach().cpu().numpy()
            id_b = id_b.detach().cpu().numpy()

            for file, top, bottom in zip(filename, id_t, id_b):
                row = CodeRow(top=top, bottom=bottom, filename=file)
                txn.put(str(index).encode('utf-8'), pickle.dumps(row))
                index += 1
                pbar.set_description(f'inserted: {index}')

        txn.put('length'.encode('utf-8'), str(index).encode('utf-8'))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--size', type=int, default=256)
    parser.add_argument('--ckpt', type=str)
    parser.add_argument('--name', type=str)
    parser.add_argument('path', type=str)

    args = parser.parse_args()

    device = 'cuda'

    transform = transforms.Compose(
        [
            transforms.Resize(args.size),
            transforms.CenterCrop(args.size),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
        ]
    )

    dataset = ImageFileDataset(args.path, transform=transform)
    loader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=4)

    model = VQVAE()
    model.load_state_dict(torch.load(args.ckpt))
    model = model.to(device)
    model.eval()

    map_size = 100 * 1024 * 1024 * 1024

    env = lmdb.open(args.name, map_size=map_size)

    extract(env, loader, model, device)

In [7]:
class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()

def train(args, epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()

    for i, (top, bottom, label) in enumerate(loader):
        model.zero_grad()

        top = top.to(device)

        if args.hier == 'top':
            target = top
            out, _ = model(top)

        elif args.hier == 'bottom':
            bottom = bottom.to(device)
            target = bottom
            out, _ = model(bottom, condition=top)

        loss = criterion(out, target)
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == target).float()
        accuracy = correct.sum() / target.numel()

        lr = optimizer.param_groups[0]['lr']

        loader.set_description(
            (
                f'epoch: {epoch + 1}; loss: {loss.item():.5f}; '
                f'acc: {accuracy:.5f}; lr: {lr:.5f}'
            )
        )

In [None]:

parser = argparse.ArgumentParser()
parser.add_argument('--batch', type=int, default=32)
parser.add_argument('--epoch', type=int, default=420)
parser.add_argument('--hier', type=str, default='top')
parser.add_argument('--lr', type=float, default=3e-4)
parser.add_argument('--channel', type=int, default=256)
parser.add_argument('--n_res_block', type=int, default=4)
parser.add_argument('--n_res_channel', type=int, default=256)
parser.add_argument('--n_out_res_block', type=int, default=0)
parser.add_argument('--n_cond_res_block', type=int, default=3)
parser.add_argument('--dropout', type=float, default=0.1)
parser.add_argument('--amp', type=str, default='O0')
parser.add_argument('--sched', type=str)
parser.add_argument('--ckpt', type=str)
parser.add_argument('path', type=str)


device = 'cuda'

dataset = LMDBDataset(args.path)
loader = DataLoader(
    dataset, batch_size=args.batch, shuffle=True, num_workers=4, drop_last=True
)

ckpt = {}

if args.ckpt is not None:
    ckpt = torch.load(args.ckpt)
    args = ckpt['args']

if args.hier == 'top':
    model = PixelSNAIL(
        [32, 32],
        512,
        args.channel,
        5,
        4,
        args.n_res_block,
        args.n_res_channel,
        dropout=args.dropout,
        n_out_res_block=args.n_out_res_block,
    )

elif args.hier == 'bottom':
    model = PixelSNAIL(
        [64, 64],
        512,
        args.channel,
        5,
        4,
        args.n_res_block,
        args.n_res_channel,
        attention=False,
        dropout=args.dropout,
        n_cond_res_block=args.n_cond_res_block,
        cond_res_channel=args.n_res_channel,
    )

if 'model' in ckpt:
    model.load_state_dict(ckpt['model'])

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)

if amp is not None:
    model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp)

model = nn.DataParallel(model)
model = model.to(device)

scheduler = None
if args.sched == 'cycle':
    scheduler = CycleScheduler(
        optimizer, args.lr, n_iter=len(loader) * args.epoch, momentum=None
    )

for i in range(args.epoch):
    train(args, i, loader, model, optimizer, scheduler, device)
    torch.save(
        {'model': model.module.state_dict(), 'args': args},
        f'checkpoint/pixelsnail_{args.hier}_{str(i + 1).zfill(3)}.pt',
    )