In [11]:
import argparse

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm


# from apex import amp


from pixelsnail import PixelSNAIL
import os
import pickle
from collections import namedtuple

import torch
from torch.utils.data import Dataset
from torchvision import datasets
import lmdb
from scheduler import CycleScheduler

In [8]:
class LMDBDataset(Dataset):
    def __init__(self, path):
        self.env = lmdb.open(
            path,
            max_readers=32,
            readonly=True,
            lock=False,
            readahead=False,
            meminit=False,
        )

        if not self.env:
            raise IOError('Cannot open lmdb dataset', path)

        with self.env.begin(write=False) as txn:
            self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8'))

    def __len__(self):
        return self.length

    def __getitem__(self, index):
        with self.env.begin(write=False) as txn:
            key = str(index).encode('utf-8')

            row = pickle.loads(txn.get(key))

        return torch.from_numpy(row.top), torch.from_numpy(row.bottom), row.filename

In [9]:
def train( epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()

    for i, (top, bottom, label) in enumerate(loader):
        model.zero_grad()

        top = top.to(device)

        if hier == 'top':
            target = top
            out, _ = model(top)

        elif hier == 'bottom':
            bottom = bottom.to(device)
            target = bottom
            out, _ = model(bottom, condition=top)

        loss = criterion(out, target)
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == target).float()
        accuracy = correct.sum() / target.numel()

        lr = optimizer.param_groups[0]['lr']

        loader.set_description(
            (
                f'epoch: {epoch + 1}; loss: {loss.item():.5f}; '
                f'acc: {accuracy:.5f}; lr: {lr:.5f}'
            )
        )


class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()

In [12]:
batch=32
epoch=420
hier='top'
lr=3e-4
channel=256
n_res_block=4
n_res_channel=256
n_out_res_block=0
n_cond_res_block=3
dropout=0.1
amp='O0'
sched=None
ckpt=None
path=None

device = 'cpu'

dataset = LMDBDataset(path)
loader = DataLoader(
    dataset, batch_size=batch, shuffle=True, num_workers=4, drop_last=True
)

if hier == 'top':
    model = PixelSNAIL(
        [32, 32],
        512,
        channel,
        5,
        4,
        n_res_block,
        n_res_channel,
        dropout=dropout,
        n_out_res_block=n_out_res_block,
    )

elif hier == 'bottom':
    model = PixelSNAIL(
        [64, 64],
        512,
        channel,
        5,
        4,
        n_res_block,
        n_res_channel,
        attention=False,
        dropout=dropout,
        n_cond_res_block=n_cond_res_block,
        cond_res_channel=n_res_channel,
    )


if 'model' in ckpt:
    model.load_state_dict(torch.load(ckpt))

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

model = model.to(device)

scheduler = CycleScheduler(
    optimizer, lr, n_iter=len(loader) * epoch, momentum=None
)

TypeError: 'path' argument required

In [None]:
for i in range(epoch):
    train( i, loader, model, optimizer, scheduler, device)
    torch.save(
        {'model': model.module.state_dict(), ': ,
        f'checkpoint/pixelsnail_{hier}_{str(i + 1).zfill(3)}.pt',
    )