In [8]:
import sys
import os

module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

In [9]:
import numpy as np
import torch
import torch.optim as optim
from tqdm import tqdm
from src.EquivariantNetwork import EquivariantAutoEncoder
from src.GenerateData import SchoolGenerator, SyntheticData

In [10]:
N_STUDENT = 200
N_COURSE = 200
N_PROFESSOR = 200
EMBEDDING_DIMS = 5
BATCH_SIZE = 1

data_generator = SyntheticData((N_STUDENT, N_COURSE, N_PROFESSOR), sparsity=0.5, embedding_dims=5)
data = data_generator.data
observed = data_generator.observed
missing = {key:  ~val for key, val in observed.items()}
schema = data.schema
relations = schema.relations

In [11]:
# Normalize the data and hide unobserved
data = data.normalize_data()
data_hidden = data.mask_data(observed)

In [12]:
# Train the neural network
net = EquivariantAutoEncoder(schema)

In [13]:
# Loss functions:
def loss_fcn(data_pred, data_true, indices):
    loss = torch.zeros(1)
    for relation in relations:
        rel_id = relation.id
        data_pred_rel = indices[rel_id]*data_pred[rel_id]
        data_true_rel = indices[rel_id]*data_true[rel_id]
        loss += torch.sum((data_pred_rel - data_true_rel)**2)
    loss = loss / len(relations)
    return loss

def per_rel_loss(data_pred, data_true):
    loss = {}
    for relation in relations:
        rel_id = relation.id
        loss[rel_id] = torch.mean((data_pred[rel_id] - data_true[rel_id])**2)
    return loss


learning_rate = 1e-4
optimizer = optim.Adam(net.parameters(), lr=learning_rate)

In [14]:
epochs=1000
progress = tqdm(range(epochs), desc="Loss: ", position=0, leave=True)
j = 0 
for i in progress:
    optimizer.zero_grad()
    data_out = net.forward(data_hidden)
    train_loss = loss_fcn(data_out, data_hidden, observed)
    val_loss = loss_fcn(data_out, data, missing)
    train_loss.backward()
    optimizer.step()
    progress.set_description("Train: {:.4f}, Val: {:.4f}".format(
            train_loss.item(), val_loss.item()))

Train: 40010.6289, Val: 40718.2500:   0%|          | 2/1000 [00:06<51:30,  3.10s/it]


KeyboardInterrupt: 