In [54]:
import torch
from torch import optim
import argparse

from brain_multimodal_vae.dataset.deeprecon import load_data, select_train_data, DeepReconDataset
from brain_multimodal_vae.models import DMVAE, MMVAE, MVAE
from brain_multimodal_vae.training import Trainer
import brain_multimodal_vae.utils as utils

device = "cuda" if torch.cuda.is_available() else "cpu"
g = torch.Generator()

In [55]:
all_subject_list = ["x01", "x02", "x03", "x04", "x05"]
n_voxels_dict = {'x01': 15316, 'x02': 14597, 'x03': 13135, 'x04': 13596, 'x05': 13149}
models = {"MVAE": MVAE, "MMVAE": MMVAE, "DMVAE": DMVAE}

parser = argparse.ArgumentParser()

# Path setting
parser.add_argument("--data_dir", type=str, default="/home/acg17270jl/projects/brain-multimodal-vae/data/deeprecon/")
parser.add_argument("--ckpt_dir", type=str, default="/home/acg17270jl/projects/brain-multimodal-vae/checkpoints/deeprecon/")
# Data setting
parser.add_argument("--subject_list", nargs="+", default=["x01", "x02", "x03", "x04", "x05"])
parser.add_argument("--n_train_repetitions", type=int, default=5)
parser.add_argument("--normalize", action=argparse.BooleanOptionalAction, default=True)
parser.add_argument("--n_shared_labels", type=int, default=600)
parser.add_argument("--n_unique_labels", type=int, default=120)
parser.add_argument("--select_seed", type=int, default=42)
# Dataset setting
parser.add_argument("--train_group", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--test_group", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--set_mode", choices=["and", "or"], default="or")
parser.add_argument("--include_missing", action=argparse.BooleanOptionalAction, default=True)
# DataLoder setting
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--batch_size", type=int, default=128)
# Model setting
parser.add_argument('--model_name', choices=list(models.keys()), default="MVAE")
parser.add_argument("--z_dim", type=int, default=128)
parser.add_argument("--zp_dim", type=int, default=64)
parser.add_argument("--zs_dim", type=int, default=128)
parser.add_argument("--hidden_dim", type=int, default=4096)
# Training setting
parser.add_argument("--optimizer_name", choices=["adam"], default="adam")
parser.add_argument("--lr", type=float, default=1e-4)
parser.add_argument("--weight_decay", type=float, default=1e-1)
parser.add_argument("--n_epochs", type=int, default=15)
parser.add_argument("--eval", action=argparse.BooleanOptionalAction, default=False)

jupyter_args = """
    --subject_list x01 x02 \
    --n_train_repetitions 5 \
    --n_shared_labels 1200 \
    --n_unique_labels 0 \
    --select_seed 42 \
    --no-train_group \
    --include_missing \
    --model_name MVAE \
    --z_dim 128 \
    --eval
"""

jupyter_args = jupyter_args.split()

if utils.is_interactive():
    args = parser.parse_args(jupyter_args)
else:
    args = parser.parse_args()

params = vars(args)

In [56]:
train_brain_dict, train_label_dict, test_brain_dict, test_label_dict = load_data(params["data_dir"], all_subject_list, params["n_train_repetitions"], params["normalize"])

params["select_seed"] = torch.seed() if params["select_seed"] is None else params["select_seed"]
g.manual_seed(params["select_seed"])
train_brain_dict, train_label_dict = select_train_data(train_brain_dict, train_label_dict, all_subject_list, params["n_shared_labels"], params["n_unique_labels"], generator=g)

train_brain_dict = utils.get_sub_dict(train_brain_dict, params["subject_list"])
train_label_dict = utils.get_sub_dict(train_label_dict, params["subject_list"])
test_brain_dict = utils.get_sub_dict(test_brain_dict, params["subject_list"])
test_label_dict = utils.get_sub_dict(test_label_dict, params["subject_list"])

In [57]:
train_ds = DeepReconDataset(train_brain_dict, train_label_dict, params["subject_list"], params["train_group"], params["set_mode"], params["include_missing"])
test_ds = DeepReconDataset(test_brain_dict, test_label_dict, params["subject_list"], params["test_group"], params["set_mode"], params["include_missing"])

train_dl = torch.utils.data.DataLoader(train_ds, batch_size=params["batch_size"], shuffle=True, drop_last=True, pin_memory=True, generator=g)
test_dl = torch.utils.data.DataLoader(test_ds, batch_size=params["batch_size"], shuffle=False, drop_last=False, pin_memory=True)

In [58]:
model = models[params["model_name"]](**params, n_voxels_dict=n_voxels_dict, device=device)

In [59]:
trainer = Trainer(model, train_dl, test_dl, params["eval"])

g.manual_seed(params["seed"])
trainer.train(params["n_epochs"])

trainer.save(params)

  0%|          | 0/46 [00:00<?, ?it/s]

100%|██████████| 46/46 [00:02<00:00, 20.66it/s]


Epoch: 1 Train loss: 128918.5781


100%|██████████| 10/10 [00:00<00:00, 17.86it/s]


Epoch: 1 Test loss: 133574.7812


100%|██████████| 46/46 [00:02<00:00, 18.51it/s]


Epoch: 2 Train loss: 125050.9297


100%|██████████| 10/10 [00:00<00:00, 27.09it/s]


Epoch: 2 Test loss: 132322.3281


100%|██████████| 46/46 [00:02<00:00, 20.28it/s]


Epoch: 3 Train loss: 122542.5312


100%|██████████| 10/10 [00:00<00:00, 26.66it/s]


Epoch: 3 Test loss: 132002.7969


100%|██████████| 46/46 [00:03<00:00, 14.55it/s]


Epoch: 4 Train loss: 120696.3906


100%|██████████| 10/10 [00:00<00:00, 18.65it/s]


Epoch: 4 Test loss: 131683.9219


100%|██████████| 46/46 [00:03<00:00, 13.67it/s]


Epoch: 5 Train loss: 119283.7812


100%|██████████| 10/10 [00:00<00:00, 18.88it/s]


Epoch: 5 Test loss: 131641.4688


100%|██████████| 46/46 [00:03<00:00, 13.66it/s]


Epoch: 6 Train loss: 118154.4297


100%|██████████| 10/10 [00:00<00:00, 19.75it/s]


Epoch: 6 Test loss: 131513.1562


100%|██████████| 46/46 [00:02<00:00, 18.04it/s]


Epoch: 7 Train loss: 117210.2188


100%|██████████| 10/10 [00:00<00:00, 36.30it/s]


Epoch: 7 Test loss: 131729.2500


100%|██████████| 46/46 [00:02<00:00, 20.48it/s]


Epoch: 8 Train loss: 116388.1797


100%|██████████| 10/10 [00:00<00:00, 25.29it/s]


Epoch: 8 Test loss: 131485.5625


100%|██████████| 46/46 [00:02<00:00, 16.09it/s]


Epoch: 9 Train loss: 115680.6250


100%|██████████| 10/10 [00:00<00:00, 36.19it/s]


Epoch: 9 Test loss: 131559.7969


100%|██████████| 46/46 [00:02<00:00, 20.92it/s]


Epoch: 10 Train loss: 115094.9375


100%|██████████| 10/10 [00:00<00:00, 36.40it/s]


Epoch: 10 Test loss: 131619.2344


100%|██████████| 46/46 [00:02<00:00, 21.05it/s]


Epoch: 11 Train loss: 114530.6016


100%|██████████| 10/10 [00:00<00:00, 32.52it/s]


Epoch: 11 Test loss: 131639.9688


100%|██████████| 46/46 [00:02<00:00, 21.13it/s]


Epoch: 12 Train loss: 114053.9688


100%|██████████| 10/10 [00:00<00:00, 35.64it/s]


Epoch: 12 Test loss: 131510.0625


100%|██████████| 46/46 [00:02<00:00, 21.13it/s]


Epoch: 13 Train loss: 113604.7500


100%|██████████| 10/10 [00:00<00:00, 29.27it/s]


Epoch: 13 Test loss: 131682.0781


100%|██████████| 46/46 [00:02<00:00, 20.97it/s]


Epoch: 14 Train loss: 113211.4453


100%|██████████| 10/10 [00:00<00:00, 36.92it/s]


Epoch: 14 Test loss: 131596.2969


100%|██████████| 46/46 [00:02<00:00, 21.11it/s]


Epoch: 15 Train loss: 112867.0000


100%|██████████| 10/10 [00:00<00:00, 36.81it/s]


Epoch: 15 Test loss: 131681.9219
