In [52]:
import torch
from torch import nn
from sklearn.metrics import r2_score

In [53]:
import kfp

In [54]:
class MyMachine(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(2,5),
            nn.ReLU(),
            nn.Linear(5,1)
        )

    def forward(self, x):
        x = self.fc(x)
        return x


def get_dataset():
        X = torch.rand((1000,2))
        x1 = X[:,0]
        x2 = X[:,1]
        y = x1 * x2
        return X, y


def train():
    model = MyMachine()
    model.train()
    X, y = get_dataset()
    NUM_EPOCHS = 1000
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=1e-5)
    criterion = torch.nn.MSELoss(reduction='mean')

    for epoch in range(NUM_EPOCHS):
        optimizer.zero_grad()
        y_pred = model(X)
        y_pred = y_pred.reshape(1000)
        loss = criterion(y_pred, y)
        loss.backward()
        optimizer.step()
        # print(f'Epoch:{epoch}, Loss:{loss.item()}')
    torch.save(model.state_dict(), 'model.h5')
    model_scripted = torch.jit.script(model) # Export to TorchScript
    model_scripted.save('model_scripted.pt') # Save

def test():
    model = MyMachine()
    model.load_state_dict(torch.load("model.h5"))
    model.eval()
    X, y = get_dataset()

    with torch.no_grad():
        y_pred = model(X)
        print(r2_score(y, y_pred))


In [55]:
train()

In [56]:
test()

0.958365948606875


In [57]:
model = MyMachine()
model.load_state_dict(torch.load("model.h5"))
model.eval()

MyMachine(
  (fc): Sequential(
    (0): Linear(in_features=2, out_features=5, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5, out_features=1, bias=True)
  )
)

In [58]:
print(model.state_dict)

<bound method Module.state_dict of MyMachine(
  (fc): Sequential(
    (0): Linear(in_features=2, out_features=5, bias=True)
    (1): ReLU()
    (2): Linear(in_features=5, out_features=1, bias=True)
  )
)>


In [59]:
x = torch.tensor([2.0, 2.0])

In [60]:
print(model(x))

tensor([1.5809], grad_fn=<ViewBackward0>)
