In [70]:
import numpy as np
import math
import torch
import torch.nn as nn
from torch.optim import Adam
from pathlib import Path
from sklearn.model_selection import train_test_split
from torch.optim import Adam, RMSprop
from torch.utils.data import DataLoader, TensorDataset
from utils import count_time, tensor
from train_gan_adjusted import train_gan, generate_dataset_from_gan
from classifier import Classifier

"""
np.srandn(0)
x0 = np.randn()
print(x0)
"""
x_dim = 2
hiddens = [x_dim, 32, 64, 1]
test_size = 0.2
valid_size = 0.125
batch_size = 10
c_hiddens = [x_dim, 32, 64, 1]
g_hidden_size = 64
g_num_layers = 2
d_hidden_size = 64
d_num_layers = 2
gan_epochs = 3000

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def to_tensor(z, x, y=None):
    if torch.is_tensor(x):
        zx = torch.cat([z, x], dim=1)
    else:
        zx = np.concatenate([z, x], axis=1)
        zx = torch.FloatTensor(zx)
    if isinstance(y, np.ndarray):
        y = torch.FloatTensor(y)
        return zx, y
    return zx

class TrueModel(nn.Module):

    def __init__(self, hiddens, seed=0):
        super().__init__()
        layers = []
        for in_dim, out_dim in zip(hiddens[:-1], hiddens[1:]):
            layers.append(nn.Linear(in_dim, out_dim))
            layers.append(nn.ReLU(inplace=True))
        layers.pop()
        layers.append(nn.Sigmoid())
        self.model = nn.Sequential(*layers)

        self.loss_fn = nn.BCELoss()
        self.optim = Adam(self.parameters())

    def forward(self, zx):
        return self.model(zx)

    def predict(self, z, x):
        zx = to_tensor(z, x)
        pred = self(zx)
        pred_y = pred.detach().round().cpu().numpy()
        return pred_y

    def fit(self, z, x, y, patience=10):
        zx, y = to_tensor(z, x, y)

        epoch, counter = 0, 0
        best_loss = float('inf')
        while True:
            pred = self(zx)
            loss = self.loss_fn(pred, y)

            self.optim.zero_grad()
            loss.backward()
            self.optim.step()
            
            epoch += 1
            if loss.item() <= best_loss:
                torch.save(self.state_dict(), self.path)
                best_loss = loss.item()
                counter = 0
            else:
                counter += 1
                if counter == patience:
                    break
        print(f"TrueModel Fit Done in {epoch} epochs!")


def gen_initial_data(n):
    np.random.seed(0)
    x0 = np.random.randn()
    z0 = math.cos(x0) + np.random.randn()
    x = torch.from_numpy(np.ones((n, 1)))
    z = torch.from_numpy(np.ones((n, 1)))
    y = torch.from_numpy(np.ones((n, 1)))
    x[0] = x0
    z[0] = z0
    for i in range(1, n):
        x[i] = np.random.randn()
        z[i] = math.cos(x[i]) + np.random.randn()
    y = torch.bernoulli(1 /(1+  np.exp(-x +z)))
    return x, z, y

def gen_data(n, l):
    x, z, y = gen_initial_data(100)
    model = TrueModel(hiddens, seed = 0)
    zx = to_tensor(z, x)
    zx.requires_grad = True
    zx = zx.to(dtype=torch.float32)
    prob = model(zx)
    loss = nn.BCELoss()(prob, torch.ones_like(prob))
    loss.backward()
    for i in range(1, n):
        loss = nn.BCELoss()(prob, torch.ones_like(prob))
        loss.backward()
        x[i] = np.random.randn() + l*(x[1-i] - y[i-1]*loss)
        z[i] = math.cos(x[i]) + np.random.randn() + l*(z[i-1] - y[i-1]*loss)
    return x, z, y
x, z, y = gen_initial_data(100)
x = x.to(dtype = torch.float32)
y = y.to(dtype = torch.float32)
z = z.to(dtype = torch.float32)
print(x.dtype)

torch.float32


In [33]:
def train_discriminator(clf, G, D, optim, loss_fn, xs, zs, ss):
    xs_fake, _, _ = G(xs[:, 0], zs, ss, clf)
    fake = D(xs_fake.detach())
    loss_fake = loss_fn(fake, torch.zeros_like(fake))

    real = D(xs)
    loss_real = loss_fn(real, torch.ones_like(real))

    loss = loss_fake + loss_real

    optim.zero_grad()
    loss.backward()
    optim.step()

    return loss


def get_moment_loss(x_pred, x_true):
    m1 = torch.mean(torch.abs(x_pred.mean(dim=0) - x_true.mean(dim=0)))
    m2 = torch.mean(torch.abs(
        torch.sqrt(x_pred.var(dim=0, unbiased=False) + 1e-6) -
        torch.sqrt(x_true.var(dim=0, unbiased=False) + 1e-6)
    ))
    return m1 + m2


def train_generator(clf, G, D, optim, loss_fn, xs, zs, ss, gamma=100):
    xs_fake, _, _ = G(xs[:, 0], zs, ss, clf)
    fake = D(xs_fake)

    loss1 = loss_fn(fake, torch.ones_like(fake))
    loss2 = get_moment_loss(xs_fake, xs)
    loss = loss1 + gamma * loss2

    optim.zero_grad()
    loss.backward()
    optim.step()
    
    return loss


@count_time
def train_gan(loader, clf, G, D, n_epochs, device):
    g_optim = Adam(G.parameters())
    d_optim = Adam(D.parameters())
    loss_fn = nn.BCELoss()

    for epoch in range(n_epochs):
        
        for i, (s_mb, x_mb, y_mb) in enumerate(loader, start=1):
            batch, seq, dim = x_mb.size()
            x_mb = x_mb.to(device)
            z_mb = torch.rand(batch, seq-1, dim).to(device)

            for _ in range(2):
                g_loss = train_generator(clf, G, D, g_optim, loss_fn, x_mb, z_mb, s_mb)

            for _ in range(1):
                d_loss = train_discriminator(clf, G, D, d_optim, loss_fn, x_mb, z_mb, s_mb)

            step = epoch * len(loader) + i
            if step % 1000 == 0:
                print(f'Epoch: {epoch: 6.0f} | step: {step:6.0f} | d_loss: {d_loss:6.4f} | g_loss: {g_loss: 6.4f}')
            


def generate_dataset_from_gan(loader, clf, G, device, extra_seq=0):
    gen_s, gen_x, gen_y = [], [], []

    batch_size = None
    for s_mb, x_mb, y_mb in loader:
        batch, seq_len, x_dim = x_mb.shape
        if batch_size is None:
            batch_size = batch

        x_mb = x_mb.to(device)
        z_mb = torch.randn(batch, seq_len + extra_seq - 1, x_dim).to(device)

        gen_x_mb, _, gen_y_mb = G(x_mb[:, 0], z_mb, s_mb, clf)
        
        gen_s.append(s_mb)
        gen_x.append(gen_x_mb)
        gen_y.append(gen_y_mb)

    gen_s = torch.cat(gen_s, dim=0).detach().cpu().numpy()
    gen_x = torch.cat(gen_x, dim=0).detach().cpu().numpy()
    gen_y = torch.cat(gen_y, dim=0).detach().cpu().numpy()

    gen_data = TensorDataset(tensor(gen_s), tensor(gen_x), tensor(gen_y))
    gen_loader = DataLoader(gen_data, batch_size=batch_size, shuffle=False)

    return gen_loader, gen_s, gen_x, gen_y

In [49]:
class Generator(nn.Module):

    def __init__(self, in_size, hidden_size, num_layers):
        super().__init__()

        self.num_layers = num_layers
        self.h0_linear = nn.Linear(in_size, hidden_size)
        self.rnn = nn.GRU(in_size + 3, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, in_size)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x0, noise, z0, clf):
        zz = torch.clone(z0)
        zz = ss.to(x0.device)
        z0 = torch.zeros(z0.size(0), 2).scatter_(1, z0.long(), torch.ones_like(z0))
        z0 = z0.to(x0.device)

        h0 = self.h0_linear(x0)
        h0 = h0.unsqueeze(0).repeat(self.num_layers, 1, 1)
        yt = clf(zz, x0)
        
        xz, yz = [x0], [yt]
        for i in range(noise.size(1)):
            y_noise = torch.cat([z0, yt, noise[:, i]], dim=-1).unsqueeze(1)
            output, h0 = self.rnn(y_noise, h0)
            xt = self.sigmoid(self.linear(output).squeeze())
            # xt = self.linear(output).squeeze()
            yt = clf(zz, xt)

            xz.append(xt)
            yz.append(yt)

        xz = torch.stack(xz, dim=1)
        yz = torch.stack(yz, dim=1)
        return xz, yz, yz.round().detach()


class Discriminator(nn.Module):

    def __init__(self, in_size, hidden_size, num_layers):
        super().__init__()

        self.rnn = nn.GRU(in_size, hidden_size, num_layers, batch_first=True)
        self.linear = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        output, hn = self.rnn(x)
        output = self.linear(output)
        output = self.sigmoid(output)
        return output


class DistributionDiscriminator(nn.Module):

    def __init__(self, hiddens):
        super().__init__()

        layers = []
        for in_dim, out_dim in zip(hiddens[:-1], hiddens[1:]):
            layers.append(nn.Linear(in_dim, out_dim))
            layers.append(nn.LeakyReLU(0.2))
        layers.pop()
        self.model = nn.Sequential(*layers)
    
    def forward(self, x):
        output = self.model(x)
        return output

In [43]:
print(y.dtype)

torch.float32


In [50]:
z_train, z_test, x_train, x_test, y_train, y_test = train_test_split(z, x, y, test_size=test_size, random_state=10)
z_train, z_valid, x_train, x_valid, y_train, y_valid = train_test_split(z_train, x_train, y_train, test_size=valid_size, random_state=10)
train_data = TensorDataset(tensor(z_train), tensor(x_train), tensor(y_train))
valid_data = TensorDataset(tensor(z_valid), tensor(x_valid),tensor(y_valid))
test_data = TensorDataset(tensor(z_test), tensor(x_test), tensor(y_test))
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False)
valid_loader = DataLoader(valid_data, batch_size=len(valid_data), shuffle=False)
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)

In [57]:
generator = Generator(x_dim, g_hidden_size, g_num_layers)
generator.to(device)

Generator(
  (h0_linear): Linear(in_features=2, out_features=64, bias=True)
  (rnn): GRU(5, 64, num_layers=2, batch_first=True)
  (linear): Linear(in_features=64, out_features=2, bias=True)
  (sigmoid): Sigmoid()
)

In [58]:
discriminator = Discriminator(x_dim, d_hidden_size, d_num_layers)
discriminator.to(device)

Discriminator(
  (rnn): GRU(2, 64, num_layers=2, batch_first=True)
  (linear): Linear(in_features=64, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [54]:
clf = Classifier(c_hiddens)
clf.to(device)

In [62]:
enumerate(train_loader, start=0)

<enumerate at 0x20887dee3c0>

In [72]:
train_gan(train_loader, clf, generator, discriminator, gan_epochs, device)

ValueError: not enough values to unpack (expected 3, got 2)