In [26]:
import argparse

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    from apex import amp

except ImportError:
    amp = None

from dataset import LMDBDataset
from pixelsnail import PixelSNAIL
from scheduler import CycleScheduler
from torch.utils import data

In [33]:
def train( epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()

    for i, (top, bottom, label) in enumerate(loader):
        model.zero_grad()

        top = top.to(device)

        if hier == 'top':
            target = top
            out, _ = model(top)

        elif hier == 'bottom':
            bottom = bottom.to(device)
            target = bottom
            out, _ = model(bottom, condition=top)

        loss = criterion(out, target)
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == target).float()
        accuracy = correct.sum() / target.numel()

        lr = optimizer.param_groups[0]['lr']

        loader.set_description(
            (
                f'epoch: {epoch + 1}; loss: {loss.item():.5f}; '
                f'acc: {accuracy:.5f}; lr: {lr:.5f}'
            )
        )
        
def evaluate( epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (top, bottom, label) in tqdm(enumerate(loader)):

            top = top.to(device)
            bottom = bottom.to(device)

            top=torch.squeeze(top,[1])
            bottom=torch.squeeze(bottom, [1])


            if hier == 'top':
                target = top
                out, _ = model(top)

            elif hier == 'bottom':
                target = bottom
                out, _ = model(bottom, condition=top)

            loss = criterion(out, target)

            _, pred = out.max(1)
            correct = (pred == target).float()
            accuracy = correct.sum() / target.numel()

            lr = optimizer.param_groups[0]['lr']

            print(f'Test epoch: {epoch + 1}; loss: {loss.item():.5f}; acc: {accuracy:.5f}; lr: {lr:.5f}')

            return round(loss.item()), accuracy



class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()
    
def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)

In [40]:
n_res_block=4
n_res_channel=256
n_out_res_block=0
n_cond_res_block=3
amp='O0'
channel=256
dropout=0.1

device = 'cuda'

lr=3e-4
# hier='top'
hier='bottom'
epoch=420
batch=32
val_split=0.15
n_gpu=1

sched='cycle'

dataset_path='runs/embs_emb_dim_64_n_embed_512_bc_left_4x_768'

dataset = LMDBDataset(dataset_path)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            # generator=torch.Generator().manual_seed(seed)
                                                           )

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

# fails when numw_workers!=0
train_loader = DataLoader(
    train_dataset, batch_size=batch // n_gpu, sampler=train_sampler, num_workers=0,
)
test_loader = DataLoader(
    test_dataset, batch_size=batch // n_gpu, sampler=test_sampler,num_workers=0,
)


if hier == 'top':
    model = PixelSNAIL(
        [32, 32],
        512,
        channel,
        5,
        4,
        n_res_block,
        n_res_channel,
        dropout=dropout,
        n_out_res_block=n_out_res_block,
    )

elif hier == 'bottom':
    model = PixelSNAIL(
        [64, 64],
        512,
        channel,
        5,
        4,
        n_res_block,
        n_res_channel,
        attention=False,
        dropout=dropout,
        n_cond_res_block=n_cond_res_block,
        cond_res_channel=n_res_channel,
    )

# if 'model' in ckpt:
#     model.load_state_dict(torch.load(model_path))

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# if amp is not None:
#     model, optimizer = amp.initialize(model, optimizer, opt_level=amp)

model = nn.DataParallel(model)
model = model.to(device)

scheduler = None
if sched == 'cycle':
    scheduler = CycleScheduler(
        optimizer, lr, n_iter=len(loader) * epoch, momentum=None
    )

In [49]:
folder_name=f'runs/pixelsnail_emb_dim_64_n_embed_512_bc_left_4x_768/{hier}/'

for i in range(epoch):
    train(i, loader, model, optimizer, scheduler, device)
    loss, acc = evaluate( i, test_loader, model, optimizer, scheduler, device)
    torch.save(model.state_dict(), f'{folder_name}/{str(i + 1)}_pixelsnail_{hier}_loss_{loss}_test_acc_{acc}.pt')

epoch: 1; loss: 5.29447; acc: 0.02037; lr: 0.00002: 100%|██████████████████████████████| 56/56 [17:21<00:00, 18.60s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 1; loss: 5.17356; acc: 0.02817; lr: 0.00002



epoch: 2; loss: 5.14059; acc: 0.02016; lr: 0.00002: 100%|██████████████████████████████| 56/56 [18:40<00:00, 20.00s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 2; loss: 5.18215; acc: 0.01852; lr: 0.00002



epoch: 3; loss: 5.14536; acc: 0.01131; lr: 0.00002: 100%|██████████████████████████████| 56/56 [22:30<00:00, 24.12s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 3; loss: 5.08532; acc: 0.02492; lr: 0.00002



epoch: 4; loss: 5.03521; acc: 0.02023; lr: 0.00002: 100%|██████████████████████████████| 56/56 [17:57<00:00, 19.24s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 4; loss: 5.06220; acc: 0.02106; lr: 0.00002



epoch: 5; loss: 4.95820; acc: 0.03265; lr: 0.00003: 100%|██████████████████████████████| 56/56 [14:36<00:00, 15.65s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 5; loss: 4.86964; acc: 0.03857; lr: 0.00003



epoch: 6; loss: 4.83153; acc: 0.03409; lr: 0.00003: 100%|██████████████████████████████| 56/56 [12:58<00:00, 13.90s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 6; loss: 5.12244; acc: 0.01340; lr: 0.00003



epoch: 7; loss: 4.98272; acc: 0.02036; lr: 0.00003: 100%|██████████████████████████████| 56/56 [12:08<00:00, 13.00s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 7; loss: 4.93303; acc: 0.03484; lr: 0.00003



epoch: 8; loss: 5.14516; acc: 0.01282; lr: 0.00003: 100%|██████████████████████████████| 56/56 [15:42<00:00, 16.83s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 8; loss: 4.93926; acc: 0.02746; lr: 0.00003



epoch: 9; loss: 4.89757; acc: 0.03419; lr: 0.00003: 100%|██████████████████████████████| 56/56 [18:11<00:00, 19.49s/it]
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
  0%|                                                                                            | 0/9 [00:00<?, ?it/s]

Test epoch: 9; loss: 4.91536; acc: 0.03450; lr: 0.00003



epoch: 10; loss: 4.88815; acc: 0.03467; lr: 0.00003:   4%|█                             | 2/56 [01:15<33:49, 37.58s/it]


KeyboardInterrupt: 

# Mnist

In [None]:
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

from pixelsnail import PixelSNAIL

In [None]:
def train(epoch, loader, model, optimizer, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        out = model(img)
        loss = criterion(out, img)
        loss.backward()

        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == img).float()
        accuracy = correct.sum() / img.numel()

        loader.set_description(
            (f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}')
        )


class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()


if __name__ == '__main__':
    device = 'cuda'
    epoch = 10

    dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
    loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(10):
        train(i, loader, model, optimizer, device)
        torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')