In [4]:
import random
import torch
import torch.nn.functional as F
from tqdm import tqdm
from model import LILayer_ty as LILayer

import numpy as np
from dataloader import DataCLUTRR
import random

In [27]:
dataloader = DataCLUTRR('data/data_db9b8f04/')
data = dataloader.get_data('train.csv')
def to_sparse(A):
    st, ed = A.sum(dim=0).nonzero().T
    edge_index = list()
    edge_type = list()
    for i in range(len(st)):
        for rr in A[:,st[i],ed[i]].nonzero():
            edge_index.append([st[i], ed[i]])
            edge_type.append(rr[0])
    edge_index = torch.tensor(edge_index).T
    edge_type = torch.tensor(edge_type)
    return edge_index, edge_type

In [24]:
model = LILayer(20, hidden_dim=20, out_dim=20, bias=True)
model._reset_parameters()
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-6)

In [25]:
# train sparse model

N_data = 100

batch_size = 400
n_epoch = 15
model_save = None

n_split = (len(data)+batch_size-1) // batch_size
for epoch in range(n_epoch):
    with tqdm(range(n_split), ncols=80) as _t:
        random.shuffle(data)
        _t.set_description_str(f'Epoch: {epoch}')
        ssloss = 0.0
        for split in _t:
            batch_st = split * batch_size
            sloss = 0.0
            for r, q, G in data[batch_st:batch_st+batch_size]:
                edge_index, edge_label = to_sparse(G)
                N = G.shape[1]

                pred = model(torch.arange(N), edge_index, edge_label, N, 50, 0.7)
                target = torch.zeros_like(G[:,0,0])
                target[r] = 1.0
                # target[q[0],q[1],r] = 1.0
                # loss = F.binary_cross_entropy_with_logits(pred, target, weight, reduction='sum')
                loss = F.cross_entropy(pred[q[0], q[1]], r*torch.ones((),dtype=torch.long))
                if loss.item() != loss.item():
                    raise
                # loss = -torch.log(pred.softmax(dim=0)[r])

                loss.backward()
                sloss += loss.item()
            _t.set_postfix_str(f'loss: {sloss:.4f}')
            ssloss += sloss
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            optimizer.step()
            optimizer.zero_grad()
            model.rescale()
            model_save = model
        # print(ssloss)

Epoch: 0: 100%|█████████████████| 89/89 [02:03<00:00,  1.39s/it, loss: 106.1521]
Epoch: 1: 100%|██████████████████| 89/89 [01:52<00:00,  1.26s/it, loss: 15.6958]
Epoch: 2: 100%|███████████████████| 89/89 [02:08<00:00,  1.45s/it, loss: 6.6457]
Epoch: 3: 100%|███████████████████| 89/89 [01:53<00:00,  1.28s/it, loss: 4.1973]
Epoch: 4:  46%|████████▊          | 41/89 [00:52<01:00,  1.27s/it, loss: 7.4972]


KeyboardInterrupt: 

In [26]:
# test on sparse model

for filename in ['1.2_test.csv','1.3_test.csv','1.4_test.csv','1.5_test.csv','1.6_test.csv',
                 '1.7_test.csv','1.8_test.csv','1.9_test.csv','1.10_test.csv']:
    test_data = dataloader.get_data(filename)
    # test_data = data
    lm = list()
    lp = list()
    with torch.no_grad():
        cnt = 0
        with tqdm(test_data, ncols=80) as _t: 
            for r, q, G in _t:
                edge_index, edge_label = to_sparse(G)
                N = G.shape[1]
                pred = model(torch.arange(N), edge_index, edge_label, N, 100, 1.0)[q[0], q[1], :]
                if pred.argmax() == r:
                    cnt += 1
                
                lp.append(pred[r])
                lm.append(pred.max())
                _t.set_postfix_str(f'{cnt}')
    print(f'{filename}: {cnt / len(test_data)}')

100%|████████████████████████████████| 2332/2332 [00:11<00:00, 208.50it/s, 2332]


1.2_test.csv: 1.0


100%|████████████████████████████████| 2289/2289 [00:11<00:00, 195.16it/s, 2215]


1.3_test.csv: 0.9676714722586283


100%|████████████████████████████████| 5009/5009 [00:26<00:00, 188.68it/s, 4806]


1.4_test.csv: 0.9594729486923538


100%|████████████████████████████████| 5074/5074 [00:27<00:00, 182.89it/s, 4491]


1.5_test.csv: 0.8851005124162397


100%|████████████████████████████████| 5002/5002 [00:27<00:00, 180.49it/s, 3882]


1.6_test.csv: 0.7760895641743303


100%|████████████████████████████████| 5047/5047 [00:23<00:00, 218.50it/s, 3073]


1.7_test.csv: 0.608876560332871


100%|████████████████████████████████| 5033/5033 [00:22<00:00, 225.64it/s, 2203]


1.8_test.csv: 0.43771110669580765


100%|████████████████████████████████| 5031/5031 [00:23<00:00, 215.85it/s, 1869]


1.9_test.csv: 0.3714967203339296


100%|████████████████████████████████| 5008/5008 [00:27<00:00, 179.75it/s, 1598]


1.10_test.csv: 0.3190894568690096
