In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import grad
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import numpy as np
from matplotlib.patches import Rectangle
import pylab as plt
import os
import urllib
import pandas as pd

np.set_printoptions(precision=2)
torch.set_printoptions(precision=2)
torch.set_printoptions(sci_mode=False)
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')

torch.__version__

'1.6.0+cu101'

In [2]:
# Load data
DATA_PATH = "https://raw.githubusercontent.com/ch00226855/ImputationGAN/master/Letter.csv"
DATA_FILE = "Letter.csv"

if not os.path.exists(DATA_FILE):
    urllib.request.urlretrieve(DATA_PATH, DATA_FILE)

raw_data = pd.read_csv(DATA_FILE, sep=",").to_numpy().astype(np.float32)
print("Data shape:", raw_data.shape)

#divide the raw data into 4 subsets
# Split Data into four pieces
full_size = len(raw_data)
n_subsets = 4
sub_size = full_size // n_subsets

subsets = []
for k in range(4):
    #subsets.append(raw_data[k*sub_size:(k+1)*sub_size, k:k+13])
    subsets.append(raw_data[k*sub_size:(k+1)*sub_size, k*3:(k*3)+7])
    
print("Split data into", n_subsets, "subsets. Size of one subset:", sub_size)
subset0=subsets[0]
subset1=subsets[1]
subset2=subsets[2]
subset3=subsets[3]
# Normalization (0 to 1)
Dim = subset0.shape[1]
Min_Val = np.zeros(Dim)
Max_Val = np.zeros(Dim)

for i in range(Dim):
    Min_Val[i] = np.min(subset3[:,i])
    subset3[:,i] = subset3[:,i] - np.min(subset3[:,i])
    Max_Val[i] = np.max(subset3[:,i])
    subset3[:,i] = subset3[:,i] / (np.max(subset3[:,i]) + 1e-6)   

subset3[:3]

Data shape: (20000, 16)
Split data into 4 subsets. Size of one subset: 5000


array([[0.6 , 0.53, 0.62, 0.13, 0.69, 0.27, 0.6 ],
       [0.27, 0.47, 0.31, 0.4 , 0.38, 0.13, 0.47],
       [0.67, 0.6 , 0.69, 0.2 , 0.54, 0.2 , 0.4 ]], dtype=float32)

In [3]:
class MaskedDataset(Dataset):
    def __init__(self, subset, block_len, random_seed=0):
        self.block_len = block_len
        self.rnd = np.random.RandomState(random_seed)
        self.data_size = len(subset)
        self.num_features = len(subset[0])
        self.generate_incomplete_data(subset)

    def __getitem__(self, index):
        # return index so we can retrieve the mask location from self.mask_loc
        return self.data[index], self.mask[index], index

    def __len__(self):
        return self.data_size

    def generate_incomplete_data(self, subset):
        n_masks = self.data_size
        self.data = [None] * n_masks
        self.mask = [None] * n_masks
        self.mask_loc = [None] * n_masks
        for i in range(n_masks):
            d0 = self.rnd.randint(0, self.num_features - self.block_len + 1)
            mask = torch.zeros((self.num_features), dtype=torch.uint8)
            mask[d0:(d0 + self.block_len)] = 1
            self.mask[i] = mask.unsqueeze(0)   # add an axis for channel
            self.mask_loc[i] = d0, self.block_len
            # Mask out missing pixels by zero
            self.data[i] = torch.from_numpy(subset[i]) * mask.float()

In [4]:
# Create masked data
masked_data = MaskedDataset(subset3, block_len=4)
batch_size = 64
data_loader = DataLoader(masked_data, batch_size=batch_size, shuffle=True,
                         drop_last=True)

Compare raw data and masked data:

## MisGAN on Numerical Data

In [5]:
def mask_data(data, mask, tau=0):
    return mask * data + (1 - mask) * tau

In [6]:
# Generator
class Generator(nn.Module):
    def __init__(self, latent_size, DIM, num_features):
        super().__init__()

        self.DIM = DIM
        self.latent_size = latent_size
        self.num_features = num_features

        self.preprocess=nn.Linear(latent_size, 4 * self.DIM)
        self.block1=nn.Linear(4 * self.DIM, 2 * self.DIM)
        self.block2=nn.Linear(2 * self.DIM, self.DIM+15)
        self.final = nn.Linear(self.DIM+15, num_features)

    def forward(self, input):
        net = self.preprocess(input).clamp(min=0)
        # net = net.view(-1, 4 * self.DIM, 4, 4)
        net = self.block1(net).clamp(min=0)
        # net = net[:, :, :7, :7]
        net = self.block2(net).clamp(min=0)
        # net = self.deconv_out(net)
        net = self.final(net)
        # return self.transform(net).view(-1, 1, 28, 28)
        # return self.transform(net).view(-1, 1, self.num_features)
        # return self.transform(net).view(-1, self.num_features)
        return net


class DataGenerator(Generator):
    def __init__(self, latent_size, DIM, num_features):
        super().__init__(latent_size, DIM, num_features)
        self.transform = lambda x: torch.sigmoid(x)


class MaskGenerator(Generator):
    def __init__(self, latent_size, DIM, num_features, temperature=.66):
        super().__init__(latent_size, DIM, num_features)
        self.transform = lambda x: torch.sigmoid(x / temperature)

In [7]:
class Critic(nn.Module):
    def __init__(self, DIM, num_features):
        super().__init__()

        self.DIM = DIM
        self.num_features = num_features
        
        self.preprocess=nn.Linear(num_features, self.DIM)
        self.block1=nn.Linear(self.DIM, 2 * self.DIM)
        self.block2=nn.Linear(2 * self.DIM, self.DIM+15)
        self.final = nn.Linear(self.DIM+15, 1)
        

    def forward(self, input):
        net = self.preprocess(input).clamp(min=0)
        net = self.block1(net).clamp(min=0)
        net = self.block2(net).clamp(min=0)
        net = self.final(net)

        return net.view(-1)

In [8]:
class CriticUpdater:
    def __init__(self, critic, critic_optimizer, batch_size=64, gp_lambda=10):
        self.critic = critic
        self.critic_optimizer = critic_optimizer
        self.gp_lambda = gp_lambda
        # Interpolation coefficient
        self.eps = torch.empty(batch_size, 1, device=device)
        # For computing the gradient penalty
        self.ones = torch.ones(batch_size).to(device)

    def __call__(self, real, fake):
        real = real.detach()
        fake = fake.detach()
        # print("real shape:", real.shape)
        # print("fake shape:", fake.shape)
        # print('eps shape:', self.eps.shape)
        self.critic.zero_grad()
        self.eps.uniform_(0, 1)
        interp = (self.eps * real + (1 - self.eps) * fake).requires_grad_()
        # print('interp shape:', interp.shape)
        # print('critic(interp) shape:', self.critic(interp).shape)
        # print('self.ones shape:', self.ones.shape)
        grad_d = grad(self.critic(interp), interp, grad_outputs=self.ones,
                      create_graph=True)[0]
        grad_d = grad_d.view(real.shape[0], -1)
        grad_penalty = ((grad_d.norm(dim=1) - 1)**2).mean() * self.gp_lambda
        w_dist = self.critic(fake).mean() - self.critic(real).mean()
        loss = w_dist + grad_penalty
        loss.backward()
        self.critic_optimizer.step()

In [9]:
n_critic = 5
alpha = .2
num_features = masked_data.num_features
DIM = num_features # DIM determines layer sizes
nz = num_features # dimensionality of the latent code
latent_size = nz # number of random noise

data_gen = DataGenerator(latent_size, DIM, num_features).to(device)
mask_gen = MaskGenerator(latent_size, DIM, num_features).to(device)

data_critic = Critic(DIM, num_features).to(device)
mask_critic = Critic(DIM, num_features).to(device)

data_noise = torch.empty(batch_size, nz, device=device)
mask_noise = torch.empty(batch_size, nz, device=device)

lrate = 1e-4
data_gen_optimizer = optim.Adam(
    data_gen.parameters(), lr=lrate, betas=(.5, .9))
mask_gen_optimizer = optim.Adam(
    mask_gen.parameters(), lr=lrate, betas=(.5, .9))

data_critic_optimizer = optim.Adam(
    data_critic.parameters(), lr=lrate, betas=(.5, .9))
mask_critic_optimizer = optim.Adam(
    mask_critic.parameters(), lr=lrate, betas=(.5, .9))

update_data_critic = CriticUpdater(
    data_critic, data_critic_optimizer, batch_size)
update_mask_critic = CriticUpdater(
    mask_critic, mask_critic_optimizer, batch_size)

In [10]:
plot_interval = 5
critic_updates = 0

for epoch in range(100):
    for real_data, real_mask, _ in data_loader:

        real_data = real_data.to(device).float()
        real_mask = real_mask.view(real_data.shape).to(device).float()

        # Update discriminators' parameters
        data_noise.normal_()
        mask_noise.normal_()

        fake_data = data_gen(data_noise)
        fake_mask = mask_gen(mask_noise)

        masked_fake_data = mask_data(fake_data, fake_mask)
        masked_real_data = mask_data(real_data, real_mask)

        # print("masked_real_data shape:", masked_real_data.shape)
        # print("masked_fake_data shape:", masked_fake_data.shape)
        # print("real_mask shape:", real_mask.shape)
        # print("fake_mask shape:", fake_mask.shape)

        update_data_critic(masked_real_data, masked_fake_data)
        update_mask_critic(real_mask, fake_mask)

        critic_updates += 1

        if critic_updates == n_critic:
            critic_updates = 0

            # Update generators' parameters
            for p in data_critic.parameters():
                p.requires_grad_(False)
            for p in mask_critic.parameters():
                p.requires_grad_(False)

            data_gen.zero_grad()
            mask_gen.zero_grad()

            data_noise.normal_()
            mask_noise.normal_()

            fake_data = data_gen(data_noise)
            fake_mask = mask_gen(mask_noise)
            masked_fake_data = mask_data(fake_data, fake_mask)

            data_loss = -data_critic(masked_fake_data).mean()
            mask_loss = -mask_critic(fake_mask).mean()
            data_loss.backward(retain_graph=True)
            (mask_loss + data_loss * alpha).backward()
            data_gen_optimizer.step()
            mask_gen_optimizer.step()

            for p in data_critic.parameters():
                p.requires_grad_(True)
            for p in mask_critic.parameters():
                p.requires_grad_(True)

    if plot_interval > 0 and (epoch + 1) % plot_interval == 0:
        print("Iteration:", epoch + 1)
        # Although it makes no difference setting eval() in this example, 
        # you will need those if you are going to use modules such as 
        # batch normalization or dropout in the generators.
        data_gen.eval()
        mask_gen.eval()

        with torch.no_grad():
            # print('Epoch:', epoch)
            # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 2.5))
            
            data_noise.normal_()
            data_samples = data_gen(data_noise)
            # plot_grid(ax1, data_samples, title='generated complete data')
            #print("Generated data:")
            #print(data_samples)
            
            mask_noise.normal_()
            mask_samples = mask_gen(mask_noise)
            # plot_grid(ax2, mask_samples, title='generated masks')
            #print("Generated mask:")
            #print(mask_samples)
            
            # plt.show()
            # plt.close(fig)

        data_gen.train()
        mask_gen.train()

Iteration: 5
Iteration: 10
Iteration: 15
Iteration: 20
Iteration: 25
Iteration: 30
Iteration: 35
Iteration: 40
Iteration: 45
Iteration: 50
Iteration: 55
Iteration: 60
Iteration: 65
Iteration: 70
Iteration: 75
Iteration: 80
Iteration: 85
Iteration: 90
Iteration: 95
Iteration: 100


In [11]:
class Imputer(nn.Module):
    def __init__(self, num_features, arch):
        super().__init__()
        self.preprocess=nn.Linear(num_features, arch[0])
        self.block1=nn.Linear(arch[0], arch[1])
        self.block2=nn.Linear(arch[1], arch[0]+15)
        self.final = nn.Linear(arch[0]+15, num_features)
        
    def forward(self, data, mask, noise):
        net = data * mask + noise * (1 - mask)
        net = net.view(data.shape[0], -1)
        net = self.preprocess(net).clamp(min=0)
        net = self.block1(net).clamp(min=0)
        net = self.block2(net).clamp(min=0)
        net = self.final(net)

        net = torch.sigmoid(net).view(data.shape)
        return data * mask + net * (1 - mask)

In [12]:
imputer = Imputer(num_features, (num_features, num_features)).to(device)
impu_critic = Critic(DIM, num_features).to(device)
impu_noise = torch.empty(batch_size, num_features, device=device)

imputer_lrate = 2e-4
imputer_optimizer = optim.Adam(
    imputer.parameters(), lr=imputer_lrate, betas=(.5, .9))
impu_critic_optimizer = optim.Adam(
    impu_critic.parameters(), lr=imputer_lrate, betas=(.5, .9))
update_impu_critic = CriticUpdater(
    impu_critic, impu_critic_optimizer, batch_size)

In [15]:
beta = .1
plot_interval = 10
critic_updates = 0
import time
t1=time.time()
for epoch in range(150):
    for real_data, real_mask, index in data_loader:

        real_data = real_data.to(device).float()
        real_mask = real_mask.view(real_data.shape).to(device).float()

        masked_real_data = mask_data(real_data, real_mask)

        # Update discriminators' parameters
        data_noise.normal_()
        fake_data = data_gen(data_noise)

        mask_noise.normal_()
        fake_mask = mask_gen(mask_noise)
        masked_fake_data = mask_data(fake_data, fake_mask)

        impu_noise.uniform_()
        imputed_data = imputer(real_data, real_mask, impu_noise)

        update_data_critic(masked_real_data, masked_fake_data)
        update_mask_critic(real_mask, fake_mask)
        update_impu_critic(fake_data, imputed_data)

        critic_updates += 1
        if critic_updates == n_critic:
            critic_updates = 0

            # Update generators' parameters
            for p in data_critic.parameters():
                p.requires_grad_(False)
            for p in mask_critic.parameters():
                p.requires_grad_(False)
            for p in impu_critic.parameters():
                p.requires_grad_(False)

            data_noise.normal_()
            fake_data = data_gen(data_noise)

            mask_noise.normal_()
            fake_mask = mask_gen(mask_noise)
            masked_fake_data = mask_data(fake_data, fake_mask)

            impu_noise.uniform_()
            imputed_data = imputer(real_data, real_mask, impu_noise)

            data_loss = -data_critic(masked_fake_data).mean()
            mask_loss = -mask_critic(fake_mask).mean()
            impu_loss = -impu_critic(imputed_data).mean()

            mask_gen.zero_grad()
            (mask_loss + data_loss * alpha).backward(retain_graph=True)
            data_gen.zero_grad()
            (data_loss + impu_loss * beta).backward(retain_graph=True)
            mask_gen_optimizer.step()
            data_gen_optimizer.step()

            imputer.zero_grad()
            impu_loss.backward()
            imputer_optimizer.step()

            for p in data_critic.parameters():
                p.requires_grad_(True)
            for p in mask_critic.parameters():
                p.requires_grad_(True)
            for p in impu_critic.parameters():
                p.requires_grad_(True)

    if plot_interval > 0 and (epoch ) % plot_interval == 0:
        print("Iteration:", epoch)

        with torch.no_grad():
            imputer.eval()

            val_size = 3000
            val_noise = torch.empty(val_size, num_features, device=device).uniform_()
            val_data = torch.stack(masked_data.data)[:val_size].float()
            val_mask = torch.stack(masked_data.mask)[:val_size].view(val_data.shape)
            imputed_data = imputer(val_data.to(device),
                                   val_mask.to(device).float(),
                                   val_noise)
            rmse = np.sqrt(np.mean((subset3[:val_size]-imputed_data.numpy())**2))
            print("Validation RMSE:", rmse)


            test_size = 3000
            # Show imputation results
            # impu_noise.uniform_()
            test_noise = torch.empty(test_size, num_features, device=device).uniform_()
            # imputed_data = imputer(real_data, real_mask, impu_noise)
            test_data = torch.stack(masked_data.data)[val_size:(val_size + test_size)].float()
            test_mask = torch.stack(masked_data.mask)[val_size:(val_size + test_size)].view(test_data.shape)
            imputed_data = imputer(test_data.to(device),
                                   test_mask.to(device).float(),
                                   test_noise)
            #print("Real data:")
            #print(subset0[:test_size])
            #print("Imputed data:")
            #print(imputed_data)
            #print("Mask:")
            #print(torch.stack(masked_data.mask)[:test_size])
            rmse = np.sqrt(np.mean((subset3[val_size:(val_size + test_size)]-imputed_data.numpy())**2))
            print("Test RMSE:", rmse)
            # print('Epoch:', epoch)
            # fig, ax = plt.subplots(figsize=(6, 3))
            # plot_grid(ax, imputed_data, bbox, gap=2)
            # plt.show()
            # plt.close(fig)

            imputer.train()
t2=time.time()
print("Run time: ",t2-t1)

Iteration: 0
Validation RMSE: 0.10677422
Test RMSE: 0.1139572
Iteration: 10
Validation RMSE: 0.09948741
Test RMSE: 0.10861789
Iteration: 20
Validation RMSE: 0.097926855
Test RMSE: 0.109125584
Iteration: 30
Validation RMSE: 0.10959242
Test RMSE: 0.12131825
Iteration: 40
Validation RMSE: 0.13550365
Test RMSE: 0.14653997
Iteration: 50
Validation RMSE: 0.17662743
Test RMSE: 0.18683586
Iteration: 60
Validation RMSE: 0.22640716
Test RMSE: 0.23547938
Iteration: 70
Validation RMSE: 0.27672377
Test RMSE: 0.28371757
Iteration: 80
Validation RMSE: 0.31930196
Test RMSE: 0.32593176
Iteration: 90
Validation RMSE: 0.34835127
Test RMSE: 0.3552013
Iteration: 100
Validation RMSE: 0.36006895
Test RMSE: 0.36680388
Iteration: 110
Validation RMSE: 0.36232227
Test RMSE: 0.36897057
Iteration: 120
Validation RMSE: 0.3563582
Test RMSE: 0.3639042
Iteration: 130
Validation RMSE: 0.33881697
Test RMSE: 0.34891152
Iteration: 140
Validation RMSE: 0.31841174
Test RMSE: 0.3323738
Run time:  111.46116042137146
