In [1]:
import os
import sys
from datetime import datetime
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import transformers
import wandb
from datasets import Dataset, concatenate_datasets
from modAL.disagreement import max_std_sampling
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, WhiteKernel
from sklearn.model_selection import KFold
from transformers import Trainer, TrainingArguments
from unicore.optim.fused_adam import FusedAdam

sys.path.insert(0, "../")
from unimol.hf_unimol import UniMol, UniMolConfig, init_unimol_backbone
from unimol.hg_mterics import compute_metrics
from unimol.lmdb_dataset import collate_fn, load_dataset
from learners import ActiveLearner, CommitteeRegressor

%matplotlib inline

# set wandb offline
os.environ["WANDB_MODE"] = "offline"



### Try the point sampling with ensemble pytorch model first

1. put the models in the learner
2. Update and sample the data points
3. Evaluate the model with new metrics

In [2]:
weight_name = "baseline"
weight_mapping = {
    "baseline": "/scratch/ssd004/datasets/cellxgene/3d_molecule_save/weights/mol_pre_no_h_220816.pt",
    "20240502-1907": "/datasets/cellxgene/3d_molecule_save/pretrain-20240502-1907/checkpoint_best.pt",
}
weight_path = weight_mapping[weight_name]
dict_path = "/fs01/home/haotian/SDL-LNP/model/unimol/dict.txt"
cache_dir = "/datasets/cellxgene/3d_molecule_data/cache"
output_path = "/datasets/cellxgene/3d_molecule_save/fine-tuning"
# num_folds = 5
# kfold_order = "sequential"  # random | sequential
# if kfold_order == "random":
#     kf = KFold(n_splits=num_folds, shuffle=True, random_state=42)
# else:
#     kf = KFold(n_splits=num_folds, shuffle=False)

# load model and dictionary
model_backbone, dictionary = init_unimol_backbone(weight_path, dict_path=dict_path)

# build huggingface dataset from lmdb
lmdb_dir = Path("/datasets/cellxgene/3d_molecule_data/1920-lib")
train_data = load_dataset(
    dictionary,
    str(lmdb_dir / "train.lmdb"),
    "train",
)
valid_data = load_dataset(
    dictionary,
    str(lmdb_dir / "valid.lmdb"),
    "valid",
)
test_data = load_dataset(
    dictionary,
    str(lmdb_dir / "test.lmdb"),
    "test",
)

hf_train_data = Dataset.from_generator(
    lambda: train_data,
    cache_dir=cache_dir,
)
hf_valid_data = Dataset.from_generator(
    lambda: valid_data,
    cache_dir=cache_dir,
)
hf_test_data = Dataset.from_generator(
    lambda: test_data,
    cache_dir=cache_dir,
)
combined_data = concatenate_datasets([hf_train_data, hf_valid_data])


FileNotFoundError: [Errno 2] No such file or directory: '/fs01/home/haotian/SDL-LNP/model/unimol/dict.txt'

In [None]:
# get the first half of combined_data as the initial training set
shards = 2
initial_train_data = combined_data.shard(2, 0, contiguous=True, keep_in_memory=True)
data_pool = combined_data.shard(2, 1, contiguous=True, keep_in_memory=True)

In [None]:
# # In the active learning framework
# data_pool = combined_data
X_pool = data_pool
y_pool = data_pool["target"]
X_test = hf_test_data
y_test = hf_test_data["target"]
# # X_pool = X_pool[:1000]
# # y_pool = y_pool[:1000]

X_initial = initial_train_data
y_initial = initial_train_data["target"]

In [None]:
# from scipy.stats import spearmanr

# spearmanr(test_output.label_ids[0], test_output.predictions[0])[0]


In [None]:
class ModALModelWrapper:
    """
    Trainer class that is compatible with modAL active learning framework.
    """
    # default model_config args
    default_model_config = {
        "input_dim":512,
        "inner_dim":512,
        "num_classes":1,
        "dropout":0,
        "decoder_type":"mlp",
    }

    default_training_args = {
        "output_dir": None,
        "num_train_epochs": 12,
        "per_device_train_batch_size": 64,
        "per_device_eval_batch_size": 256,
        "dataloader_num_workers": 4,
        "remove_unused_columns": False,
        "logging_dir": "./logs",
        "fp16": True,
        "logging_steps": 100,
        "evaluation_strategy": "steps",
        "save_strategy": "steps",
        "save_steps": 500,
        "eval_steps": 100,
        "report_to": "wandb",
        "label_names": ["target", "smi_name"],
        "load_best_model_at_end": True,
        "optim": "adamw_torch",
        "metric_for_best_model": "relaxed_spearman",
    }


    def __init__(self, weight_path, dict_path, output_path, eval_dataset=None, **kwargs):
        model_config_kwargs = self.default_model_config.copy()
        model_config_kwargs.update(kwargs)
        self.model_config = UniMolConfig(**model_config_kwargs)
        self.default_training_args = self.default_training_args.copy()
        self.default_training_args["output_dir"] = output_path

        self.model_backbone, self.dictionary = init_unimol_backbone(
            weight_path, dict_path=dict_path
        )
        self.eval_dataset = eval_dataset

    def _init_trainer(self, train_dataset, training_args):
        training_arguments = self.default_training_args.copy()
        training_arguments.update(training_args)
        self.training_arguments = TrainingArguments(**training_arguments)

        model = UniMol(self.model_backbone, self.model_config, self.dictionary)

        optimizer = FusedAdam(
            model.parameters(),
            lr=1e-4,
            eps=1e-6,
            betas=(0.9, 0.99),
        )

        warmup_ratio = 0.06
        training_steps = len(train_dataset) * training_arguments["num_train_epochs"]
        warmup_steps = int(training_steps * warmup_ratio)

        scheduler = transformers.get_polynomial_decay_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=training_steps,
        )

        return Trainer(
            model=model,
            args=self.training_arguments,
            train_dataset=train_dataset,
            eval_dataset=self.eval_dataset,
            data_collator=collate_fn,
            compute_metrics=compute_metrics,
            tokenizer=None,
            optimizers=(optimizer, scheduler),
        )
    
    @property
    def model(self):
        return self.trainer.model
    
    @property
    def train_dataset(self):
        return self.trainer.train_dataset
    
    @train_dataset.setter
    def train_dataset(self, dataset):
        self.trainer.train_dataset = dataset

    def fit(self, X, y=None, **training_args):
        """
        Fit the model with the given input data.
        """
        self.trainer = self._init_trainer(X, training_args)
        # self.trainer.train_dataset = X
        self.trainer.train()

    def predict(self, X, return_std=False):
        """
        Predict the target values for the given input data.
        """
        if not hasattr(self, "trainer"):
            raise ValueError("Model not trained yet. Usually you should call ModALModelWrapper.fit first.")
        predictions = self.trainer.predict(X)
        smi_names = predictions.label_ids[1]
        predictions = predictions.predictions[0]
        if return_std:
            # predictions have repeated smi_names, compute the std per smi_name
            import pandas as pd

            df = pd.DataFrame(
                {
                    "smi_name": smi_names,
                    "prediction": predictions,
                }
            )
            std = df.groupby("smi_name").std().values
            # map it back to length of predictions
            std = np.array([std[smi_names == smi][0] for smi in smi_names])
            assert len(std) == len(predictions)
            return predictions, std

        
        return predictions

In [None]:
model_wrapper = ModALModelWrapper(weight_path, dict_path, output_path, eval_dataset=hf_test_data)
# model_wrapper.fit(initial_train_data, num_train_epochs=6)

Model loaded


In [None]:
# initialize ActiveLearner
learner = ActiveLearner(
    estimator=model_wrapper,
    query_strategy=max_std_sampling,
    X_training=X_initial,
    y_training=y_initial,
    num_train_epochs=6,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss,Validation Loss,Pearson,Spearman,Relaxed Spearman
100,23.4896,16.697338,0.193469,0.167758,0.168234
200,14.1758,16.860138,0.530302,0.535511,0.536151
300,13.7764,16.135302,0.628915,0.675216,0.676017
400,13.1751,16.493685,0.640179,0.689035,0.689955
500,12.1358,15.154257,0.671557,0.688011,0.689018
600,11.5065,13.323022,0.651926,0.660493,0.661202
700,11.285,17.330334,0.664367,0.67138,0.672253
800,10.5079,16.075773,0.645682,0.662008,0.663138


In [None]:
X_pool

Dataset({
    features: ['src_tokens', 'src_coord', 'src_distance', 'src_edge_type', 'target', 'smi_name'],
    num_rows: 9504
})

In [None]:
# the active learning loop
n_queries = 10
for idx in range(n_queries):
    query_idx, query_instance = learner.query(X_pool, n_instances=100)
    print(query_idx)
    print(query_instance)
    learner.teach(X_pool[query_idx], y_pool[query_idx], only_new=True, num_train_epochs=6)
    # remove queried instance from pool
    # X_pool = np.delete(X_pool, query_idx, axis=0)
    # y_pool = np.delete(y_pool, query_idx, axis=0)

In [None]:
# # visualizing the data
# with plt.style.context('seaborn-v0_8-bright'):
#     plt.figure(figsize=(7, 7))
#     plt.scatter(X, y, c='k')
#     plt.title('Noisy absolute value function')
#     plt.show()

In [None]:
# # initialize the model

# # five fold cross validation and init five different models

# learner_list = []

# # initializing the Committee
# committee = CommitteeRegressor(
#     learner_list=learner_list,
#     query_strategy=max_std_sampling
# )