In [65]:
%cd ~/fast/ProtoTSNetDPL/

/home/jovyan/fast/ProtoTSNetDPL


In [66]:
from scipy import signal
import numpy as np
import torch
import json

from deepproblog.dataset import Dataset as DPLDataset, DataLoader
from deepproblog.query import Query
from deepproblog.network import Network
from deepproblog.model import Model
from deepproblog.engines import ExactEngine
from deepproblog.train import train_model
from deepproblog.evaluate import get_confusion_matrix, get_fact_accuracy
from problog.logic import Term, Constant, list2term

from model import ProtoTSNet
from autoencoder import RegularConvEncoder

from matplotlib import pyplot as plt

In [67]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [68]:
%env CUDA_VISIBLE_DEVICES=0
torch.cuda.set_per_process_memory_fraction(fraction=0.5, device=0)

env: CUDA_VISIBLE_DEVICES=0


In [69]:
class ArtificialProtosDataset():
    def __init__(self, N, feature_noise_power=0.1, randomize_right_side=False):
        self.data = []
        x = np.linspace(0, 100, 100)
        for _ in range(N):
            label = np.random.randint(0, 2)
            ts = np.zeros((3, 100))
            if label == 0:
                ts[0, :40] = signal.sawtooth(x[:40] / (1+1))
                ts[1, :40] = signal.square(x[:40] / (2+1))
            else:
                ts[0, :40] = signal.square(x[:40] / (1+1))
                ts[1, :40] = signal.sawtooth(x[:40] / (2+1))
            if np.random.choice([0, 1]) == 0:
                ts[2, :40] = signal.square(np.random.choice([-1, 1]) * x[:40] / 3)
            else:
                ts[2, :40] = signal.sawtooth(np.random.choice([-1, 1]) * x[:40] / 3)
            for i in range(3):
                if randomize_right_side:
                    ts[i, 40:] = np.sin(x[40:] / (np.random.randint(0, 4)+i+1)) / 3
                else:
                    ts[i, 40:] = np.sin(x[40:] / (i+1)) / 3
                ts[i, :] += np.random.normal(0, feature_noise_power, 100)
            self.data.append((ts.astype('float32'), label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[int(idx[0])][0])
    
    def get_label(self, idx):
        return self.data[idx][1]

class ArtificialProtosQueries(DPLDataset):
    def __init__(self, dataset: ArtificialProtosDataset, phase: str):
        self.phase = phase
        self.dataset = dataset
        self.dataset_len = len(dataset)
        self.num_classes = 2 # len(set([dataset.get_label(i) for i in range(self.dataset_len)]))

    def to_query(self, i: int) -> Query:
        
        ds_entry = i // self.num_classes
        cls_num = i % self.num_classes
        correct_cls = self.dataset.get_label(ds_entry)

        ts_term = Term(f'ts{ds_entry}')
        q = Query(
            Term(
                'is_class',
                ts_term,
                Term(f'c{cls_num}')
            ),
            {
                ts_term: Term(
                    "tensor",
                    Term(
                        self.phase,
                        Constant(ds_entry),
                    ),
                )
            },
            p = float(cls_num == correct_cls)
        )
        return q

    def __len__(self):
        return self.dataset_len * self.num_classes

In [70]:
test_dataset = ArtificialProtosDataset(50)
test_queries = ArtificialProtosQueries(test_dataset, "test")

In [71]:
idx = 6
print(f'Label: {test_dataset.get_label(idx)}')
print(f'Query 1: {test_queries.to_query(2*idx)}')
print(f'Query 2: {test_queries.to_query(2*idx + 1)}')

Label: 0
Query 1: (1.0::is_class(ts6,c0), {ts6: tensor(test(6))})
Query 2: (0.0::is_class(ts6,c1), {ts6: tensor(test(6))})


In [72]:
protos_per_class = 1
latent_features = 32

print('Preparing ProtoTSNet...')
autoencoder = RegularConvEncoder(num_features=3, latent_features=latent_features, padding='same')
encoder = autoencoder.encoder
net = ProtoTSNet(
    cnn_base=encoder,
    for_deepproblog=True,
    num_features=3,
    ts_sample_len=100,
    prototype_shape=(protos_per_class*2, latent_features, 20),
    num_classes=2,
    prototype_activation_function='log'
)

dpl_net = Network(net, "ptsnet", batching=False)
dpl_net.optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

print('Loading logic file...')
model = Model("proto_logic.pl", [dpl_net])
model.set_engine(ExactEngine(model))

Preparing ProtoTSNet...
Loading logic file...


In [77]:
model.load_state('./snapshots/initial_model.pth')
model.train()
model.eval()
model.add_tensor_source("test", test_dataset)

ProtoTSNet.eval()


In [74]:
next(net.parameters()).shape

torch.Size([2, 32, 20])

In [75]:
# idx = 7
# print(f'Label: {test_dataset.get_label(idx)}')
# plt.plot(test_dataset[[idx]][0])

In [79]:
idx = 3
print(f'Label: {test_dataset.get_label(idx)}')
print(model.solve([test_queries.to_query(2*idx+0)]))
print(model.solve([test_queries.to_query(2*idx+1)]))

Label: 1
[{is_class(tensor(test(3)),c0): tensor(0.0596, grad_fn=<SelectBackward0>)}]
[{is_class(tensor(test(3)),c1): tensor(0.9075, grad_fn=<SelectBackward0>)}]
