In [14]:
from SeqEN2.session import Session
from SeqEN2.model.model import Model
from SeqEN2.autoencoder.autoencoder import Autoencoder
from glob import glob
from os import system
from torch import no_grad
from torch import argmax
from torch import sum as torch_sum

In [21]:
class NewAutoencoder(Autoencoder):
    def test_batch(self, input_vals, device, input_noise=0.0, wandb_log=True):
        """
        Test a single batch of data, this will move into autoencoder
        :param input_vals:
        :return:
        """
        with no_grad():
            input_ndx, one_hot_input = self.transform_input(
                input_vals, device, input_noise=input_noise
            )
            reconstructor_output = self.forward_test(one_hot_input)
            reconstructor_loss = self.criterion_NLLLoss(
                reconstructor_output, input_ndx.reshape((-1,))
            )
            # reconstructor acc
            reconstructor_ndx = argmax(reconstructor_output, dim=1)
            reconstructor_accuracy = (
                torch_sum(reconstructor_ndx == input_ndx.reshape((-1,)))
                / reconstructor_ndx.shape[0]
            )
            # reconstruction_loss, discriminator_loss, classifier_loss
            if wandb_log:
                wandb.log({"test_reconstructor_loss": reconstructor_loss.item()})
                wandb.log(
                    {"test_reconstructor_accuracy": reconstructor_accuracy.item()}
                )
            else:
                return reconstructor_loss, reconstructor_accuracy
            # clean up
            del reconstructor_loss
            del reconstructor_output
            return


class NewModel(Model):
    def build_model(self, model_type, arch):
        if model_type == "AE":
            self.autoencoder = NewAutoencoder(self.d0, self.d1, self.dn, self.w, arch)
        # elif model_type == "AAE":
        #     self.autoencoder = AdversarialAutoencoder(
        #         self.d0, self.d1, self.dn, self.w, arch
        #     )
        # elif model_type == "AAEC":
        #     self.autoencoder = AdversarialAutoencoderClassifier(
        #         self.d0, self.d1, self.dn, self.w, arch
        #     )
        self.autoencoder.to(self.device)

    def test(self, num_test_items=1, input_noise=0.0):
        for test_batch in self.data_loader.get_test_batch(
            num_test_items=num_test_items
        ):
            results = self.autoencoder.test_batch(
                test_batch, self.device, input_noise=input_noise, wandb_log=False
            )
            print(results)


class NewSession(Session):
    def add_model(self, name, arch, model_type, d0=21, d1=8, dn=10, w=20):
        arch = self.load_arch(arch)
        if self.model is None:
            self.model = NewModel(name, arch, model_type, d0=d0, d1=d1, dn=dn, w=w)

    def load_data(self, dataset_name):
        data_files = sorted(glob(str(Model.root) + f"/data/{dataset_name}/*.csv.gz"))[
            :2
        ]
        self.model.load_data(dataset_name, data_files)

In [22]:
session = NewSession()
model_name = "dummy"
arch = "gen1"
model_type = "AE"
dataset = "w_20_KeggSeq_ndx_wpACT"
run_title = "prototyping_consensus_acc"

In [23]:
session.add_model(model_name, arch, model_type)
session.load_data(dataset)
session.test(num_test_items=1, input_noise=0.0)

(tensor(3.0896), tensor(0.0380))


In [10]:
run_dir = session.root / "models" / f"{model_name}" / "versions" / f"{run_title}"

In [12]:
system(f"rm -r {str(run_dir)}")

0