In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils as utils
import torch.utils.data
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.datasets as dsets
from torchvision import models
from torchvision.utils import save_image
import pdb


from tqdm import tqdm

from flows import PlanarFlow
from utils import Binarize
import codes

from torchmeta.datasets.helpers import cifar_fs
from torchmeta.utils.data import BatchMetaDataLoader


#from __future__ import print_function
import argparse
import cv2
import matplotlib.pyplot as plt

import os
cur_dir = "C:/Users/KJH/OneDrive - skku.edu/KJH/Projects/2019winter_research"
#cur_dir = "C:/Users/KJH-Laptop/OneDrive - skku.edu/KJH/Projects/2019winter_research/"
os.chdir(cur_dir)
#os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import time
import copy
import random as rd

device = torch.device('cuda')

# Normalize training set together with augmentation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
])

# Normalize test set same as training set without augmentation
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.507, 0.487, 0.441], std=[0.267, 0.256, 0.276])
])

batch_size_train = 128
batch_size_test = 128

trainset = torchvision.datasets.CIFAR10('./data/CIFAR10/',
                                         train=True,
                                         download=True,
                                         transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size_train, shuffle=True, num_workers=8)

testset = torchvision.datasets.CIFAR10('./data/CIFAR10/',
                                        train=False,
                                        download=True,
                                        transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=batch_size_test, shuffle=False, num_workers=8)

class Conv2d_flipout(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False, padding_mode='zeros'):
        super(Conv2d_flipout, self).__init__(
                 in_channels = in_channels, out_channels = out_channels, kernel_size = kernel_size, stride = stride,
                 padding = padding, dilation = dilation, groups = groups, bias = bias, padding_mode = padding_mode)
        
        self.weight_mean = nn.Parameter(torch.empty(self.weight.data.shape, device = device).normal_(0,1))
        weight_prec_prior = torch.distributions.gamma.Gamma(torch.ones(self.weight.data.shape), torch.ones(self.weight.data.shape))
        self.weight_logvar = nn.Parameter(weight_prec_prior.sample().reciprocal().log().to(device))
        
        self.kld = None
        
    def forward(self, x):
        self.weight.data = torch.empty(self.weight.data.shape, device = device).normal_(0,1) * self.weight_logvar.div(2).exp() + self.weight_mean
        x = torch.empty(x.shape, device = device, requires_grad = False).uniform_(-1,1).sign() * x
        self.kld = torch.sum(self.weight_mean**2 -self.weight_logvar +self.weight_logvar.exp() -1)/2 / x.shape[0]
        return self.conv2d_forward(x, self.weight)



class Linear_flipout(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear_flipout, self).__init__()
        
        self.in_features = in_features
        self.out_features = out_features
        
        self.weight_mean = nn.Parameter(torch.empty([self.in_features, self.out_features], device = device).normal_(0,1))
        weight_prec_prior = torch.distributions.gamma.Gamma(torch.ones(self.weight_mean.data.shape)/2, torch.ones(self.weight_mean.data.shape)/2)
        self.weight_logvar = nn.Parameter(weight_prec_prior.sample().reciprocal().log().to(device))
        
        self.kld = None
        
        self.bias = bias
        if bias:
            self.bias_mean = nn.Parameter(torch.empty(self.out_features, device = device).normal_(0,1))
            bias_prec_prior = torch.distributions.gamma.Gamma(torch.ones(self.bias_mean.data.shape)/2, torch.ones(self.bias_mean.data.shape)/2)
            self.bias_logvar = nn.Parameter(bias_prec_prior.sample().reciprocal().log().to(device))
        else:
            self.bias_mean = None
            self.bias_logvar = None
            
    def forward(self, x):
        weight_noise = torch.empty(self.weight_mean.data.shape, device = device).normal_(0,1)
        
        output = torch.mm(x, self.weight_mean)

        in_sign = torch.empty(x.shape, device = device, requires_grad = False).uniform_(-1,1).sign()
        out_sign = torch.empty([x.shape[0], self.out_features], device = device, requires_grad = False).uniform_(-1,1).sign()
        output += torch.mm(in_sign * x, weight_noise * self.weight_mean * self.weight_logvar.div(2).exp()) * out_sign
        
        self.kld = torch.sum(self.weight_mean**2 -self.weight_logvar +self.weight_logvar.exp() -1) /2 / x.shape[0]   
        
        if self.bias:
            bias_noise = torch.empty(self.bias_mean.data.shape, device = device).normal_(0,1)
            output += (1 + bias_noise * self.bias_logvar.div(2).exp()) * self.bias_mean
            self.bias_mean.grad = (1 + bias_noise * self.bias_logvar.div(2).exp())
            self.bias_logvar.grad = (self.bias_mean * bias_noise).div(2) * self.bias_logvar.div(2).exp()
            
            self.kld += torch.sum(self.bias_mean**2 -self.bias_logvar + self.bias_logvar.exp() -1)/2 / x.shape[0]
            
        return output



def loss_fn(pred, label, model, progress):
    loss = F.cross_entropy(pred, label, weight=None, ignore_index=-100, reduction='mean')
    if progress < 0.5:
        loss += (2 * progress)**2 * model.kld()
    else:
        loss += model.kld()
    return loss

Files already downloaded and verified
Files already downloaded and verified


In [2]:
reg = Linear_flipout(2, 1, bias = True).cuda()
#reg = nn.Linear(2,1).cuda()
x = torch.empty(1000, 2, device = device).normal_(0,1)
y = (torch.empty(1, device = device).normal_(2,0.7) * x[:,0:1]
     + torch.empty(1, device = device).normal_(-3,0.3) * x[:,1:]
     + torch.empty(1000, 1, device = device).normal_(0,1) + 1)

#opt = optim.SGD(reg.parameters(), lr=3e-4, momentum=0.9, weight_decay=5e-4)
opt = optim.Adam(reg.parameters(), lr=1)

print(reg.weight_mean)
print(reg.weight_logvar.div(2).exp())
print(reg.bias_mean)
print(reg.bias_logvar.div(2).exp())
print(reg.kld)

for rep in range(500):
    opt.zero_grad()
    loss = ((y - reg(x))**2).sum() + reg.kld
    loss.backward()
    #reg.backward()
    #print(reg.weight.grad)
    #print(loss.detach())
    
    opt.step()
    print(reg.kld)

print(reg.weight_mean)
print(reg.weight_logvar.div(2).exp())
print(reg.bias_mean)
print(reg.bias_logvar.div(2).exp())


Parameter containing:
tensor([[ 1.8994],
        [-0.1946]], device='cuda:0', requires_grad=True)
tensor([[1.0562],
        [1.8276]], device='cuda:0', grad_fn=<ExpBackward>)
Parameter containing:
tensor([-1.1346], device='cuda:0', requires_grad=True)
tensor([112.5163], device='cuda:0', grad_fn=<ExpBackward>)
None
tensor(6.3278, device='cuda:0', grad_fn=<AddBackward0>)
tensor(2.3252, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.9517, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.4784, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2629, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1588, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.1034, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0719, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0531, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0414, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0339, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.0290, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.025

In [None]:
def get_context_idx(N):
    # generate the indeces of the N context points in a flattened image
    idx = random.sample(range(0, 784), N)
    idx = torch.tensor(idx, device=device)
    return idx


def generate_grid(h, w):
    rows = torch.linspace(0, 1, h, device=device)
    cols = torch.linspace(0, 1, w, device=device)
    grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1)
    grid = grid.unsqueeze(0)
    return grid


def idx_to_y(idx, data):
    # get the [0;1] pixel intensity at each index
    y = torch.index_select(data, dim=1, index=idx)
    return y


def idx_to_x(idx, batch_size):
    # From flat idx to 2d coordinates of the 28x28 grid. E.g. 35 -> (1, 7)
    # Equivalent to np.unravel_index()
    x = torch.index_select(x_grid, dim=1, index=idx)
    x = x.expand(batch_size, -1, -1)
    return x


class NP(nn.Module):
    def __init__(self, args):
        super(NP, self).__init__()
        self.r_dim = args.r_dim
        self.z_dim = args.z_dim

        self.h_1 = nn.Linear(3, 400)
        self.h_2 = nn.Linear(400, 400)
        self.h_3 = nn.Linear(400, self.r_dim)

        self.r_to_z_mean = nn.Linear(self.r_dim, self.z_dim)
        self.r_to_z_logvar = nn.Linear(self.r_dim, self.z_dim)

        self.g_1 = nn.Linear(self.z_dim + 2, 400)
        self.g_2 = nn.Linear(400, 400)
        self.g_3 = nn.Linear(400, 400)
        self.g_4 = nn.Linear(400, 400)
        self.g_5 = nn.Linear(400, 1)

    def h(self, x_y):
        x_y = F.relu(self.h_1(x_y))
        x_y = F.relu(self.h_2(x_y))
        x_y = F.relu(self.h_3(x_y))
        return x_y

    def aggregate(self, r):
        return torch.mean(r, dim=1)

    def reparameterise(self, z):
        mu, logvar = z
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z_sample = eps.mul(std).add_(mu)
        z_sample = z_sample.unsqueeze(1).expand(-1, 784, -1)
        return z_sample

    def g(self, z_sample, x_target):
        z_x = torch.cat([z_sample, x_target], dim=2)
        input = F.relu(self.g_1(z_x))
        input = F.relu(self.g_2(input))
        input = F.relu(self.g_3(input))
        input = F.relu(self.g_4(input))
        y_hat = torch.sigmoid(self.g_5(input))
        return y_hat

    def xy_to_z_params(self, x, y):
        x_y = torch.cat([x, y], dim=2)
        r_i = self.h(x_y)
        r = self.aggregate(r_i)

        mu = self.r_to_z_mean(r)
        logvar = self.r_to_z_logvar(r)

        return mu, logvar

    def forward(self, x_context, y_context, x_all=None, y_all=None):
        z_context = self.xy_to_z_params(x_context, y_context)  # (mu, logvar) of z
        if self.training:  # loss function will try to keep z_context close to z_all
            z_all = self.xy_to_z_params(x_all, y_all)
        else:  # at test time we don't have the image so we use only the context
            z_all = z_context

        z_sample = self.reparameterise(z_all)

        # reconstruct the whole image including the provided context points
        x_target = x_grid.expand(y_context.shape[0], -1, -1)
        y_hat = self.g(z_sample, x_target)

        return y_hat, z_all, z_context


def kl_div_gaussians(mu_q, logvar_q, mu_p, logvar_p):
    var_p = torch.exp(logvar_p)
    kl_div = (torch.exp(logvar_q) + (mu_q - mu_p) ** 2) / var_p \
             - 1.0 \
             + logvar_p - logvar_q
    kl_div = 0.5 * kl_div.sum()
    return kl_div


def np_loss(y_hat, y, z_all, z_context):
    BCE = F.binary_cross_entropy(y_hat, y, reduction="sum")
    KLD = kl_div_gaussians(z_all[0], z_all[1], z_context[0], z_context[1])
    return BCE + KLD


model = NP(args).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
x_grid = generate_grid(28, 28)
os.makedirs("results/", exist_ok=True)

In [None]:
model = 
epoch = 200
lr = 3e-4
#optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

#0.1 for epoch [0,150)
#0.01 for epoch [150,250)
#0.001 for epoch [250,350)

running_test_loss = 0.0
for run in range(epoch):
    start = time.time()
    
    #Training
    model.train()
    train_loss = 0.0
    for ind, data in enumerate(trainloader):
        optimizer.zero_grad()
        img, label = data
        pred = model(img.cuda())
        loss = loss_fn(pred, label.cuda(), model, run/epoch)
        train_loss += loss.detach() * img.shape[0] 
        loss.backward()
        optimizer.step()
    train_loss /= len(trainset)
    
    #Test
    model.eval()
    with torch.no_grad():
        test_loss = 0.0
        acc = 0.0
        sum_kl = 0
        for ind, data in enumerate(testloader):
            img, label = data
            pred = model(img.cuda())
            sum_kl += model.kld().detach()
            #test_loss += loss_fn(pred, label.cuda(), model).detach() * img.shape[0]
            acc += sum(pred.argmax(1) == label.cuda()).item()
        test_loss /= len(testset)
        acc /= len(testset)
        running_test_loss += test_loss
    #print("epoch : %d, train loss = %5.6f, test loss = %5.6f, running_test_loss = %5.6f, acc = %.3f, reg = %.4f, time: %f sec"
    #      %(run, train_loss, test_loss, running_test_loss / (run + 1), acc, sum_kl/(ind + 1), time.time() - start))
    print("epoch : %d, train loss = %5.6f, acc = %.3f, reg = %.4f, time: %f sec"
          %(run, train_loss, acc, sum_kl/(ind + 1), time.time() - start))

Files already downloaded and verified
Files already downloaded and verified
epoch : 0, train loss = 2.357553, acc = 0.102, reg = 91.5498, time: 124.659379 sec
epoch : 1, train loss = 2.366360, acc = 0.108, reg = 81.0944, time: 123.171208 sec
epoch : 2, train loss = 2.381727, acc = 0.097, reg = 70.5443, time: 121.357049 sec
epoch : 3, train loss = 2.401616, acc = 0.100, reg = 61.5948, time: 120.932984 sec
epoch : 4, train loss = 2.426486, acc = 0.102, reg = 54.0632, time: 121.403174 sec
