In [32]:
from SeqEN2.session import Session
from SeqEN2.model.model import Model
from SeqEN2.autoencoder.autoencoder import Autoencoder
from glob import glob
from os import system
from torch import no_grad
from torch import argmax, tensor, index_select, mode, diagonal, fliplr, eye
from torch import sum as torch_sum
import torch
import numpy as np
import cProfile
import pstats

In [33]:
def consensus(output, ndx, device):
    output_length, w = output.shape
    seq_length = output_length + w - 1
    filter_size = min(seq_length - ndx, ndx + 1)
    if filter_size > w:
        filter_size = w
    r_min = max(0, ndx - w + 1)
    r_max = r_min + filter_size
    r_indices = tensor(range(r_min, r_max), device=device)
    c_min = max(0, ndx - output_length + 1)
    c_max = min(ndx, w - 1) + 1
    c_indices = tensor(range(c_min, c_max), device=device)
    sub_result = index_select(index_select(output, 0, r_indices), 1, c_indices)
    val = mode(diagonal(fliplr(fliplr(eye(filter_size, device=device).long()) * sub_result)))
    return val.values.item()


def get_seq(ndx, ndx_windows):
    output_length, w = ndx_windows.shape
    seq_length = output_length + w - 1
    if ndx < output_length:
        return ndx_windows[ndx][0]
    elif ndx < seq_length:
        return ndx_windows[-1][ndx-output_length+1]
    else:
        raise IndexError(f'index {ndx-output_length+1} is out of bounds for dimension 1 with size {w}')


def consensus_acc(seq, output, device):
    output_length, w = output.shape
    seq_length = output_length + w - 1
    n = 0
    consensus_seq = []
    for i in range(seq_length):
        consensus_seq.append(consensus(output, i, device=device))
        if get_seq(i, seq).item() == consensus_seq[-1]:
            n += 1
    return n / len(seq), consensus_seq

def main(w, l):
    ol = l - w + 1
    output = torch.randint(0, 21, (ol, w)).long()
    seq = torch.randint(0, 21, (l,)).long()
    consensus_acc = consensus_acc(seq, output, 'cpu') 

In [34]:
# l = 100
# w = 20
# ol = l - w + 1
# output = torch.randint(0, 21, (ol, w)).long()
# get_seq(20, output)

In [35]:
# l = 100
# w = 20
# r = 100
# job = f'tensor_{w}_{l}_{r}'
# job = f'np_{w}_{l}_{r}'

# with cProfile.Profile() as pr:
#     for _ in range(r):
#         main(w, l)
# stats = pstats.Stats(pr)
# stats.sort_stats(pstats.SortKey.TIME)
# stats.print_stats()
# filename=f'./profiling_{job}.prof'
# stats.dump_stats(filename)

In [36]:
class NewAutoencoder(Autoencoder):
    def test_batch(self, input_vals, device, input_noise=0.0, wandb_log=True):
        """
        Test a single batch of data, this will move into autoencoder
        :param input_vals:
        :return:
        """
        with no_grad():
            input_ndx, one_hot_input = self.transform_input(input_vals, device, input_noise=input_noise)
            reconstructor_output = self.forward_test(one_hot_input)
            reconstructor_loss = self.criterion_NLLLoss(reconstructor_output, input_ndx.reshape((-1,)))
            # reconstructor acc
            reconstructor_ndx = argmax(reconstructor_output, dim=1)
            reconstructor_accuracy = (
                torch_sum(reconstructor_ndx == input_ndx.reshape((-1,))) / reconstructor_ndx.shape[0]
            )
            consensus_seq_acc, consensus_seq = consensus_acc(input_ndx, reconstructor_ndx.reshape((-1, self.w)), device)
            # reconstruction_loss, discriminator_loss, classifier_loss
            if wandb_log:
                wandb.log({"test_reconstructor_loss": reconstructor_loss.item()})
                wandb.log({"test_reconstructor_accuracy": reconstructor_accuracy.item()})
                wandb.log({"test_consensus_accuracy": consensus_seq_acc})
            else:
                return (
                    reconstructor_loss,
                    reconstructor_accuracy,
                    consensus_seq_acc,
                    consensus_seq,
                )
            # clean up
            del reconstructor_loss
            del reconstructor_output
            return

class NewModel(Model):
    def build_model(self, model_type, arch):
        if model_type == "AE":
            self.autoencoder = NewAutoencoder(self.d0, self.d1, self.dn, self.w, arch)
        # elif model_type == "AAE":
        #     self.autoencoder = AdversarialAutoencoder(
        #         self.d0, self.d1, self.dn, self.w, arch
        #     )
        # elif model_type == "AAEC":
        #     self.autoencoder = AdversarialAutoencoderClassifier(
        #         self.d0, self.d1, self.dn, self.w, arch
        #     )
        self.autoencoder.to(self.device)
    
    def predict(self, num_test_items=1, input_noise=0.0):
        """
        The main training loop for a model
        :param num_test_items:
        :param input_noise:
        :return:
        """
        for test_batch in self.data_loader.get_test_batch(num_test_items=num_test_items):
            results = self.autoencoder.test_batch(test_batch, self.device, input_noise=input_noise, wandb_log=False)
            # do stuff with results

class NewSession(Session):
    
    def add_model(self, name, arch, model_type, d0=21, d1=8, dn=10, w=20):
        arch = self.load_arch(arch)
        if self.model is None:
            self.model = NewModel(name, arch, model_type, d0=d0, d1=d1, dn=dn, w=w)
            
    def load_data(self, dataset_name):
        data_files = sorted(glob(str(Model.root) + f"/data/{dataset_name}/*.csv.gz"))[:2]
        self.model.load_data(dataset_name, data_files)
        
    def test(self, num_test_items=1, input_noise=0.0):
        self.model.test(num_test_items=num_test_items, input_noise=input_noise)

In [37]:
session = NewSession()
model_name = 'dummy'
arch = 'gen1'
model_type = 'AE'
dataset = 'w_20_KeggSeq_ndx_wpACT'
run_title = 'prototyping_consensus_acc'
w = 20

In [38]:
session.add_model(model_name, arch, model_type)
session.load_data(dataset)
result = session.test(num_test_items=1, input_noise=0.0)

test batch shape:  (353, 22)
input vals shape : (353, 22)
input ndx shape : torch.Size([353, 20])
input onehot shape : torch.Size([353, 20, 21])
output shape : torch.Size([353, 20])


In [39]:
result