In [1]:
import torch
import torch.nn as nn
from torch.utils.data import IterableDataset, DataLoader
import matplotlib.pyplot as plt

In [2]:
def complex_func(x, t):
    return x**3*t + 2*x**2*torch.sqrt(t) + torch.sin(x) + 3*torch.cos(t)

In [3]:
class MLP(nn.Module):
    def __init__(self,
                 input_dim: int,
                 output_dim: int):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(1, 64)
        
        self.fc3 = nn.Linear(128, 128)
        self.fc4 = nn.Linear(128, 64)
        self.fc5 = nn.Linear(64, output_dim)
        self.act = nn.ReLU()

    def forward(self, x, t):
        x = self.act(self.fc1(x))
        t = self.act(self.fc2(t))
        xt = torch.cat((x, t), dim=1)
        xt = self.act(self.fc3(xt))
        xt = self.act(self.fc4(xt))
        return self.fc5(xt)

In [4]:
class MyIterableDataset(IterableDataset):
    def __init__(self):
        super().__init__()
        self.x_dim = 64
        self.t_dim = 1
    
    def generator(self):
        while True:
            x = torch.rand((self.x_dim))
            t = torch.rand((self.t_dim))
            y = complex_func(x, t)
            yield x, t, y
    
    def __iter__(self):
        return iter(self.generator())

In [6]:
dataset = MyIterableDataset()
ds_loader = DataLoader(dataset, batch_size=4)

mlp = MLP(input_dim=64, output_dim=64)
mlp.train()
optimizer = torch.optim.Adam(mlp.parameters(), lr=0.001)
loss_func = nn.MSELoss()

for epoch in range(100):
    for i, data in enumerate(ds_loader):
        x, t, y = data
        optimizer.zero_grad()
        pred = mlp(x, t)
        loss = loss_func(pred, y)
        loss.backward()
        optimizer.step()
        if i > 1000:
            break
    print(f"Epoch: {epoch}, Loss: {loss.item()}")

Epoch: 0, Loss: 0.8135509490966797
Epoch: 1, Loss: 0.5046282410621643
Epoch: 2, Loss: 0.6681536436080933
Epoch: 3, Loss: 0.6689860224723816
Epoch: 4, Loss: 0.40377163887023926
Epoch: 5, Loss: 0.5392351746559143
Epoch: 6, Loss: 0.6561574935913086
Epoch: 7, Loss: 0.497970312833786
Epoch: 8, Loss: 0.4703254997730255
Epoch: 9, Loss: 0.4264187216758728
Epoch: 10, Loss: 0.33536186814308167
Epoch: 11, Loss: 0.4799478352069855
Epoch: 12, Loss: 0.2943488359451294
Epoch: 13, Loss: 0.23191973567008972
Epoch: 14, Loss: 0.6100165843963623
Epoch: 15, Loss: 0.4604904353618622
Epoch: 16, Loss: 0.2953226566314697
Epoch: 17, Loss: 0.36538827419281006
Epoch: 18, Loss: 0.4199337065219879
Epoch: 19, Loss: 0.41898396611213684
Epoch: 20, Loss: 0.4573240876197815
Epoch: 21, Loss: 0.5226966142654419
Epoch: 22, Loss: 0.30232709646224976
Epoch: 23, Loss: 0.5510480403900146
Epoch: 24, Loss: 0.48824018239974976
Epoch: 25, Loss: 0.36434581875801086
Epoch: 26, Loss: 0.31575390696525574
Epoch: 27, Loss: 0.29422572255

In [7]:
x = torch.rand((1, 64))
t = torch.rand((1, 1))
y = complex_func(x.squeeze(), t.squeeze())
with torch.no_grad():
    pred = mlp(x, t)
    print(f"y: {y}, pred: {pred}")
    print(f"y-pred: {y-pred}")

y: tensor([3.1364, 3.1241, 3.1781, 4.0548, 4.1723, 4.1959, 4.2630, 3.1958, 3.4458,
        4.1777, 4.3326, 3.2465, 4.4264, 3.0197, 3.3294, 3.6239, 3.0075, 3.2230,
        4.2677, 4.0143, 4.3521, 3.1582, 3.2274, 3.4675, 3.8911, 4.1819, 3.5897,
        3.2023, 4.1943, 3.6479, 3.6529, 3.2884, 3.6018, 4.3290, 3.8149, 4.2488,
        3.0523, 3.1200, 3.1330, 4.2276, 3.2416, 3.8995, 3.4058, 4.1589, 3.6061,
        3.3893, 3.4294, 3.5833, 3.3441, 3.4656, 3.3275, 3.2203, 3.1620, 3.5418,
        3.8082, 4.4065, 4.0568, 3.2071, 3.2362, 3.2202, 3.0992, 4.3898, 3.0008,
        4.1357]), pred: tensor([[3.6153, 3.1494, 3.4698, 3.9832, 3.6665, 4.1834, 4.1734, 3.7083, 3.3788,
         4.7665, 3.8076, 3.1714, 3.9558, 3.3482, 3.6533, 3.4255, 2.8138, 3.2836,
         4.0509, 3.6887, 3.5982, 3.8802, 2.9503, 3.3829, 3.9513, 3.8507, 3.9052,
         3.0377, 4.2483, 3.6220, 3.7094, 3.5375, 3.7382, 3.6356, 3.9911, 3.8013,
         3.4487, 3.1081, 3.2295, 4.1512, 3.5629, 4.0895, 3.6865, 3.7793, 3.6243,
        