In [1]:
import sys
sys.path.append('../')

In [2]:
from modules.modules import VectorQuantizedVAE

In [3]:
import numpy as np
import torch

In [4]:
%matplotlib inline
import matplotlib.pyplot as plt
from datasets import datasets

In [5]:
import torch.optim as optim

from tqdm import tqdm_notebook as tqdm

In [6]:
from torch import nn
class Classifier(nn.Module):
    def __init__(self, in_f, out_f):
        super(Classifier, self).__init__()
        self.fc = nn.Linear(in_f, out_f)
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        x = self.fc(x)
        return x

In [7]:
model = VectorQuantizedVAE(3, 256, 256)

In [8]:
model.load_state_dict(torch.load('/home/genta/data2/vqvae/models/vqvae_im128_k256/best.pt'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [9]:
dataset = datasets.get_dataset('imagenet', '/home/genta/dataset/', image_size=128)

In [10]:
train_dataset = dataset['train']
test_dataset = dataset['test']
valid_dataset = dataset['valid']
num_channels = dataset['num_channels']

In [11]:
import multiprocessing as mp

In [12]:
batch_size = 256
num_workers = mp.cpu_count() - 1
train_loader = torch.utils.data.DataLoader(train_dataset,
    batch_size=batch_size, shuffle=False,
    num_workers=num_workers, pin_memory=True)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
    batch_size=batch_size, shuffle=False, drop_last=True,
    num_workers=num_workers, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test_dataset,
    batch_size=16, shuffle=True)

In [13]:
num_workers

11

In [14]:
predictor = Classifier(int(256*32*32), len(train_dataset.classes))

In [15]:
predictor.cuda()
model.cuda()

VectorQuantizedVAE(
  (codebook): VQEmbedding(
    (embedding): Embedding(256, 256)
  )
  (encoder): Sequential(
    (0): Conv2d(3, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace)
    (3): Conv2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): ResBlock(
      (block): Sequential(
        (0): ReLU(inplace)
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ReLU(inplace)
        (4): Conv2d(256, 256, kernel_size=(1, 1), stride=(1, 1))
        (5): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (5): ResBlock(
      (block): Sequential(
        (0): ReLU(inplace)
        (1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (2): BatchNo

In [16]:
optimizer = optim.SGD(predictor.parameters(), lr=0.001, momentum=0.9)

In [17]:
loss_fn = nn.CrossEntropyLoss()

In [18]:

def train(data_loader, model, clfy, optimizer, args=None, writer=None, loss_fn=None):
    if loss_fn is None:
        loss_fn = nn.CrossEntropyLoss()
    for images, labels in tqdm(data_loader, total=len(data_loader)):
        # print(images.shape)
        images = images.to('cuda')
        labels = labels.to('cuda')

        optimizer.zero_grad()
        with torch.no_grad():
            latents = model.encode(images)
            latents = model.codebook.embedding(latents).permute(0, 3, 1, 2)
        out = clfy(latents)
        loss = loss_fn(out, labels)
        loss.backward()

#         if writer is not None:
#             # Logs
#             writer.add_scalar('loss/train/reconstruction', loss_recons.item(), args.steps)
#             writer.add_scalar('loss/train/quantization', loss_vq.item(), args.steps)

        optimizer.step()
#         args.steps += 1

In [19]:
%%time
train(train_loader, model, predictor, optimizer, loss_fn=loss_fn)

torch.save(predictor.state_dict(), './clfy.model')

HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




CPU times: user 18min 5s, sys: 17min 57s, total: 36min 2s
Wall time: 29min 39s


In [20]:
predictor.load_state_dict(torch.load('./clfy.model'))

IncompatibleKeys(missing_keys=[], unexpected_keys=[])

In [21]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [22]:
def test(data_loader, model, clfy, args, writer=None):
    with torch.no_grad():
        loss_total = 0.
        acc_total = 0.
        for images, labels in tqdm(data_loader, total=len(data_loader)):
            # print(images.shape)
            images = images.to('cuda')
            labels = labels.to('cuda')

            latents = model.encode(images)
            latents = model.codebook.embedding(latents).permute(0, 3, 1, 2)
            out = clfy(latents)
            loss_total += loss_fn(out, labels)
            acc, = accuracy(out, labels)
            acc_total += acc

        loss_total /= len(data_loader)
        acc_total /= len(data_loader)
        
    

    return loss_total.item(), acc_total.item()

In [23]:
result = test(test_loader, model, predictor, None)

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))




In [24]:
result

(588.74560546875, 0.09999999403953552)

In [25]:
result_train = test(train_loader, model, predictor, None)

HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))






In [26]:
result_train

(594.8475952148438, 0.10146103799343109)

In [27]:
%%time
n_epoch = 9
for epoch in range(n_epoch):
    print(epoch)
    train(train_loader, model, predictor, optimizer, loss_fn=loss_fn)

0


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




1


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




2


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




3


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




4


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




5


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




6


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




7


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




8


HBox(children=(IntProgress(value=0, max=5005), HTML(value='')))




CPU times: user 2h 11min 36s, sys: 2h 44min 10s, total: 4h 55min 46s
Wall time: 4h 26min 26s


In [29]:
torch.save(predictor.state_dict(), './clfy10epoch.model')

In [30]:
result = test(test_loader, model, predictor, None)

HBox(children=(IntProgress(value=0, max=3125), HTML(value='')))




In [31]:
result

(583.0155639648438, 0.1939999908208847)