In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import shutil
from copy import deepcopy
from pathlib import Path

import molfeat
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from hydra import compose, initialize
from hydra.core.global_hydra import GlobalHydra
from hydra.utils import instantiate
from lightning.pytorch.loggers import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict

from src import utils
from src.models.jump_cl import BasicJUMPModule
from src.modules.collate_fn import default_collate
from src.modules.losses.contrastive_losses import InfoNCE, NTXent, RegInfoNCE, RegNTXent
from src.utils import instantiate_evaluator_list

In [None]:
for i in range(1, 4):
    if not Path(f"../cpjump{i}/jump/").exists():
        print(f"Mounting cpjump{i}...")
        os.system(f"sshfs bioclust:/projects/cpjump{i}/ ../cpjump{i}")
    else:
        print(f"cpjump{i} already mounted.")

## Load the config and instantiate the model, loggers and evaluators

In [6]:
GlobalHydra.instance().clear()

In [7]:
initialize(version_base=None, config_path=f"../configs/model/criterion")

hydra.initialize()

In [8]:
cfg = compose(
    config_name="ntxent_vae.yaml",
    overrides=[
        # "evaluate=true",
        # "eval=retrieval",
        # "paths.projects_dir=..",
        # f"paths.output_dir=../cpjump1/jump/logs/train/multiruns/{run}",
        # # "experiment=fp_big",
        # "data.batch_size=4",
        # # "model/molecule_encoder=gin_masking.yaml",
        # "trainer.devices=1",
        # # "eval.moa_image_task.datamodule.data_root_dir=../",
    ],
)
print(OmegaConf.to_yaml(cfg))

_target_: src.modules.losses.base_losses.CombinationLoss
norm: true
weights:
- 1
- 0.25
- 1
losses:
  NTXent:
    _target_: src.modules.losses.contrastive_losses.NTXent
    norm: true
    temperature: 15
    return_rank: true
    temperature_requires_grad: false
    temperature_min: 0
    temperature_max: 100
  regularization:
    _target_: src.modules.losses.base_losses.RegularizationLoss
    mse_reg: 1
    l1_reg: 0.15
    uniformity_reg: 0
    variance_reg: 1
    covariance_reg: 0.5
  autoencoder:
    _target_: src.modules.losses.autoencoder_losses.VariationalAutoEncoderLoss
    emb_dim: 128
    loss: cosine
    detach_target: true
    beta: 1



In [10]:
criterion = instantiate(cfg)

In [12]:
criterion.losses

{'NTXent': NTXent(), 'regularization': RegularizationLoss(
  (mse_loss): MSELoss()
  (l1_loss): L1Loss()
), 'autoencoder': VariationalAutoEncoderLoss(
  (fc_mu): Linear(in_features=128, out_features=128, bias=True)
  (fc_var): Linear(in_features=128, out_features=128, bias=True)
  (decoder): Sequential(
    (0): Linear(in_features=128, out_features=128, bias=True)
    (1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=128, out_features=128, bias=True)
  )
)}