In [0]:
!pip install -U -q PyDrive

[?25l[K     |▎                               | 10kB 23.3MB/s eta 0:00:01[K     |▋                               | 20kB 29.4MB/s eta 0:00:01[K     |█                               | 30kB 34.5MB/s eta 0:00:01[K     |█▎                              | 40kB 34.9MB/s eta 0:00:01[K     |█▋                              | 51kB 37.9MB/s eta 0:00:01[K     |██                              | 61kB 41.9MB/s eta 0:00:01[K     |██▎                             | 71kB 32.8MB/s eta 0:00:01[K     |██▋                             | 81kB 34.1MB/s eta 0:00:01[K     |███                             | 92kB 36.4MB/s eta 0:00:01[K     |███▎                            | 102kB 32.9MB/s eta 0:00:01[K     |███▋                            | 112kB 32.9MB/s eta 0:00:01[K     |████                            | 122kB 32.9MB/s eta 0:00:01[K     |████▎                           | 133kB 32.9MB/s eta 0:00:01[K     |████▋                           | 143kB 32.9MB/s eta 0:00:01[K     |█████        

In [0]:
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

In [0]:
# 1. Authenticate and create the PyDrive client.
auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

In [3]:
# choose a local (colab) directory to store the data.
local_download_path = os.path.expanduser('data')
try:
    os.makedirs(local_download_path)
except:
    print("error")

# 2. Auto-iterate using the query syntax
#    https://developers.google.com/drive/v2/web/search-parameters
file_list = drive.ListFile(
    {'q': "'18qyv3XEVWeQKyQNeCFfpUu--PkPaRdpM' in parents"}).GetList()

for f in file_list:
    # 3. Create & download by id.
    print('title: %s, id: %s' % (f['title'], f['id']))
    fname = os.path.join(local_download_path, f['title'])
    print('downloading to {}'.format(fname))
    f_ = drive.CreateFile({'id': f['id']})
    f_.GetContentFile(fname)

error
title: dsprites_ndarray_64x64.npz, id: 1sLoovx3oF6XYZ4m7Ol99Wb4ykBzyaKDv
downloading to data/dsprites_ndarray_64x64.npz
title: datasets.py, id: 1M_qXS4b7214yBATneK6GS9dgee58rAyo
downloading to data/datasets.py


In [0]:
! mv data/datasets.py datasets.py

In [0]:
import torch
from torch import nn, optim
from torch.nn import functional as F
import torchvision
from torchvision.utils import save_image
import numpy as np
import os
import datasets

In [0]:
class CONFIG(object):
    def __init__(self):
        self.image_size = 64
        self.device = torch.device("cuda") # or "cpu"
        self.batch_size = 64
        self.num_classes = 10
        self.latten_size = 10
        self.beta = 4
        self.use_BN = False
        self.version = 'B-VAE'
        self.KL_penalty = 'abs'
        self.use_label = False
        self.C_max = 20
        self.iter_increase_C = 2e4
        self.gamma = 10
        
config = CONFIG()
assert config.version in ['VAE', 'B-VAE', 'U-VAE']
assert config.KL_penalty in ['relu', 'abs']

In [0]:
dsprites_dataloader = datasets.get_dsprites_dataloader(batch_size=config.batch_size, 
                                                       path_to_data='data/dsprites_ndarray_64x64.npz', subsample=-1)
                                                      #subsample=256000)

dataloader = dsprites_dataloader

In [0]:
class VAE(torch.nn.Module):
    
    def __init__(self):
        super().__init__()
        # encoder
        self._conv1    = nn.Conv2d(1, 32, 4, stride=2, padding=1)
        self._conv1_BN = nn.BatchNorm2d(num_features=32)
        self._conv2    = nn.Conv2d(32, 32, 4, stride=2, padding=1)
        self._conv2_BN = nn.BatchNorm2d(num_features=32)
        self._conv3    = nn.Conv2d(32, 32, 4, stride=2, padding=1)
        self._conv3_BN = nn.BatchNorm2d(num_features=32)
        self._conv4    = nn.Conv2d(32, 32, 4, stride=2, padding=1)
        self._conv4_BN = nn.BatchNorm2d(num_features=32)
        self._fc5     = nn.Linear(512, 256)
        self._fc5_BN  = nn.BatchNorm1d(num_features=256)
        self._fc6     = nn.Linear(256, 256)
        self._fc6_BN  = nn.BatchNorm1d(num_features=256)
        self._fc71     = nn.Linear(256, config.latten_size)
        self._fc72     = nn.Linear(256, config.latten_size)
        self._fc8      = nn.Linear(config.latten_size * 2, config.latten_size * 2)

        # decoder
        if config.use_label:
            self.fc7_   = nn.Linear(config.latten_size + config.num_classes, 256)
        else:
            self.fc7_   = nn.Linear(config.latten_size, 256)
        self.fc7_BN     = nn.BatchNorm1d(num_features=256)
        self.fc6_      = nn.Linear(256, 256)
        self.fc6_BN     = nn.BatchNorm1d(num_features=256)
        self.fc5_      = nn.Linear(256, 512)
        self.fc5_BN     = nn.BatchNorm1d(num_features=512)
        
        self.conv4_    = nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1)
        self.conv4_BN  = nn.BatchNorm2d(num_features=32)
        self.conv3_    = nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1)
        self.conv3_BN  = nn.BatchNorm2d(num_features=32)
        self.conv2_    = nn.ConvTranspose2d(32, 32, 4, stride=2, padding=1)
        self.conv2_BN  = nn.BatchNorm2d(num_features=32)
        self.conv1_    = nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1)
        
        self.mask = torch.ones((config.latten_size * 2, config.latten_size * 2), device=config.device)
        for i in range(config.latten_size):
            self.mask[i, i] = self.mask[i, i + config.latten_size] = 0
            self.mask[i + config.latten_size, i + config.latten_size] = self.mask[i + config.latten_size, i] = 0
        self.mask.to(config.device)
        
    def encode(self, x):
        h = self._conv1(x)
        h = self._conv1_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self._conv2(h)
        h = self._conv2_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self._conv3(h)
        h = self._conv3_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self._conv4(h)
        h = self._conv4_BN(h) if config.use_BN else h
        h = F.relu(h)
            
        h = h.view(-1, 512)
        h = self._fc5(h)
        if config.use_BN:
            h = self._fc5_BN(h)
        h = F.relu(h)
        
        h = self._fc6(h)
        if config.use_BN:
            h = self._fc6_BN(h)
        h = F.relu(h)
        
        mean = self._fc71(h)
        var = self._fc72(h)

        return mean, var, h
    
    
    def decode(self, z, y=None):
        if config.use_label and y is not None:
            z = torch.cat((z, y), dim=1)
        
        h = self.fc7_(z)
        h = self.fc7_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self.fc6_(h)
        h = self.fc6_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self.fc5_(h)
        h = self.fc5_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = h.view(-1, 32, 4, 4)
        
        h = self.conv4_(h)
        h = self.conv4_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self.conv3_(h)
        h = self.conv3_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        h = self.conv2_(h)
        h = self.conv2_BN(h) if config.use_BN else h
        h = F.relu(h)
        
        recon_x = self.conv1_(h)
        
        return recon_x

    
    def compute_predict(self, mean, var):
        t = torch.cat((mean, var), dim=1)
        self.mask.to(config.device)
        self._fc8.weight = torch.nn.Parameter(self._fc8.weight.to(config.device) * self.mask)
        #print(t.shape)
        pred = self._fc8(t)
        return pred
    
        
    def forward(self, x, iteration, y=None):
        mean, var, h = self.encode(x.view(-1, 1, config.image_size, config.image_size))

        var = var.exp()
        samples = torch.randn_like(mean) # sample: Normal distribution
        z = mean + samples * torch.sqrt(var)

        recon_x = self.decode(z, y)
        
        return mean, var, recon_x

In [0]:
def compute_recon_loss(x, recon_x):
    return F.binary_cross_entropy_with_logits(recon_x, x, reduction='sum') / x.shape[0]

def compute_KL_loss(mean, var):
    KL_loss = -0.5 * torch.sum(1 + torch.log(var) - torch.pow(mean, 2) - var)
    return KL_loss / mean.shape[0]

In [0]:
model = VAE().to(config.device)
optimizer = optim.Adam(model.parameters(), lr=5e-4)

global_iter = 0

def train(epoch):
    model.train()
    
    train_loss_accum = recon_loss_accum = 0
    global global_iter
    print(global_iter)
    
    C_max = torch.autograd.Variable(torch.cuda.FloatTensor([config.C_max]), requires_grad=False)
    
    for batch_idx, (X, Y) in enumerate(dataloader):
        # load data to GPU
        X = X.to(config.device)
        Y_onehot = (Y.reshape(-1, 1) == torch.arange(config.num_classes).reshape(1, config.num_classes)).float()
        Y_onehot = Y_onehot.to(config.device)
        
        # update iteration counter and reset gradients
        global_iter += 1
        optimizer.zero_grad()
        
        # forward
        mean, var, recon_x = model(X, Y_onehot)
        
        # compute losses
        recon_loss = compute_recon_loss(X, recon_x)
        KL_loss = compute_KL_loss(mean, var)
        
        loss = recon_loss + KL_loss
        if config.version == 'B-VAE':
            loss = recon_loss + KL_loss * config.beta
        elif config.version == 'U-VAE':
            C = torch.clamp(C_max * global_iter / config.iter_increase_C, 0, config.C_max)
            if config.KL_penalty == 'abs':
                t = config.gamma * torch.abs((KL_loss - C))
            elif config.KL_penalty == 'relu':
                t = config.gamma * F.relu((KL_loss - C))
            loss = recon_loss + t
        
        train_loss_accum += loss.item()
        recon_loss_accum += recon_loss.item()
        
        # caculate gradients
        loss.backward()
        
        # update the weights
        optimizer.step()
        
        if batch_idx % 200 == 0:
            print(recon_loss.item(), KL_loss.item())
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(X), len(dataloader.dataset),
                100. * batch_idx / len(dataloader),
                loss.item()))
            
    epoch_loss = train_loss_accum / len(dataloader.dataset) * config.batch_size
    epoch_recon_loss = recon_loss_accum / len(dataloader.dataset) * config.batch_size
    
    print('====> Epoch: {} Average loss: {:.4f} \tRecon Loss: {:.4f}'.format(epoch, epoch_loss, epoch_recon_loss))

for epoch in range(0, 100):
    train(epoch)

0
2423.5849609375 0.012721960432827473
537.4815673828125 0.006772113032639027
558.64013671875 0.7754918932914734
256.086181640625 7.606057167053223
131.8016815185547 11.484402656555176
115.77548217773438 10.171117782592773
116.3048324584961 9.98398208618164
116.44170379638672 9.781201362609863
109.57201385498047 9.532739639282227
113.14666748046875 9.79623031616211
116.96055603027344 9.615743637084961
122.450927734375 10.246109008789062
100.75709533691406 9.13283920288086
100.8324203491211 9.568601608276367
114.76675415039062 9.270282745361328
113.58191680908203 9.858699798583984
116.76106262207031 9.469005584716797
106.5410385131836 9.082473754882812
109.9217758178711 9.35281753540039
113.15160369873047 9.519816398620605
114.13005828857422 9.856270790100098
112.34197998046875 9.734370231628418
111.13212585449219 9.418840408325195
119.17984771728516 9.439023971557617
102.49982452392578 9.184325218200684
110.82048034667969 8.874897956848145
109.17408752441406 9.36862850189209
107.073883

In [0]:
import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

uploaded = drive.CreateFile({'title': "model-epoch-50.pt",\
                           "parents": [{"kind": "drive#fileLink","id": '1wchfRZbfcdCQEHp7UJYNijVbnGOT1BU-'}]})
torch.save(model, 'model.pt')
uploaded.SetContentFile('model.pt')
uploaded.Upload()

  "type " + obj.__name__ + ". It won't be checked "


In [0]:
model = torch.load('model.pt')

In [0]:
tmp = iter(dataloader)
batch_x, batch_y  = tmp.next()

output_dir = "VAE_results/dsprites/add_label/beta_{}/latten_size_{}".format(config.beta, config.latten_size)
print(output_dir)

def hidden_travel(images, label, neuron_id):
    batch_y_one_hot = (label.reshape(-1, 1) == torch.arange(config.num_classes).reshape(1, config.num_classes)).float()
    batch_y_one_hot = batch_y_one_hot.to(config.device)
    
    samples = images.reshape(-1, 1, config.image_size, config.image_size).to(config.device)
    num_imgs = len(samples)
    means, var, h = model.encode(samples)
    result = torch.zeros((num_imgs, 22, 1, config.image_size, config.image_size))

    result[:, 0] = samples
    
    with torch.no_grad():
        for i, d in enumerate(np.linspace(-3, 3, 21)):
            means_t = torch.clone(means)
            means_t[:, neuron_id] = d
            samples = model.decode(means_t, y=batch_y_one_hot).cpu()
            samples = torch.sigmoid(samples)
            result[:, i + 1] = samples
    print(result.shape)
    save_image(result.view(-1, 1, config.image_size, config.image_size), 
               '{}/travel_neuron_{}.png'.format(output_dir, neuron_id), 
               nrow=22, pad_value=255)

VAE_results/dsprites/add_label/beta_10/latten_size_10


In [0]:
try:
    os.makedirs(output_dir)
except Exception as e:
    print(e)

model.eval()
    
# with torch.no_grad():
#     samples = torch.randn(64, config.latten_size).to(config.device)
#     samples = model2.decode(samples)
#     samples = torch.sigmoid(samples).cpu()
#     save_image(samples.view(64, 1, config.image_size, config.image_size),
#                '{}/samples.png'.format(output_dir), pad_value=255)
    
    
# with torch.no_grad():
#     save_image(batch_x[0:32].view(-1, 1, config.image_size, config.image_size), '{}/orgin.png'.format(output_dir), pad_value=255)
#     samples = batch_x[0:32].reshape(-1, 1, config.image_size, config.image_size).to(config.device)
#     means, var = model2.encode(samples)
    
# #     mean, var, recon_x = model(samples, None)
#     print(var)
# #     recon_x_loss, KL_loss = loss_function(samples, mean, var, recon_x)
# #     recon_x_loss /= len(samples)
# #     KL_loss /= len(samples)
# #     print(recon_x_loss.item(), KL_loss.item())
    
#     samples = model2.decode(means)
#     samples = torch.sigmoid(samples).cpu()
#     save_image(samples.view(-1, 1, config.image_size, config.image_size), '{}/reconstructed.png'.format(output_dir), pad_value=255)
    
    
for neuron_id in range(0, config.latten_size):
    hidden_travel(batch_x[0:20], batch_y[0:20], neuron_id)

[Errno 17] File exists: 'VAE_results/dsprites/add_label/beta_10/latten_size_10'
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])
torch.Size([20, 22, 1, 64, 64])


In [0]:
import matplotlib.pyplot as plt
import glob

import os
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

auth.authenticate_user()
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

for f in glob.glob("{}/*.png".format(output_dir)):
    uploaded = drive.CreateFile({'title': f.split('/')[-1],  "parents": [{"kind": "drive#fileLink","id": '1B8s6zgrNsCUcjSivvNcPY_ews2l717ui'}]})
    uploaded.SetContentFile(f)
    uploaded.Upload()
    print('Uploaded file with ID {}'.format(uploaded.get('id')))

Uploaded file with ID 1niDuE1eJ__7GW5cCYya0gZIw7VZlwma7
Uploaded file with ID 1kOq4mNdMsLlhdpgix7D7I-0Nt8JM128p
Uploaded file with ID 1Ft2TtVIdE4noP4VE7k_ZooKvF6qc8lNT
Uploaded file with ID 1fNp-ED_PGD8T5N_zajnjbQxb0j5ul81z
Uploaded file with ID 1Wsye9jsqFX_oUiA-Z_OAdim2nhs01LSr
Uploaded file with ID 1LUmZSgviGeVoRZNVlMGHSvII8GptmL3I
Uploaded file with ID 1mnH9IuXPsGZrM7-8GpvrFOIOpVform7b
Uploaded file with ID 1hJAcfPdMzdC2QkqOKx0WZPCGvaFjphtf
Uploaded file with ID 1u-Eu9WnB1HQDTuk2l1h9mGWcgcoCZr-o
Uploaded file with ID 19yhXfsqEMx9QeCu1BG0EBQIBDSyPP1lg
