# Approximating the Triangle Function

f(x) = max(1 - abs(x), 0)


In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from functions.hat_function import hat_function, HatApproxNet 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
path = "saved_models/"
model_name = "hat_function.pth"
load_model = True

#Can specify a hidden_dim parameter but be careful when loading and saving weights
model = HatApproxNet().to(device)

for m in model.modules():
  if isinstance(m, nn.Linear):
    nn.init.xavier_normal_(m.weight)

if model_name in os.listdir(path) and load_model:
    model.load_state_dict(torch.load(path + model_name, map_location = device))
    
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0004)

In [None]:
iterations = 300000
for iteration in range(iterations + 1):
  data = torch.Tensor(np.random.uniform(-3.01, 3.01, size = (32, 1))).to(device)
  targets = torch.Tensor([hat_function(i) for i in data]).view(-1, 1).to(device)

  model.zero_grad()
  predictions = model(data)

  loss = torch.mean((targets - predictions)**2)
  loss.backward()
  optimizer.step()

  if iteration%2000 == 0:
    print(iteration, loss)

In [None]:
torch.save(model.state_dict(), path + model_name)

Here's a quick look at how the model is doing versus the actual function on a nice plot.

In [None]:
with torch.no_grad():
  model.eval()
  A = torch.Tensor(np.random.uniform(-3.01, 3.01, size = (1024, 1))).cuda()
  T = torch.Tensor([hat_function(i) for i in A]).view(-1, 1).cuda()
  B = model(A)

MSE = torch.mean((T - B)**2)
f = plt.figure(figsize = (14, 14))
plt.plot(A.cpu(), B.cpu(), 'C1.', label = "Model")
plt.plot(A.cpu(), T.cpu(), 'C0.', alpha= 0.4, label = "Actual")
plt.title("MSE error of " + str(MSE.cpu().numpy()))
plt.legend()
plt.show()


Below is a test to see how much slower this "function" call is with respect to one of the torch functionals such as tanh or relu.
The triangle/hat function takes around **4** times the time of a normal function call but it may be negligible in the larger scheme of things.

In [None]:
import time
import datetime

with torch.no_grad():
  model.eval()

  A = torch.Tensor(np.random.uniform(-3.01, 3.01, size = (1024, 1))).to(device)
  T = torch.Tensor([hat_function(i) for i in A]).view(-1, 1).to(device)
  
  s = time.time()
  B = model(data)
  print("Time taken by model :", datetime.timedelta(seconds = time.time() - s))  

  s = time.time()
  B = F.relu(data)
  print("Time taken by F :", datetime.timedelta(seconds = time.time() - s))