In [1]:
import sys
sys.path.insert(0, '..')

In [2]:
import torch
from torch.utils.data import DataLoader

from test.artificial_dataset_generator import ArtificialDataset
from linear_model import LinearModel
from exploded_logit import ExplodedLogitLoss

torch.manual_seed(24637882)
dataset_size = 8000
test_dataset_size = 1000
data_columns = 3
competitors = 8

dataset_generator = ArtificialDataset(dataset_size, competitors, data_columns, rand_eps=0.0001)
loader_iterator = iter(DataLoader(dataset_generator))

In [3]:
%load_ext tensorboard

In [4]:
def get_sort_order(scores):
    s = torch.argsort(scores, descending=True)
    r = torch.zeros(scores.shape, dtype=torch.long)
    for i in range(scores.shape[-1]):
        r[0, s[0, i]] = i
    return r + 1

In [7]:
loss_type='nll'

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/fully_artificial_test/' + loss_type)

linear_model = LinearModel(data_columns, 1)  # number of columns to score
optimizer = torch.optim.Adam(params=linear_model.parameters())
loss = ExplodedLogitLoss(loss_type=loss_type)

for step in range(dataset_size):
    data, order = next(loader_iterator)
    optimizer.zero_grad()

    score = linear_model(data).squeeze(-1)

    loss_value = loss(score, order)
    loss_value.backward()
    optimizer.step()
    
    writer.add_scalar('training loss',loss_value.item(), step)
    
    if step % 1000 == 0:
        print("Loss value: {0}".format(loss_value.item()))

with torch.no_grad():
    for _ in range(test_dataset_size):
        data, expected_order = next(loader_iterator)

        score = linear_model(data).squeeze(-1)
        actual_order = get_sort_order(score)

        if not torch.equal(actual_order, expected_order):
            println("Order not equal:\n{0}\n{1}".format(actual_order, expected_order))


print("Finished")



Loss value: 42.92835979078594
Loss value: 11.272154385194849
Loss value: 4.042522546918535
Loss value: 2.186030019323485
Loss value: 1.3199677889380097
Loss value: 0.8289848021773595
Loss value: 0.5281215455616933
Loss value: 0.33796510521209966
Finished


In [8]:
tensorboard --logdir=runs/fully_artificial_test

Reusing TensorBoard on port 6006 (pid 22267), started 0:00:26 ago. (Use '!kill 22267' to kill it.)