In [1]:
repo_path = "/Users/johnzhou/research/decision-making"
expt_dir = f"{repo_path}/experiments"
data_dir = f"{repo_path}/data"

In [2]:
import sys
sys.path.append(repo_path)

In [3]:
import numpy as np
import matplotlib.pyplot as plt
from omegaconf import OmegaConf
from sklearn.model_selection import StratifiedKFold, train_test_split
import torch

from src.data.blocks import RealDataset, SynthDataset
from src.models.agentnet import AgentNet
from src.models.sigmoidnet import SigmoidNet
from src.models.find_shapes import conv1d_shape, convtranspose1d_shape
from src.models.train import train

In [5]:
print(OmegaConf.to_container(vae_config))

{'name': 'test_expt', 'random_seed': 4995, 'model': {'in_channels': 1, 'conv_encoder_layers': [[16, 3, 2], [8, 3, 2], [4, 1, 1]], 'conv_decoder_layers': [[8, 3, 2, 0], [4, 2, 1, 0], [1, 1, 2, 0]], 'encoder_output_dim': [4, 3], 'latent_dim': 6, 'use_batch_norm': True}, 'learning_rate': 0.0001, 'data': {'feature_path': '/Users/johnzhou/research/decision-making/data/synth/sim_features.npy', 'label_path': '/Users/johnzhou/research/decision-making/data/synth/sim_labels.npy', 'train_proportion': 0.8, 'train_batch_size': 100, 'val_batch_size': 100}, 'trainer': {'gpus': 0, 'max_epochs': 100}}


In [6]:
classifier_config = OmegaConf.create({
    "name": "test_expt",
    "random_seed": 4995,
    "model": {
        "in_features": 15,
        "linear_layers": [8],
        "use_batch_norm": True
    },
    "learning_rate": 1e-4,
    "data": {
        "feature_path": f"{repo_path}/data/synth/sim_features.npy",
        "label_path": f"{repo_path}/data/synth/sim_labels.npy",
        "train_proportion": 0.8,
        "train_batch_size": 100,
        "val_batch_size": 100
    },
    "trainer": {
        "gpus": 0,
        "max_epochs": 100
    },

})

OmegaConf.save(config=classifier_config, f=f"{repo_path}/configs/model_configs/classifier_train.yaml")

In [7]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

In [9]:
net = SigmoidNet(OmegaConf.to_container(classifier_config))
model = net.load_from_checkpoint(f"{expt_dir}/test_expt/model-v1.ckpt")

{'in_features': 15, 'linear_layers': [8], 'use_batch_norm': True}
LinearClassifier(
  (layers): Sequential(
    (0): Linear(in_features=15, out_features=8, bias=True)
    (1): LeakyReLU(negative_slope=0.05)
    (2): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=8, out_features=3, bias=True)
  )
)
{'in_features': 15, 'linear_layers': [8], 'use_batch_norm': True}
LinearClassifier(
  (layers): Sequential(
    (0): Linear(in_features=15, out_features=8, bias=True)
    (1): LeakyReLU(negative_slope=0.05)
    (2): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=8, out_features=3, bias=True)
  )
)


In [15]:
X_fname = f"{data_dir}/synth/sim_features.npy"
model_fname = f"{expt_dir}/test_expt/sigmoid.ckpt"
system = SigmoidNet

model = system.load_from_checkpoint(model_fname)

params = torch.squeeze(model(torch.unsqueeze(torch.from_numpy(np.load(X_fname)).float(), 1)))
print(params.shape)

{'in_features': 15, 'linear_layers': [64, 32, 16, 8], 'use_batch_norm': True}
LinearClassifier(
  (layers): Sequential(
    (0): Linear(in_features=15, out_features=64, bias=True)
    (1): LeakyReLU(negative_slope=0.05)
    (2): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Linear(in_features=64, out_features=32, bias=True)
    (4): LeakyReLU(negative_slope=0.05)
    (5): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): Linear(in_features=32, out_features=16, bias=True)
    (7): LeakyReLU(negative_slope=0.05)
    (8): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): Linear(in_features=16, out_features=8, bias=True)
    (10): LeakyReLU(negative_slope=0.05)
    (11): BatchNorm1d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): Linear(in_features=8, out_features=3, bias=True)
  )
)
torch.Size([100000, 3])
