In [1]:
import sys
from pathlib import Path

proj_path = Path('/cluster') / 'work' / 'jacobaal' / 'pers-pred'
proj_path = proj_path.resolve()
if proj_path not in sys.path: sys.path.append(str(proj_path))

import torch
from torch import nn
from torch.utils.data import DataLoader
import pandas as pd
import src.fcki as fcki
from src.utils import get_commons
from src.metrics import BCEMetric, MAEMetric
from src.ffn import Encoder, Decoder
from src.datasets import SingleInputDataset, get_mtl_dataloaders
from src.trainer import MTLTrainer
import LibMTL as mtl

  from .autonotebook import tqdm as notebook_tqdm


device: cpu


In [2]:
paths, constants, config, logger, device = get_commons()

2024-06-01 12:57:54,777 - ArgumentLogger - INFO - Arguments:
2024-06-01 12:57:54,777 - ArgumentLogger - INFO - Arguments:
2024-06-01 12:57:54,778 - ArgumentLogger - INFO - seed: 42
2024-06-01 12:57:54,778 - ArgumentLogger - INFO - seed: 42
2024-06-01 12:57:54,780 - ArgumentLogger - INFO - dataframe: {'generate': False, 'mbti_frac': 0.1, 'bigfive_c_frac': 1.0, 'bigfive_s_frac': 1.0}
2024-06-01 12:57:54,780 - ArgumentLogger - INFO - dataframe: {'generate': False, 'mbti_frac': 0.1, 'bigfive_c_frac': 1.0, 'bigfive_s_frac': 1.0}
2024-06-01 12:57:54,781 - ArgumentLogger - INFO - eda: {'generate': False}
2024-06-01 12:57:54,781 - ArgumentLogger - INFO - eda: {'generate': False}
2024-06-01 12:57:54,783 - ArgumentLogger - INFO - reduce: {'generate': False, 'use_full': False}
2024-06-01 12:57:54,783 - ArgumentLogger - INFO - reduce: {'generate': False, 'use_full': False}
2024-06-01 12:57:54,784 - ArgumentLogger - INFO - preprocessing: {'generate_features': False, 'generate_partially_cleaned': Fa

device: cpu


In [3]:
model_name = config['embeddings']['model']
embedding_size = constants['embedding_sizes'][model_name]
stats_size = 0

In [4]:
dataframes = {task: pd.read_csv(paths['split'][model_name][task], header=[0, 1], index_col=0).drop('STATS', axis='columns') for task in constants['tasks']}
datasets = {task: SingleInputDataset(dataframe) for task, dataframe in dataframes.items()}
dataloaders = get_mtl_dataloaders(datasets, config['split']['train'], config['split']['test'], config['dataloaders'], logger=logger)
dataloaders

In [None]:
input_size = [config['encoder']['nn'][-1]]
decoders = nn.ModuleDict({
    'mbti': Decoder(input_size + config['mtl-decoders']['hidden_nn'] + [4], final='sigmoid', dropout=config['mtl-decoders']['dropout']).to(device),
    'bigfive_c': Decoder(input_size + config['mtl-decoders']['hidden_nn'] + [5], final='sigmoid', dropout=config['mtl-decoders']['dropout']).to(device),
    'bigfive_s': Decoder(input_size + config['mtl-decoders']['hidden_nn'] + [5], final='none', dropout=config['mtl-decoders']['dropout']).to(device)
    })

In [None]:
task_dict = {
  'mbti': {
    'metrics': ['MSE'],
    'metrics_fn': BCEMetric(),
    'loss_fn': mtl.loss.CELoss(),
    'weight': [0]
  },
  'bigfive_c': {
    'metrics': ['MSE'],
    'metrics_fn': BCEMetric(),
    'loss_fn': mtl.loss.CELoss(),
    'weight': [0]
  },
  'bigfive_s': {
    'metrics': ['MSE'],
    'metrics_fn': MAEMetric(),
    'loss_fn': mtl.loss.MSELoss(),
    'weight': [0] # 0 means high loss is bad
  },
}

In [None]:
args = config["mtl"]
mtl_trainer = MTLTrainer(
    task_dict=task_dict,
    weighting=args["weighting"],
    architecture=args["architecture"],
    encoder_class=Encoder,
    decoders=decoders,
    rep_grad=True,
    multi_input=True,
    optim_param=config["optim_param"],
    scheduler_param=config["scheduler_param"],
    device=device,
    save_path=paths["training"]["mtl_save"],
    **args["kwargs"]
)

Total Params: 10508305
Trainable Params: 10508305
Non-trainable Params: 0
LOG FORMAT | mbti_LOSS MSE | bigfive_c_LOSS MSE | bigfive_s_LOSS MSE | TIME


In [None]:
mtl_trainer.train(
          train_dataloaders=dataloaders["train"],
          test_dataloaders=dataloaders["test"],
          val_dataloaders=dataloaders["val"],
          epochs=config["training"]["epochs"],
          patience=config["training"]["patience"]
        )

Epoch 0, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:29<00:00,  1.71s/batch]

Epoch: 0000 | TRAIN: 1.7786 37.4113 | 4.1706 3.0105 | 1367.8913 30.3965 | Time: 29.2706 | 




VAL: 1.7777 48.5922 | 4.1324 6.7007 | 1068.6784 28.0461 | Time: 2.9833 | TEST: 1.7777 48.5922 | 4.1324 6.7007 | 1068.6784 28.0461 | Time: 3.1103
Save Model 0 to /cluster/work/stefandt/pers-pred/checkpoints/mtl/best.pt


Epoch 1, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:44<00:00,  2.62s/batch]

Epoch: 0001 | TRAIN: 1.7785 47.0580 | 4.1269 6.1881 | 1024.5507 27.5517 | Time: 44.5743 | 




VAL: 1.7777 48.0597 | 4.1324 4.8599 | 941.9759 26.7792 | Time: 3.1045 | TEST: 1.7777 48.0597 | 4.1324 4.8599 | 941.9759 26.7792 | Time: 3.0008
Save Model 1 to /cluster/work/stefandt/pers-pred/checkpoints/mtl/best.pt


Epoch 2, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:45<00:00,  2.69s/batch]

Epoch: 0002 | TRAIN: 1.7768 48.0794 | 4.1399 4.8855 | 942.5902 26.7568 | Time: 45.6501 | 




VAL: 1.7777 48.0958 | 4.1324 4.9642 | 934.5260 26.7219 | Time: 3.3326 | TEST: 1.7777 48.0958 | 4.1324 4.9642 | 934.5260 26.7219 | Time: 3.1440


Epoch 3, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:44<00:00,  2.65s/batch]

Epoch: 0003 | TRAIN: 1.7791 48.0864 | 4.1217 4.9765 | 939.6957 26.7558 | Time: 44.9871 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3993 26.7210 | Time: 3.1628 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3993 26.7210 | Time: 3.0132


Epoch 4, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:45<00:00,  2.68s/batch]

Epoch: 0004 | TRAIN: 1.7790 48.0841 | 4.1234 4.9831 | 940.0425 26.7695 | Time: 45.6302 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.5121 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.2185


Epoch 5, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:44<00:00,  2.63s/batch]

Epoch: 0005 | TRAIN: 1.7772 48.1047 | 4.1352 4.9641 | 943.5989 26.8345 | Time: 44.7481 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.3752 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.1833


Epoch 6, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:45<00:00,  2.68s/batch]

Epoch: 0006 | TRAIN: 1.7758 48.1543 | 4.1327 4.9699 | 940.8655 26.7776 | Time: 45.5905 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.3348 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.2025


Epoch 7, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:44<00:00,  2.60s/batch]

Epoch: 0007 | TRAIN: 1.7775 48.1240 | 4.1388 4.9641 | 944.8311 26.8516 | Time: 44.2296 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.0572 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.2831


Epoch 8, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:45<00:00,  2.68s/batch]

Epoch: 0008 | TRAIN: 1.7780 48.1057 | 4.1300 4.9760 | 937.3058 26.7206 | Time: 45.6035 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.2089 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.3154


Epoch 9, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:44<00:00,  2.62s/batch]

Epoch: 0009 | TRAIN: 1.7778 48.1178 | 4.1393 4.9619 | 938.4634 26.7521 | Time: 44.6083 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.1983 | TEST: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.3100


Epoch 10, Allocated 0.0GB, Cached 0.0GB: 100%|██████████| 17/17 [00:45<00:00,  2.66s/batch]

Epoch: 0010 | TRAIN: 1.7753 48.1600 | 4.1351 4.9672 | 940.3865 26.7860 | Time: 45.2485 | 




VAL: 1.7777 48.0966 | 4.1324 4.9663 | 934.3992 26.7210 | Time: 3.2164 | Early stopping at epoch 10
Best Result: Epoch 1, result {'mbti': [48.05974663628472], 'bigfive_c': [4.859935601552327], 'bigfive_s': [26.779177711923744]}
