In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

In [2]:
def make_moons(n_samples=1000, noise=0.1):
    n_samples_per_moon = n_samples // 2
    
    # First moon (upper)
    theta1 = torch.linspace(0, torch.pi, n_samples_per_moon)
    x1 = torch.cos(theta1)
    y1 = torch.sin(theta1)
    
    # Second moon (lower, shifted)
    theta2 = torch.linspace(0, torch.pi, n_samples - n_samples_per_moon)
    x2 = 1 - torch.cos(theta2)
    y2 = 0.5 - torch.sin(theta2)
    
    X = torch.vstack([
        torch.column_stack([x1, y1]),
        torch.column_stack([x2, y2])
    ])
    y = torch.cat([torch.zeros(n_samples_per_moon), torch.ones(n_samples - n_samples_per_moon)])
    
    # Add noise
    X += torch.randn_like(X) * noise
    
    # Shuffle
    idx = torch.randperm(n_samples)
    return X[idx], y[idx].long()

# X, y = make_moons(n_samples=1000, noise=0.1)
# plt.scatter(X[:, 0].numpy(), X[:, 1].numpy(), s=40, c=y.numpy(), cmap='viridis')
# plt.xlabel("X")
# plt.ylabel("Y")
# plt.show()


In [3]:
class MLP(nn.Module):
    """Linear transform and activation"""
    def __init__(self, n_in, n_hid, n_out):
        super().__init__()
        self.ln1 = nn.Linear(n_in, n_hid)
        self.act1 = nn.Tanh()
        self.ln2 = nn.Linear(n_hid, n_out)
    
    def forward(self, x, targets=None):
        x = self.ln1(x)
        x = self.act1(x)
        logits = self.ln2(x)
        if targets is None:
            return logits, None
        else:
            loss = F.binary_cross_entropy_with_logits(logits, targets)
            return logits, loss

In [4]:
class DataLoader():
    def __init__(self, batch_size, dataset_size):
        self.batch_size = batch_size
        self.dataset_size = dataset_size
        self.X, y = make_moons(dataset_size, noise=0.05)
        self.Y = y.float().unsqueeze(1)
        self.pos = 0
    
    def get_batch(self):
        x = self.X[self.pos:self.pos+self.batch_size]
        y = self.Y[self.pos:self.pos+self.batch_size]

        self.pos += self.batch_size
        if self.pos + self.batch_size > self.dataset_size:
            self.pos = 0

        return x, y

In [5]:
# Model
model = MLP(2, 16, 1)

# Data Loader
data_loader = DataLoader(batch_size=32, dataset_size=2048)

# LR
lr = 0.1
max_steps = 10*1024

# Optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=lr)

In [6]:
for i in range(max_steps):
    optimizer.zero_grad()
    x_batch, y_batch = data_loader.get_batch()
    logits, loss = model(x_batch, y_batch)
    loss.backward()
    optimizer.step()

    if i % 1024 == 0:
        print(i, loss.item())

0 0.6722705364227295
1024 0.24732737243175507
2048 0.10327691584825516
3072 0.05463967099785805
4096 0.032960448414087296
5120 0.021808452904224396
6144 0.015560531988739967
7168 0.01176547072827816
8192 0.009294457733631134
9216 0.007591884583234787


In [None]:
with torch.no_grad():
    x_batch, y_batch = data_loader.X, data_loader.Y
    logits, loss = model(x_batch, y_batch)
    y = F.sigmoid(logits)

plt.scatter(x_batch[:, 0].numpy(), x_batch[:, 1].numpy(), s=40, c=y.numpy(), cmap='viridis')
plt.xlabel("X")
plt.ylabel("Y")
plt.show()
print(loss.item())