In [1]:
import argparse

import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

try:
    from apex import amp

except ImportError:
    amp = None

from dataset import LMDBDataset
from pixelsnail import PixelSNAIL
from scheduler import CycleScheduler
from torch.utils import data

# from torchsummary import summary

In [2]:
def train( epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)
    
    criterion = nn.CrossEntropyLoss()
    for i, (top, bottom, label) in enumerate(loader):
        model.zero_grad()

        top = top.to(device)
        bottom = bottom.to(device)

        print(bottom.shape)
        
        if hier == 'top':
            target = top
            out, _ = model(top)

        elif hier == 'bottom':
            print('-1')
            target = bottom
            out, _ = model(bottom, condition=top)
        
            print('0')
        
        loss = criterion(out, target)
        print(f'computed loss {loss}')
        loss.backward()

        if scheduler is not None:
            scheduler.step()
        optimizer.step()
        
        print('1')

        _, pred = out.max(1)
        correct = (pred == target).float()
        accuracy = correct.sum() / target.numel()
        print('2')

        lr = optimizer.param_groups[0]['lr']

        loader.set_description(
            (
                f'epoch: {epoch + 1}; loss: {loss.item():.5f}; '
                f'acc: {accuracy:.5f}; lr: {lr:.5f}'
            )
        )
    print(3)
    return loss.item(), accuracy
        
def evaluate( epoch, loader, model, optimizer, scheduler, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()
    with torch.no_grad():
        for i, (top, bottom, label) in tqdm(enumerate(loader)):

            top = top.to(device)
            bottom = bottom.to(device)

            top=torch.squeeze(top,[1])
            bottom=torch.squeeze(bottom, [1])


            if hier == 'top':
                target = top
                out, _ = model(top)

            elif hier == 'bottom':
                target = bottom
                out, _ = model(bottom, condition=top)

            loss = criterion(out, target)

            _, pred = out.max(1)
            correct = (pred == target).float()
            accuracy = correct.sum() / target.numel()

            lr = optimizer.param_groups[0]['lr']

            print(f'Test epoch: {epoch + 1}; loss: {loss.item():.5f}; acc: {accuracy:.5f}; lr: {lr:.5f}')

            return loss.item(), accuracy



class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()
    
def data_sampler(dataset, shuffle, distributed):
    if distributed:
        return data.distributed.DistributedSampler(dataset, shuffle=shuffle)

    if shuffle:
        return data.RandomSampler(dataset)

    else:
        return data.SequentialSampler(dataset)

In [4]:
device = 'cuda'

lr=3e-4
# hier='top'
hier='bottom'
epoch=420
batch=4
val_split=0.15
n_gpu=1

sched='cycle'

dataset_path='runs/embs_emb_dim_2_n_embed_512_bc_left_4x_768'

dataset = LMDBDataset(dataset_path)

train_dataset_len = int(len(dataset) * (1 - val_split))
test_dataset_len = len(dataset) - train_dataset_len

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_dataset_len, test_dataset_len],
                                                            # generator=torch.Generator().manual_seed(seed)
                                                           )

train_sampler = data_sampler(train_dataset, shuffle=True, distributed=False)
test_sampler = data_sampler(test_dataset, shuffle=True, distributed=False)

# fails when numw_workers!=0
train_loader = DataLoader(
    train_dataset, batch_size=batch // n_gpu, sampler=train_sampler, num_workers=0,
)
test_loader = DataLoader(
    test_dataset, batch_size=batch // n_gpu, sampler=test_sampler,num_workers=0,
)


channel=256
dropout=0.1
    
    
if hier == 'top':
    
    # original config
    # n_res_channel=256
    # n_out_res_block=0
    # n_res_block=4
    
    # imagenet config
    n_out_res_block=20
    n_res_channel=2048
    n_res_block=20

    model = PixelSNAIL(
        shape=[32, 32],
        n_class=512,
        channel=channel,
        kernel_size=5,
        n_block=4,
        n_res_block=n_res_block,
        res_channel=n_res_channel,
        dropout=dropout,
        cond_res_kernel=3,
        attention=True,
        cond_res_channel=0,
        n_cond_res_block=0,
        n_out_res_block=n_out_res_block,
    )

elif hier == 'bottom':
    
    # original config
    # n_res_channel=256
    # n_cond_res_block=3
    # n_res_block=4
    
    # imagenet config
    n_cond_res_block=20
    n_res_channel=1024
    n_res_block=20
    
    model = PixelSNAIL(
        shape=[64, 64],
        n_class=512,
        channel=channel,
        kernel_size=5,
        n_block=4,
        n_res_block=n_res_block,
        res_channel=n_res_channel,
        dropout=dropout,
        cond_res_kernel=3,
        attention=False,
        cond_res_channel=n_res_channel,
        n_cond_res_block=n_cond_res_block,
        n_out_res_block=0,
 
    )

# if 'model' in ckpt:
#     model.load_state_dict(torch.load(model_path))

model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# if amp is not None:
#     model, optimizer = amp.initialize(model, optimizer, opt_level=amp)

# model = nn.DataParallel(model)
# model = model.to(device)

scheduler = None
if sched == 'cycle':
    scheduler = CycleScheduler(
        optimizer, lr, n_iter=len(train_loader) * epoch, momentum=None
    )

In [None]:
folder_name=f'runs/pixelsnail_emb_dim_2_n_embed_512_bc_left_4x_768_zeros/{hier}/'

for i in range(epoch):
    train_loss, train_acc=train(i, train_loader, model, optimizer, scheduler, device)
    loss, acc = evaluate( i, test_loader, model, optimizer, scheduler, device)
    torch.save(model.state_dict(), f'{folder_name}/{str(i + 1)}_pixelsnail_{hier}_train_loss_{train_loss:2f}_acc_{train_acc:2f}_test_loss_{loss:2f}_acc_{acc:2f}.pt')

  0%|                                                                                          | 0/383 [00:00<?, ?it/s]

torch.Size([4, 64, 64])
-1


  return F.conv2d(input, weight, bias, self.stride,


In [None]:
folder_name=f'runs/pixelsnail_emb_dim_64_n_embed_512_bc_left_4x_768/{hier}/'

torch.save(model.state_dict(), f'{folder_name}/{str(i + 1)}_pixelsnail_{hier}_loss_{loss}_test_acc_{acc}.pt')

# Mnist

In [None]:
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

from pixelsnail import PixelSNAIL

In [None]:
def train(epoch, loader, model, optimizer, device):
    loader = tqdm(loader)

    criterion = nn.CrossEntropyLoss()

    for i, (img, label) in enumerate(loader):
        model.zero_grad()

        img = img.to(device)

        out = model(img)
        loss = criterion(out, img)
        loss.backward()

        optimizer.step()

        _, pred = out.max(1)
        correct = (pred == img).float()
        accuracy = correct.sum() / img.numel()

        loader.set_description(
            (f'epoch: {epoch + 1}; loss: {loss.item():.5f}; ' f'acc: {accuracy:.5f}')
        )


class PixelTransform:
    def __init__(self):
        pass

    def __call__(self, input):
        ar = np.array(input)

        return torch.from_numpy(ar).long()


if __name__ == '__main__':
    device = 'cuda'
    epoch = 10

    dataset = datasets.MNIST('.', transform=PixelTransform(), download=True)
    loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

    model = PixelSNAIL([28, 28], 256, 128, 5, 2, 4, 128)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    for i in range(10):
        train(i, loader, model, optimizer, device)
        torch.save(model.state_dict(), f'checkpoint/mnist_{str(i + 1).zfill(3)}.pt')