In [1]:
from VAE import *

In [None]:
parser = argparse.ArgumentParser(description='VAE Example')
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                    help='input batch size for training (default: 128)')
parser.add_argument('--epochs', type=int, default=10, metavar='N',
                    help='number of epochs to train (default: 10)')
parser.add_argument('--no-cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
                    help='how many batches to wait before logging training status')
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()


torch.manual_seed(args.seed)

In [2]:


device = torch.device("cuda")

In [3]:

# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    BCE = F.mse_loss(recon_x, x, size_average=False)

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE + KLD


def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (data, _) in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss = loss_function(recon_batch, data, mu, logvar)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch, train_loss / len(train_loader.dataset)))


def test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, (data, _) in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                      recon_batch.view(args.batch_size, 1, None)[:n]])
                
                pd.DataFrame(comparison).to_csv('results/reconstruction_' + str(epoch) + '.csv', index = False)

    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))



In [4]:
import pandas as pd
predata = pd.read_csv('context_plus_plus_fixed.csv', nrows = 10)
cols = predata.columns.tolist()
target_cols = [c for c in cols if c[:7]=='target_']
print(len(target_cols))
cols = [c for c in cols if c not in target_cols]
icd9cols = [c for c in cols if c[:4]=='icd9']
cols = [c for c in cols if c not in icd9cols]
print(len(icd9cols))
idcols = [c for c in cols if c[-3:]=='_id']
print(len(idcols))
cols = [c for c in cols if c not in idcols]
categ_cols = [c for c in cols if (c[0] == c[0].upper())]
cols = [c for c in cols if c not in categ_cols]

3
2356
2


In [7]:
train_loader = torch.utils.data.DataLoader(
    DataLoader_Bad('context_plus_plus_fixed.csv', cols), train=True,
    batch_size=28, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    DataLoader_Bad('context_plus_plus_fixed.csv', cols, train=False,
    batch_size=28, shuffle=True))

model = VAE(2356).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)


for epoch in range(1, args.epochs + 1):
    train(epoch)
    test(epoch)
    with torch.no_grad():
        sample = torch.randn(64, 20).to(device)
        sample = model.decode(sample).cpu()
        
        np.savetxt('results/sample_' + str(epoch) + '.txt', sample)

TypeError: __init__() got an unexpected keyword argument 'train'