In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import math

In [3]:
# g : R^2 -> R
# Here, each x \in X are \in R^2
def true_g(X):
    return np.sum(X, axis=1)
    # return np.sum(np.square(X), axis=1)
    # return np.where(np.linalg.norm(X, axis=1) < 4, 0, 1)

# Given g, each x \in X are \in R^3
def true_f(g, X):
    g_result = g(X[:,:2])
    return g_result * (1 - X[:,2]) + (10 * g_result) * X[:,2]

In [4]:
# Number of training points will be split_sizes[0] + split_sizes[1]
def get_train_data(split_sizes):
    n0 = split_sizes[0]
    n1 = split_sizes[1]

    X1_01 = np.around(np.random.uniform(0, 5, size=(n0,2)), 6)
    X1_2 = np.zeros((n0, 1))
    X1 = np.hstack((X1_01, X1_2))

    X2_01 = np.around(np.random.uniform(0, 5, size=(n1,2)), 6)
    X2_2 = np.ones((n1, 1))
    X2 = np.hstack((X2_01, X2_2))

    X_train = np.vstack((X1, X2))
    y_train = true_f(true_g, X_train)
    return X_train, y_train

X_train, y_train = get_train_data([5000,5000])
train_data = []
for i in range(len(X_train)):
    train_data.append([X_train[i], y_train[i]])

trainloader = torch.utils.data.DataLoader(train_data, shuffle=True, batch_size=20)

In [5]:
embedding = nn.Embedding(2,1)

In [6]:
with torch.no_grad():
    input = embedding(torch.from_numpy(X_train[:, 2]).long())

input = torch.from_numpy(X_train[:, 2]).long()

In [8]:
model = nn.Sequential(
    nn.Embedding(2,1),
    nn.LayerNorm(1),
    nn.Linear(1,1)
)

labels = torch.from_numpy(X_train[:,2]).unsqueeze(dim=1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-6)
loss_values = []
num_epochs = 1000
for epoch in range(num_epochs):
    correct = 0
    optimizer.zero_grad()
    outputs = model(input)
    loss = criterion(outputs, labels)

    loss.backward(retain_graph=True)
    optimizer.step()
    
    output = (outputs > 0.5).float()
    correct += (output == labels).float().sum()

    acc = correct / len(train_data)
    print(f"Epoch {epoch}/{num_epochs}, Loss: {loss.item():.3f}, Accuracy: {acc:.3f}")
    loss_values.append(loss.item())

plt.plot(np.array(np.arange(0, num_epochs)), loss_values)

Epoch 0/1000, Loss: 0.731, Accuracy: 0.500
Epoch 1/1000, Loss: 0.731, Accuracy: 0.500
Epoch 2/1000, Loss: 0.731, Accuracy: 0.500
Epoch 3/1000, Loss: 0.731, Accuracy: 0.500
Epoch 4/1000, Loss: 0.731, Accuracy: 0.500
Epoch 5/1000, Loss: 0.731, Accuracy: 0.500
Epoch 6/1000, Loss: 0.731, Accuracy: 0.500
Epoch 7/1000, Loss: 0.731, Accuracy: 0.500
Epoch 8/1000, Loss: 0.731, Accuracy: 0.500
Epoch 9/1000, Loss: 0.731, Accuracy: 0.500
Epoch 10/1000, Loss: 0.731, Accuracy: 0.500
Epoch 11/1000, Loss: 0.731, Accuracy: 0.500
Epoch 12/1000, Loss: 0.731, Accuracy: 0.500
Epoch 13/1000, Loss: 0.731, Accuracy: 0.500
Epoch 14/1000, Loss: 0.731, Accuracy: 0.500
Epoch 15/1000, Loss: 0.731, Accuracy: 0.500
Epoch 16/1000, Loss: 0.731, Accuracy: 0.500
Epoch 17/1000, Loss: 0.731, Accuracy: 0.500
Epoch 18/1000, Loss: 0.731, Accuracy: 0.500
Epoch 19/1000, Loss: 0.731, Accuracy: 0.500
Epoch 20/1000, Loss: 0.731, Accuracy: 0.500
Epoch 21/1000, Loss: 0.731, Accuracy: 0.500
Epoch 22/1000, Loss: 0.731, Accuracy: 0.50

KeyboardInterrupt: 