In [9]:
import numpy as np
import plotly.graph_objects as go
import torch
import scipy as sp
from torch.utils.data import Dataset
import os
ROOT = os.path.join("./")

In [10]:
# Toy Dataset
def N(u, var):
  return np.random.normal(u, np.sqrt(var))

def sample(func, var, minx, maxx, pts, sparsity=0.5):
  rng = np.random.default_rng()
  res = (maxx-minx) / pts

  all_x = []
  for i in range(pts):
    if rng.random() <= sparsity:
      continue

    all_x.append(minx + i*res)



  all_y = [func(x) for x in all_x]

  valrange = max(all_y) - min(all_y)
  all_y = [N(y, var*valrange) for y in all_y]

  all_y = sp.stats.zscore(all_y)
  return (all_x, all_y)

def pos_encoding(unsorted):
  unsorted = np.array(unsorted)
  sorted = unsorted[unsorted[:, -1].argsort()] # sort ascending

  times = [0] + list(sorted[:, -1])
  pos = np.array([times[i] - times[i-1] for i in range(1, len(times))])

  sorted[:, -1] = pos # change abs time to positional encoding
  return torch.tensor(sorted)
def n_examples(fn_gen, var, minx, maxx, pts, sparsity, N):
  ex = []

  fn = fn_gen()

  for i in range(N):
    x, y = sample(fn, var, minx, maxx, pts, sparsity)
    ts = []
    for j in range(len(x)):
      xpt = x[j]
      ypt = y[j]
      ts.append(np.array([ypt - np.random.rand()*var*10, np.random.rand()*var*10, ypt + np.random.rand()*var*10, np.random.rand()*var*10, xpt/np.max(x), xpt]))

    ts = pos_encoding(ts)

    ts = torch.tensor(ts)

    ex.append(ts)

  return ex

class PseudoSet(Dataset):
  def __init__(self, classes):
    self.classes = classes

    self.all = []
    for i, class_ in enumerate(self.classes):
      label = torch.zeros(len(self.classes))
      label[i] = 1
      for ex in class_:
        self.all.append((ex, label))


  def __getitem__(self, idx):
    return self.all[idx]

  def __len__(self):
    return len(self.all)
num_ex = 100
var = 0.005

s, u = np.random.rand()*75, np.random.rand()*310
norm = lambda s, u, x: (1 / np.sqrt(s**2 * 2 * 3.141)) * np.exp(-0.5 * ((x - u) / s)**2)

def linefn_gen():
  m = np.random.rand()*100
  return lambda x: m * x

def logfn_gen():
  a = np.random.rand()*2
  b = np.random.rand()*75
  return lambda x: a * np.log(b * x)

def bellfn_gen():
  s = np.random.rand()*12
  u = (np.random.rand() - 0.5) * 6
  norm = lambda s, u, x: (1 / np.sqrt(s**2 * 2 * 3.141)) * np.exp(-0.5 * ((x - u) / s)**2)
  return lambda x: norm(s, u, x)


lines = n_examples(linefn_gen, 5*var, 0, 1000, 267, 0.65, num_ex)
logs = n_examples(logfn_gen, var, 0.1, 4, 267, 0.65, num_ex)
bells = n_examples(bellfn_gen, 0.01*var, -10, 10, 267, 0.65, num_ex)

def displayex(ex):
  fig = go.Figure()
  tr1 = go.Scatter(
      x=ex[:, -2],
      y=ex[:, 0],
      mode='markers',
      marker=dict(size=5, opacity=.75)
  )
  tr2 = go.Scatter(
      x=ex[:, -2],
      y=ex[:, 2],
      mode='markers',
      marker=dict(size=5, opacity=.75)
  )

  fig.add_trace(tr1)
  fig.add_trace(tr2)

  fig.show()

displayex(lines[15])
displayex(logs[15])
displayex(bells[15])


trainsplit = 0.7

splitidx = int(0.7 * num_ex)

train = PseudoSet((lines[:splitidx], logs[:splitidx], bells[:splitidx]))
valid = PseudoSet((lines[splitidx:], logs[splitidx:], bells[splitidx:]))

print(train[0][0].shape)

with open(ROOT + "processed_datasets/toy_data_train.pt", "wb") as f:
  torch.save(train, f)
with open(ROOT + "processed_datasets/toy_data_valid.pt", "wb") as f:
  torch.save(valid, f)




To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



torch.Size([91, 6])
