In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import torch

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'

In [None]:
device

In [None]:
torch.manual_seed(0)
observations = 2048
train_data = torch.zeros((observations, 2))
train_data[:, 0] = 50 * torch.rand(observations)
train_data[:, 1] = 1.08 ** train_data[:, 0]

In [None]:
fig = plt.figure(dpi=100, figsize=(8, 6))
plt.plot(train_data[:, 0], train_data[:, 1], '.', c='r')
plt.xlabel('values of x', fontsize=15)
plt.ylabel('values of $y=1.08^x$', fontsize=15)
plt.title('An expotential growth shape', fontsize=20)

In [None]:
from torch.utils.data.dataloader import DataLoader

In [None]:
batch_size = 128
train_loader = DataLoader(train_data, batch_size, shuffle=True)

In [None]:
batch0=next(iter(train_loader))

In [None]:
batch0

In [None]:
import torch.nn as nn
D = nn.Sequential(
    nn.Linear(2, 256),
    nn.ReLU(),
    nn.Dropout(.3),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(.3),
    nn.Linear(128, 64),
    nn.ReLU(),
    nn.Dropout(.3),
    nn.Linear(64, 1),
    nn.Sigmoid()
).to(device)

In [None]:
G = nn.Sequential(
    nn.Linear(2, 16),
    nn.ReLU(),
    nn.Linear(16, 32),
    nn.ReLU(),
    nn.Linear(32, 2),
).to(device)

In [None]:
loss_fn = nn.BCELoss()
lr = .0005
optimD = torch.optim.Adam(D.parameters(), lr=lr)
optimG = torch.optim.Adam(G.parameters(), lr=lr)

In [None]:
mse = nn.MSELoss()
def performance(fake_samples):
    real = 1.08 ** fake_samples[:, 0]
    return mse(fake_samples[:, 1], real)

In [None]:
class EarlyStop:
    def __init__(self, patience=1000):
        self.patience = patience
        self.steps = 0
        self.min_gdif = float('inf')
    def stop(self, gdif):
        if gdif < self.min_gdif:
            self.min_gdif = gdif
            self.steps = 0
        elif gdif >= self.min_gdif:
            self.steps += 1
        if self.steps >= self.patience:
            return True
        else:
            return False
stopper=EarlyStop()

In [None]:
real_labels = torch.ones((batch_size, 1)).to(device)
fake_labels = torch.zeros((batch_size, 1)).to(device)

In [None]:
def train_D_on_real(real_samples):
    real_samples = real_samples.to(device)
    optimD.zero_grad()
    output = D(real_samples)
    loss_D = loss_fn(output, real_labels)
    loss_D.backward()
    optimD.step()
    return loss_D

In [None]:
def train_D_on_fake():
    noise = torch.randn((batch_size, 2)).to(device)
    fake_samples = G(noise)
    optimD.zero_grad()
    output_D = D(fake_samples)
    loss_D = loss_fn(output_D, fake_labels)
    loss_D.backward()
    optimD.step()
    return loss_D

In [None]:
def train_G():
    noise = torch.randn((batch_size, 2)).to(device)
    optimG.zero_grad()
    fake_samples = G(noise)
    out_G = D(fake_samples)
    loss_G = loss_fn(out_G, real_labels)
    loss_G.backward()
    optimG.step()
    return loss_G, fake_samples

In [None]:
import os
os.makedirs('files', exist_ok=True)

In [None]:
def test_epoch(epoch, gloss, dloss, n, fake_samples):
    if epoch == 0 or (epoch + 1) % 100 == 0:
        g = gloss.item() / n
        d = dloss.item() / n
        print(f"at epoch {epoch+1}, G loss: {g}, D loss {d}")
        print(f"at epoch {epoch+1}, G loss: {g}, D loss {d}")
        fake=fake_samples.detach().cpu().numpy()
        plt.figure(dpi=80)
        plt.plot(fake[:,0],fake[:,1],"*",c="g",
        label="generated samples")
        plt.plot(train_data[:,0],train_data[:,1],".",c="r",
        alpha=0.1,label="real samples")
        plt.title(f"epoch {epoch+1}")
        plt.xlim(0,50)
        plt.ylim(0,50)
        plt.legend()
        plt.savefig(f"files/p{epoch+1}.png")
        plt.show()

In [None]:
for epoch in range(10000):
    gloss = 0
    dloss = 0
    for n, real_samples in enumerate(train_loader):
        dloss += train_D_on_real(real_samples)
        dloss += train_D_on_fake()
        loss_G, fake_samples = train_G()
        gloss += loss_G
    test_epoch(epoch, gloss, dloss, n, fake_samples)
    gdif=performance(fake_samples).item()
    if stopper.stop(gdif)==True:
        break

In [None]:
scripted = torch.jit.script(G)

In [None]:
scripted.save('files/my_exponential.pt')