In [1]:
# import modules
import os
from onto_vae.ontobj import *
from onto_vae.vae_model import *

In [2]:
# initialize the Ontobj
# the description should be an identifier, e.g. the ontology used, here: PWO (Pathway Ontology)
pwo = Ontobj(description='PWO')

In [3]:
# initialize our ontology
# obo: path to an obo file
# gene_annot: path to a tab separated file with two columns: Genes and Ontology IDs
pwo.initialize_dag(obo=data_path() + 'pw.obo',
                   gene_annot=data_path() + 'gene_term_mapping.txt')

# trim the ontology
pwo.trim_dag(top_thresh=1000, 
             bottom_thresh=30)

#create masks for decoder initialization

################### parameter which positions to change ##########################

In [4]:
reg = np.zeros((5892,657))
reg[0,0] = 1

pwo.create_masks(top_thresh=1000,
                 bottom_thresh=30,
                 reg_mask= reg)

pwo.match_dataset(expr_data = data_path() + 'pbmc_sample_expr.csv',
                  name='PBMC_CD4T')

In [5]:
[i.shape for i in pwo.masks["1000_30"]]

[(3, 76),
 (12, 79),
 (131, 91),
 (169, 222),
 (160, 391),
 (101, 551),
 (5, 652),
 (5892, 657)]

In [6]:
# initialize OntoVAE 
pwo_model = OntoVAE(ontobj=pwo,              # the Ontobj we will use
                    dataset='PBMC_CD4T',     # which dataset from the Ontobj to use for model training
                    top_thresh=1000,         # which trimmed version to use
                    bottom_thresh=30)        # which trimmed version to use     
pwo_model.to(pwo_model.device)         

OntoVAE(
  (encoder): Encoder(
    (encoder): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=5892, out_features=228, bias=True)
        (1): BatchNorm1d(228, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Dropout(p=0.2, inplace=False)
        (3): ReLU()
      )
    )
    (mu): Sequential(
      (0): Linear(in_features=228, out_features=228, bias=True)
      (1): Dropout(p=0.5, inplace=False)
    )
    (logvar): Sequential(
      (0): Linear(in_features=228, out_features=228, bias=True)
      (1): Dropout(p=0.5, inplace=False)
    )
  )
  (decoder): OntoDecoder(
    (decoder): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=228, out_features=9, bias=True)
      )
      (1): Sequential(
        (0): Linear(in_features=237, out_features=36, bias=True)
      )
      (2): Sequential(
        (0): Linear(in_features=273, out_features=393, bias=True)
      )
      (3): Sequential(
        (0): Linear(in_features=666, out_fe

In [7]:
[i.shape for i in pwo.masks["1000_30"]]

[(3, 76),
 (12, 79),
 (131, 91),
 (169, 222),
 (160, 391),
 (101, 551),
 (5, 652),
 (5892, 657)]

In [8]:
# train the model
pwo_model.train_model(os.getcwd() + 'models/best_model.pt',   # where to store the best model
                     lr=1e-4, 
                     l1 = 1,                                # the learning rate
                     kl_coeff=1e-4,                           # the weighting coefficient for the Kullback Leibler loss
                     batch_size=128,                          # the size of the minibatches
                     epochs=5,                                # over how many epochs to train
                     log=False)    

Epoch 1 of 5


  0%|          | 0/7 [00:00<?, ?it/s]

  self.decoder.decoder[i][0].weight.data = torch.tensor(self.decoder.decoder[i][0].weight.data.clamp(0))
 14%|█▍        | 1/7 [00:00<00:01,  5.14it/s]

tensor([[1.2250e-06, 1.0000e-08, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 29%|██▊       | 2/7 [00:00<00:00,  5.95it/s]

tensor([[0.0000e+00, 9.7369e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 43%|████▎     | 3/7 [00:00<00:00,  6.15it/s]

tensor([[0.0000e+00, 9.8703e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 57%|█████▋    | 4/7 [00:00<00:00,  6.28it/s]

tensor([[0.0000e+00, 9.8092e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 71%|███████▏  | 5/7 [00:00<00:00,  6.38it/s]

tensor([[0.0000e+00, 9.6342e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 86%|████████▌ | 6/7 [00:00<00:00,  6.16it/s]

tensor([[0.0000e+00, 9.6016e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 7/7 [00:01<00:00,  6.24it/s]


tensor([[0.0000e+00, 8.9190e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 2/2 [00:00<00:00, 96.59it/s]

New best model!





Train Loss: 25140.2930
Val Loss: 21918.4731
Epoch 2 of 5


 14%|█▍        | 1/7 [00:00<00:01,  5.37it/s]

tensor([[0.0000e+00, 9.0365e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 29%|██▊       | 2/7 [00:00<00:00,  5.76it/s]

tensor([[0.0000e+00, 9.2565e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 43%|████▎     | 3/7 [00:00<00:00,  6.15it/s]

tensor([[0.0000e+00, 8.9150e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 57%|█████▋    | 4/7 [00:00<00:00,  6.05it/s]

tensor([[0.0000e+00, 9.2812e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 71%|███████▏  | 5/7 [00:00<00:00,  6.19it/s]

tensor([[0.0000e+00, 9.0865e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 86%|████████▌ | 6/7 [00:00<00:00,  6.34it/s]

tensor([[0.0000e+00, 8.7725e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 7/7 [00:01<00:00,  6.29it/s]


tensor([[0.0000e+00, 8.3380e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 2/2 [00:00<00:00, 97.30it/s]

New best model!





Train Loss: 25050.5011
Val Loss: 21877.6587
Epoch 3 of 5


 14%|█▍        | 1/7 [00:00<00:01,  5.08it/s]

tensor([[0.0000e+00, 8.4503e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 29%|██▊       | 2/7 [00:00<00:00,  5.80it/s]

tensor([[0.0000e+00, 8.8470e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 43%|████▎     | 3/7 [00:00<00:00,  5.82it/s]

tensor([[0.0000e+00, 8.6033e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 57%|█████▋    | 4/7 [00:00<00:00,  6.06it/s]

tensor([[0.0000e+00, 8.4669e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 86%|████████▌ | 6/7 [00:01<00:00,  5.52it/s]

tensor([[0.0000e+00, 8.9003e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])
tensor([[0.0000e+00, 9.1470e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
    

100%|██████████| 7/7 [00:01<00:00,  5.73it/s]


tensor([[0.0000e+00, 8.6829e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 2/2 [00:00<00:00, 99.59it/s]


New best model!
Train Loss: 24982.3791
Val Loss: 21833.3472
Epoch 4 of 5


 14%|█▍        | 1/7 [00:00<00:01,  5.70it/s]

tensor([[0.0000e+00, 8.4692e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 29%|██▊       | 2/7 [00:00<00:00,  5.92it/s]

tensor([[0.0000e+00, 8.4831e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 43%|████▎     | 3/7 [00:00<00:00,  6.27it/s]

tensor([[0.0000e+00, 8.3701e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 57%|█████▋    | 4/7 [00:00<00:00,  6.32it/s]

tensor([[0.0000e+00, 8.1803e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 71%|███████▏  | 5/7 [00:00<00:00,  6.14it/s]

tensor([[0.0000e+00, 8.2685e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 86%|████████▌ | 6/7 [00:00<00:00,  6.04it/s]

tensor([[0.0000e+00, 8.0892e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 7/7 [00:01<00:00,  6.26it/s]


tensor([[0.0000e+00, 7.4350e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 2/2 [00:00<00:00, 76.59it/s]

New best model!





Train Loss: 24907.2460
Val Loss: 21780.1279
Epoch 5 of 5


 14%|█▍        | 1/7 [00:00<00:00,  6.11it/s]

tensor([[0.0000e+00, 7.3532e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 29%|██▊       | 2/7 [00:00<00:00,  6.26it/s]

tensor([[0.0000e+00, 7.3858e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 43%|████▎     | 3/7 [00:00<00:00,  6.18it/s]

tensor([[0.0000e+00, 7.4650e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 57%|█████▋    | 4/7 [00:00<00:00,  6.02it/s]

tensor([[0.0000e+00, 7.7202e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 71%|███████▏  | 5/7 [00:00<00:00,  6.39it/s]

tensor([[0.0000e+00, 7.2322e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


 86%|████████▌ | 6/7 [00:00<00:00,  6.41it/s]

tensor([[0.0000e+00, 7.0141e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 7/7 [00:01<00:00,  6.32it/s]


tensor([[0.0000e+00, 6.6102e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])


100%|██████████| 2/2 [00:00<00:00, 102.21it/s]

New best model!





Train Loss: 24828.2215
Val Loss: 21723.2222


In [10]:
pwo_model.decoder.decoder[-1][0].weight.data

tensor([[0.0000e+00, 6.6102e-09, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, -0.0000e+00,
         0.0000e+00],
        ...,
        [0.0000e+00, 0.0000e+00, -0.0000e+00,  ..., -0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., -0.0000e+00, -0.0000e+00,
         -0.0000e+00],
        [0.0000e+00, -0.0000e+00, -0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]])