<a href="https://colab.research.google.com/github/jhmuller/nn_sort/blob/main/sorting.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import numpy as np
from torch import nn
import math
from itertools import permutations
from torch.utils.data import Dataset, DataLoader


In [37]:


def sort_seq(x):
  lst = x.copy()
  N = len(lst)
  swaped = True
  seq = []
  while swaped:
    swaped = False
    for i in range(N-1):
      if lst[i] > lst[i+1]:
        lst[i], lst[i+1] = lst[i+1], lst[i]
        seq.append(i)
        swaped = True
  while len(seq) < N*(N-1)/2:
    seq.append(N-1)
  return lst, seq

def apply_seq(x_in, seq):
  x_out = x_in.copy()
  N = len(x_in)
  for i in seq:
    if i >= N-1:
      continue
    x_out[i], x_out[i+1] = x_out[i+1], x_out[i]
  return x_out

def lst_diff(true, pred):
  diff = 0
  for i, p in enumerate(pred):
    diff += abs(p - true[i])
  return diff

def factorial(n):
  res = n
  for i in range(2, n):
    res = res*i
  return res

def generate_input_data(m=5):
  res = []
  w = int((m*(m-1))/2)
  first = list(range(m))
  iter = permutations(first)
  for i, x in enumerate(iter):
    _, y = sort_seq(list(x))
    mat = np.zeros((w, m))
    for j, p in enumerate(y):
      mat[j, p] = 1
    item = {}
    item["list"] = torch.Tensor(x)
    item["mat"] = mat
    res.append(item)
  print(f" # items: {len(res)}")
  return res


class SortDataset(Dataset):
    """Sort dataset."""

    def __init__(self, data):
      self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return(self.data[idx])

sdata = generate_input_data(m=8)
N = len(sdata)
cutoff = int(N*.8)
train = sdata[:cutoff]
test = sdata[cutoff:]
train_dataset = SortDataset(train)
train_dataloader = DataLoader(train_dataset, batch_size=4,
                        shuffle=True, num_workers=0)
test_dataset = SortDataset(test)
test_dataloader = DataLoader(test_dataset, batch_size=4,
                        shuffle=True, num_workers=0)



 # items: 40320


In [32]:
len(train)
len(test)
factorial(8)
train[3]
next(iter(train_dataloader))
train_dataset[2]
for i in [1, 2]:
  print(next(iter(train_dataloader)))

{'list': tensor([[1., 5., 2., 3., 6., 4., 0., 7.]]), 'mat': tensor([[[0., 1., 0., 0., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0., 0.],
         [0., 0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 1., 0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0., 0., 0.,

In [38]:
import torch.nn.functional as F

class SortNet(nn.Module):
    def __init__(self, lst_len, hidden=100):
        """
        """
        self.lst_len = lst_len
        self.seq_len = int((lst_len * (lst_len-1))/2)
        out_len = (lst_len) * self.seq_len        
        print(lst_len, out_len)
        super().__init__()
        self.linear1 = nn.Linear(lst_len, hidden) 
        self.linear2 = nn.Linear(hidden, hidden)        
        self.linear3 = nn.Linear(hidden, out_len)

    def forward(self, x):
            """
            In the forward function we accept a Variable of input data and we must 
            return a Variable of output data. We can use Modules defined in the 
            constructor as well as arbitrary operators on Variables.
            """
            h_relu = F.relu(self.linear1(x))
            y = torch.sigmoid(self.linear3(h_relu))
            y_pred = y.view((y.shape[0], self.seq_len , self.lst_len))
            y_pred = torch.softmax(y_pred, dim=2)
            return y_pred

In [39]:
print_freq = 20
model = SortNet(8)
loss_fn = nn.MSELoss() #nn.BCELoss() # CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0001)
for ei in range(101):
  #print(f"epoch {ei}")
  for bi, sample in enumerate(train_dataloader):
    #print(sample, "----")
    X = torch.Tensor(sample["list"])

    X = X.to(torch.float32)
 
    y = torch.Tensor(sample["mat"])
    y = y.to(torch.float32)

    y_pred = model(X)           # compute model output
    if ei == 50 and (bi % print_freq == 0):
      pass
    optimizer.zero_grad()       
    loss = loss_fn(y, y_pred)  

    loss.backward()        
    optimizer.step()
  if ei % print_freq == 0: 
    print(f"epoch {ei} batch {bi} loss {loss}")

8 224
epoch 0 batch 8063 loss 0.04185900464653969
epoch 20 batch 8063 loss 0.020469803363084793
epoch 40 batch 8063 loss 0.02370063029229641
epoch 60 batch 8063 loss 0.02776983380317688
epoch 80 batch 8063 loss 0.029141241684556007
epoch 100 batch 8063 loss 0.016244884580373764


In [41]:
with torch.no_grad():
  sum_loss = 0
  model.eval()
  for bi, sample in enumerate(test_dataloader):
    X = torch.Tensor(sample["list"])
    X = X.to(torch.float32) 
    y = torch.Tensor(sample["mat"])
    y = y.to(torch.float32)
    y_pred = model(X)           # compute model output
    if ei == 50 and (bi % print_freq == 0):
      pass     
    loss = loss_fn(y, y_pred)  # calculate loss
    sum_loss += loss
    if bi % print_freq == 0: 
      print(f"batch {bi} loss {loss} sum_loss {sum_loss}")

batch 0 loss 0.06636456400156021 sum_loss 0.06636456400156021
batch 20 loss 0.07128918915987015 sum_loss 1.5038270950317383
batch 40 loss 0.051932137459516525 sum_loss 2.8890087604522705
batch 60 loss 0.09167414158582687 sum_loss 4.260381698608398
batch 80 loss 0.0566142201423645 sum_loss 5.769962310791016
batch 100 loss 0.07950929552316666 sum_loss 7.11648416519165
batch 120 loss 0.0736360102891922 sum_loss 8.588957786560059
batch 140 loss 0.06706433743238449 sum_loss 10.097452163696289
batch 160 loss 0.07332582026720047 sum_loss 11.583922386169434
batch 180 loss 0.07947786897420883 sum_loss 13.059700965881348
batch 200 loss 0.0710444375872612 sum_loss 14.486085891723633
batch 220 loss 0.060264237225055695 sum_loss 15.855779647827148
batch 240 loss 0.06487880647182465 sum_loss 17.207229614257812
batch 260 loss 0.08064998686313629 sum_loss 18.72004508972168
batch 280 loss 0.0693620815873146 sum_loss 20.167112350463867
batch 300 loss 0.06420190632343292 sum_loss 21.640092849731445
batch

In [47]:
test = torch.Tensor([4, 3, 2, 1])
test = test.unsqueeze(0)
sample = test_dataset[0]
X = sample["list"]
X = X.unsqueeze(0)
res = model(X)
print(X)
print(res)

tensor([[6., 2., 5., 7., 0., 1., 3., 4.]])
tensor([[[1.0000e+00, 5.6751e-34, 0.0000e+00, 1.8937e-23, 0.0000e+00,
          0.0000e+00, 1.6461e-34, 3.6827e-10],
         [2.9217e-10, 1.0000e+00, 3.9654e-19, 4.2872e-03, 1.9867e-14,
          3.0894e-25, 1.6978e-29, 0.0000e+00],
         [7.9651e-27, 2.4799e-21, 3.2102e-07, 4.4458e-01, 4.4394e-02,
          3.5389e-08, 1.1289e-17, 0.0000e+00],
         [7.8460e-22, 1.6747e-17, 1.3335e-09, 1.5576e-06, 3.3859e-01,
          4.5228e-02, 2.0215e-06, 0.0000e+00],
         [2.8726e-15, 2.9572e-12, 1.3286e-06, 1.8169e-11, 1.4406e-06,
          4.7782e-01, 1.3665e-01, 0.0000e+00],
         [2.4061e-12, 2.6352e-10, 2.5904e-02, 4.4076e-05, 2.6148e-09,
          2.7311e-10, 9.2863e-01, 1.1175e-29],
         [7.0916e-13, 3.6836e-09, 3.9730e-01, 7.8059e-02, 3.4024e-04,
          1.7362e-06, 9.2499e-10, 4.6799e-25],
         [1.0741e-07, 1.2476e-22, 3.0061e-02, 5.0517e-01, 7.0640e-02,
          2.5558e-03, 7.9308e-09, 3.0764e-21],
         [1.0634e-04,

In [48]:
torch.sum(res, dim=2)

tensor([[1.0000, 1.0043, 0.4890, 0.3838, 0.6145, 0.9546, 0.4757, 0.6084, 0.7524,
         1.1967, 0.9953, 0.9549, 0.9978, 1.0657, 0.5392, 0.6771, 1.1568, 1.2143,
         1.0220, 1.0006, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
         1.0000]], grad_fn=<SumBackward1>)

In [11]:
class Square(nn.Module):
    """ Custom Linear layer but mimics a standard linear layer """
    def __init__(self, size_in, size_out):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        weights = torch.Tensor(size_out, size_in)
        self.weights = nn.Parameter(weights)  # nn.Parameter is a Tensor that's a module parameter.
        bias = torch.Tensor(size_out)
        self.bias = nn.Parameter(bias)

        # initialize weights and biases
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5)) # weight init
        fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weights)
        bound = 1 / math.sqrt(fan_in)
        nn.init.uniform_(self.bias, -bound, bound)  # bias init

    def forward(self, x):
        w_times_x= torch.mm(x, self.weights.t())
        # 
        print(f" x shape {x.shape} weights shape {self.weights.shape} res shape {w_timex_x.shape}")
        return torch.add(w_times_x, self.bias) 