In [1]:
import os
import numpy as np
from datetime import datetime
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from train_utils_regr import (
    get_datasets,
    get_config,
    get_preds,
    compute_metrics,
    save_regression_outputs,
)


os.environ["TOKENIZERS_PARALLELISM"] = "true"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [2]:
datasplits_dir = "splits_biogen"
checkpoints_dir = "checkpoints"
predictions_dir = "predictions"
task_type = "regression"
task_name = "Sol"
target_label = "activity"
split_type = "original"

MODEL_NAME = "korolewadim/gselformer-large"
BATCH_SIZE = 32
NUM_EPOCHS = 50
WARMUP_RATIO = 0.0
LEARNING_RATE = 2e-5
hyperparameters_str = f"batch_size_{BATCH_SIZE}__num_epochs_{NUM_EPOCHS}__warmup_ratio_{WARMUP_RATIO}__lr_{LEARNING_RATE}"

SEEDS = [1574211741, 32662977, 54088142, 1593098056, 4245326646]
NUM_GROUP_SELFIES_ = [1]

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/958 [00:00<?, ?B/s]

In [3]:
for num_gr_sf in NUM_GROUP_SELFIES_:
    for SEED in SEEDS:
        num_group_selfies = str(num_gr_sf)
        train_dataset, test_dataset = get_datasets(
            os.path.join(
                datasplits_dir,
                task_type,
                task_name,
                target_label,
                split_type,
                num_group_selfies,
            ),
            tokenizer,
            "cache",
        )
        formatting_time = str(datetime.now().strftime("%Y-%m-%d__%H-%M-%S"))
        model = AutoModelForSequenceClassification.from_pretrained(
            MODEL_NAME, num_labels=1
        )

        output_dir = os.path.join(
            checkpoints_dir,
            task_type,
            task_name,
            target_label,
            split_type,
            num_group_selfies,
            MODEL_NAME,
            hyperparameters_str,
            str(SEED),
            formatting_time,
        )
        preds_dir = os.path.join(
            predictions_dir,
            task_type,
            task_name,
            target_label,
            split_type,
            num_group_selfies,
            MODEL_NAME,
            hyperparameters_str,
            str(SEED),
            formatting_time,
        )
        os.makedirs(output_dir, exist_ok=True)
        os.makedirs(preds_dir, exist_ok=True)

        training_args = TrainingArguments(
            output_dir=output_dir,
            eval_strategy="epoch",
            per_device_train_batch_size=BATCH_SIZE,
            per_device_eval_batch_size=BATCH_SIZE,
            learning_rate=LEARNING_RATE,
            num_train_epochs=NUM_EPOCHS,
            warmup_ratio=WARMUP_RATIO,
            save_strategy="epoch",
            save_total_limit=1,
            no_cuda=False,
            seed=SEED,
            fp16=True,
            dataloader_num_workers=8,
            load_best_model_at_end=False,
            metric_for_best_model="eval_loss",
            dataloader_pin_memory=True,
        )
        trainer = Trainer(
            model,
            training_args,
            train_dataset=train_dataset,
            eval_dataset=test_dataset,
            processing_class=tokenizer,
            compute_metrics=compute_metrics,
        )
        train_result = trainer.train()

        test_mol_indices = test_dataset.data["mol_index"].to_numpy()
        test_selfies = test_dataset.data["selfies"].to_numpy()
        outputs = trainer.predict(test_dataset)
        labels, preds = get_preds(outputs)
        _ = save_regression_outputs(
            labels, preds, test_mol_indices, test_selfies, preds_dir
        )

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

Map:   0%|          | 0/3476 [00:00<?, ? examples/s]

Map:   0%|          | 0/870 [00:00<?, ? examples/s]

Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at korolewadim/gselformer-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'head.dense.weight', 'head.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

Epoch,Training Loss,Validation Loss,R2 Score,Mae,Rmse
1,No log,0.418704,0.248681,0.484478,0.647073
2,No log,0.448223,0.195712,0.485024,0.669495
3,No log,0.660921,-0.185951,0.598548,0.812971
4,No log,0.373839,0.329186,0.435339,0.611424
5,0.352500,0.358903,0.355988,0.416198,0.599085
6,0.352500,0.406843,0.269965,0.531282,0.637842
7,0.352500,0.359227,0.355406,0.409382,0.599355
8,0.352500,0.356414,0.360453,0.413522,0.597004
9,0.352500,0.359967,0.354077,0.410586,0.599973
10,0.056000,0.355347,0.362369,0.402716,0.59611


Map:   0%|          | 0/3476 [00:00<?, ? examples/s]

Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at korolewadim/gselformer-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'head.dense.weight', 'head.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,R2 Score,Mae,Rmse
1,No log,0.441182,0.208347,0.538148,0.664215
2,No log,0.38057,0.317109,0.465091,0.616903
3,No log,0.359163,0.355521,0.441471,0.599302
4,No log,0.377251,0.323064,0.453609,0.614208
5,0.354000,0.37267,0.331284,0.452989,0.610467
6,0.354000,0.377818,0.322046,0.444181,0.614669
7,0.354000,0.341317,0.387543,0.402282,0.584224
8,0.354000,0.39877,0.28445,0.420496,0.631482
9,0.354000,0.355897,0.361381,0.392432,0.596571
10,0.056600,0.348087,0.375395,0.404559,0.589989


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at korolewadim/gselformer-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'head.dense.weight', 'head.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,R2 Score,Mae,Rmse
1,No log,0.651014,-0.168175,0.526757,0.806855
2,No log,0.39978,0.282639,0.476093,0.632281
3,No log,0.391683,0.297168,0.45839,0.625846
4,No log,0.379294,0.319398,0.447827,0.615869
5,0.345700,0.389206,0.301612,0.410778,0.623864
6,0.345700,0.367699,0.340205,0.398907,0.606381
7,0.345700,0.376694,0.324063,0.412436,0.613754
8,0.345700,0.344255,0.382271,0.404286,0.586733
9,0.345700,0.372681,0.331264,0.398086,0.610476
10,0.051300,0.362206,0.35006,0.390883,0.601836


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at korolewadim/gselformer-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'head.dense.weight', 'head.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,R2 Score,Mae,Rmse
1,No log,0.693399,-0.24423,0.598078,0.832706
2,No log,0.426617,0.234482,0.444591,0.653159
3,No log,0.393174,0.294492,0.440729,0.627036
4,No log,0.417317,0.251169,0.455704,0.646001
5,0.330900,0.35911,0.355616,0.399077,0.599258
6,0.330900,0.363866,0.347081,0.413584,0.603213
7,0.330900,0.377251,0.323063,0.456392,0.614208
8,0.330900,0.376867,0.323753,0.453533,0.613895
9,0.330900,0.359348,0.355188,0.411702,0.599457
10,0.051400,0.356256,0.360737,0.398252,0.596872


Some weights of ModernBertForSequenceClassification were not initialized from the model checkpoint at korolewadim/gselformer-large and are newly initialized: ['classifier.bias', 'classifier.weight', 'head.dense.weight', 'head.norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,R2 Score,Mae,Rmse
1,No log,0.42857,0.230977,0.490966,0.654653
2,No log,0.425673,0.236175,0.436693,0.652436
3,No log,0.442941,0.205191,0.509542,0.665538
4,No log,0.51317,0.079172,0.628916,0.716359
5,0.325600,0.398683,0.284606,0.447793,0.631414
6,0.325600,0.393383,0.294116,0.413925,0.627203
7,0.325600,0.38499,0.309178,0.422475,0.620475
8,0.325600,0.370979,0.334318,0.400829,0.60908
9,0.325600,0.366546,0.342274,0.393103,0.60543
10,0.050300,0.400382,0.281558,0.415157,0.632758
