In [17]:
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import numpy as np
import scanpy as sc
import time
import copy
from typing import List, Tuple, Dict, Union, Optional
from scipy.sparse import issparse
from sklearn.model_selection import train_test_split

from utils import set_seed, AttrDict
from myTorchtext import Vocab
from preprocess import Preprocessor
from tokenizer import tokenize_and_pad_batch, retrieve_tfs, random_mask_value
from model import TransformerModel
from loss import masked_mse_loss, masked_relative_error, criterion_neg_log_bernoulli

In [40]:
config = AttrDict({
    "run_name": "",
    "dataset_name": "HYPOXIA_9K",
    "seed": 42,
    "ntokens": 10000,
    "d_model": 512,
    "nhead": 8,
    "d_hid": 2048,
    "nlayers": 6,
    "n_bins": 51,
    "n_hvg": 100,
    "include_zero_gene": False,
    "mask_single_value": False,
    "dropout": 0.2,
    "batch_size": 32,
    "epochs": 1,
    "lr": 0.0001,
    "amp": True,
    "schedule_ratio": 0.9,
    "GEPC": False,   # If Gene Expression Prediction for Cell Modelling objective (MLM from <cls> only)
    "explicit_zero_prob": True, # if modelling gene expression also as bern var
    "do_train": True,
    "log_interval": 100,
    "wandb": True,
})

set_seed(config.seed)

if config.wandb:
    import wandb
    wandb.login()
    run = wandb.init(
        project='BioFormer',
        config = config,
        name = config.run_name if config.run_name else None
    )



0,1
epoch,▁
train/mse,█▇▆▅▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
train/nzlp,█▆▄▃▃▃▃▂▂▃▂▂▃▃▂▂▂▃▃▃▂▂▂▃▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
valid/mre,▁
valid/mse,▁

0,1
epoch,1.0
train/mse,165.19212
train/nzlp,0.00104
valid/mre,13912.42714
valid/mse,153.74775


In [24]:
# Pre-processing
pad_token = "<pad>"
special_tokens = [pad_token, "<cls>", "<eoc>"]
mask_value = -1
pad_value = -2
n_input_bins = config.n_bins
include_zero_gene = config.include_zero_gene
n_hvg = config.n_hvg
max_seq_len = n_hvg + 1

In [25]:
# Import data
dataset_name = config.dataset_name

if dataset_name == 'BREAST_25K':
    adata = sc.read_h5ad('./data/breast_25k.h5ad')
    data_is_raw = True

elif dataset_name == 'BREAST_12K':
    adata = sc.read_h5ad('./data/breast_12k.h5ad')
    data_is_raw = True

elif dataset_name == 'DERMAL_100K':
    adata = sc.read_h5ad('./data/dermal_100k.h5ad')
    adata.var["gene_name"] = adata.var.feature_name.tolist()
    data_is_raw = True

elif dataset_name == 'HYPOXIA_9K':
    adata = sc.read_h5ad('./data/scsHypoxiaTimeSub.h5ad')
    adata.X = adata.layers['raw_count']
    adata.var['gene_name'] = adata.var.index.tolist()
    data_is_raw = True

print(dataset_name)
print(adata)

HYPOXIA_9K
AnnData object with n_obs × n_vars = 9234 × 19046
    obs: 'nCount_RNA', 'nFeature_RNA', 'SampleTags', 'percent.mt', 'HypoxicState', 'TimePoint', 'nCount_SCT', 'nFeature_SCT', 'S.Score', 'G2M.Score', 'Phase', 'seurat_clusters', 'SampleTagsShort', 'active_ident'
    var: 'variable_gene', 'gene_name'
    uns: 'active_ident_colors', 'seurat_clusters_colors'
    obsm: 'X_pca', 'X_umap'
    layers: 'raw_count'


In [26]:
# Pre-process adata
preprocessor = Preprocessor(
    use_key="X",  # the key in adata.layers to use as raw data
    filter_gene_by_counts=3,  # step 1
    filter_cell_by_counts=False,  # step 2
    normalize_total=1e4,  # 3. whether to normalize the raw data and to what sum
    result_normed_key="X_normed",  # the key in adata.layers to store the normalized data
    log1p=data_is_raw,  # 4. whether to log1p the normalized data
    result_log1p_key="X_log1p",
    subset_hvg=config.n_hvg,  # 5. whether to subset the raw data to highly variable genes
    hvg_flavor="seurat_v3" if data_is_raw else "cell_ranger",
    binning=config.n_bins,  # 6. whether to bin the raw data and to what number of bins
    result_binned_key="X_binned",  # the key in adata.layers to store the binned data
)

preprocessor(adata, batch_key=None)


Filtering genes by counts ...
Normalizing total counts ...
Log1p transforming ...
Subsetting highly variable genes ...
No batch_key is provided, will use all cells for HVG selection.
Binning data ...


In [27]:
input_layer_key = "X_binned"
all_counts = (
    adata.layers[input_layer_key].toarray()
    if issparse(adata.layers[input_layer_key])
    else adata.layers[input_layer_key]
)
genes = adata.var["gene_name"].tolist()

train_data, valid_data = train_test_split(all_counts, test_size=0.1, shuffle=True)

# Vocab
stoi = {s:i for i, s in enumerate(genes + special_tokens)}
itos = {i:s for i, s in enumerate(genes + special_tokens)}
vocab = Vocab(stoi, itos)
vocab.set_default_index(vocab["<pad>"]) # index to return if token not found in vocab
gene_ids = np.array(vocab(genes), dtype=int)
print(f'Vocab of size: {len(vocab)} --> {len(genes)} genes, {len(special_tokens)} special tokens {special_tokens}')

Vocab of size: 103 --> 100 genes, 3 special tokens ['<pad>', '<cls>', '<eoc>']


In [28]:
tokenized_train = tokenize_and_pad_batch(
    train_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,  # append <cls> token at the beginning
    include_zero_gene=include_zero_gene,
)
tokenized_valid = tokenize_and_pad_batch(
    valid_data,
    gene_ids,
    max_len=max_seq_len,
    vocab=vocab,
    pad_token=pad_token,
    pad_value=pad_value,
    append_cls=True,
    include_zero_gene=include_zero_gene,
)
print(f"Train samples: {tokenized_train['genes'].shape[0]}")
print(f"Valid samples: {tokenized_valid['genes'].shape[0]}")
print(f"Input length: {tokenized_valid['genes'].shape[1]}")

Train samples: 8310
Valid samples: 924
Input length: 68


In [29]:
# # Import TFs Lookup Table
# if config.use_condition_labels:
#     tf = pd.read_csv(r'./data/transcriptional_interactions.csv',
#                             index_col=0, 
#                             low_memory=False,
#                             dtype={'source_genesymbol': str, 'target_genesymbol':str})[['source_genesymbol', 'target_genesymbol', 'is_stimulation', 'is_inhibition']].rename(columns={'source_genesymbol': 'source', 'target_genesymbol':'target'})
#     tf = tf.drop_duplicates()
#     tf = tf[~(tf.is_inhibition  == tf.is_stimulation)]    # drop rows where both are 1s or 0s

#     source_in_vocab = tf.source.isin(vocab.get_stoi().keys())
#     target_in_vocab = tf.target.isin(vocab.get_stoi().keys())
#     both_in_vocab = source_in_vocab * target_in_vocab

#     print('Unique Sources in GeneVocab:',
#             tf[source_in_vocab].source.unique().shape[0])
#     print('Unique Targets in GeneVocab:',
#             tf[target_in_vocab].target.unique().shape[0])

#     tf = tf[both_in_vocab]

#     print('Unique Pairs in GeneVocab:',
#             tf[~tf[['source', 'target']].duplicated()].shape[0])

#     print('Stimulation Interactions:',
#             tf.is_stimulation.sum())

#     print('Inhibition Interactions:',
#             tf.is_inhibition.sum())

#     tf.reset_index(inplace=True)

In [30]:
def prepare_data(use_condition_labels = False):
    
    masked_values_train = random_mask_value(
        tokenized_train["values"],
        mask_value=mask_value,
        pad_value=pad_value,
        mask_single_value = config.mask_single_value
    )
    masked_values_valid = random_mask_value(
        tokenized_valid["values"],
        mask_value=mask_value,
        pad_value=pad_value,
        mask_single_value = config.mask_single_value
    )
    print(
        f"random masking at epoch {epoch:3d}, ratio of masked values in train: ",
        f"{(masked_values_train == mask_value).sum() / (masked_values_train - pad_value).count_nonzero():.4f}",
    )


    input_gene_ids_train, input_gene_ids_valid = (
        tokenized_train["genes"],
        tokenized_valid["genes"],
    )
    input_values_train, input_values_valid = masked_values_train, masked_values_valid
    target_values_train, target_values_valid = (
        tokenized_train["values"],
        tokenized_valid["values"],
    )

    train_data_pt = {
        "gene_ids": input_gene_ids_train,
        "values": input_values_train,
        "target_values": target_values_train,
    }
    valid_data_pt = {
        "gene_ids": input_gene_ids_valid,
        "values": input_values_valid,
        "target_values": target_values_valid,
    }

    if use_condition_labels:
        train_data_pt['conditions'] = retrieve_tfs(
            input_gene_ids_train,
            input_values_train,     # masked
            tf = tf                                  
        )
        valid_data_pt['conditions'] = retrieve_tfs(
            input_gene_ids_valid,
            input_values_valid,      # masked
            tf = tf                                  
        )

    return train_data_pt, valid_data_pt

# dataset
class SeqDataset(Dataset):
    def __init__(self, data: Dict[str, torch.Tensor]):
        self.data = data

    def __len__(self):
        return self.data["gene_ids"].shape[0]

    def __getitem__(self, idx):
        return {k: v[idx] for k, v in self.data.items()}


# data_loader
def prepare_dataloader(
    data_pt: Dict[str, torch.Tensor],
    batch_size: int,
    shuffle: bool = False,
    drop_last: bool = False,
    num_workers: int = 0,
) -> DataLoader:
    dataset = SeqDataset(data_pt)

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        drop_last=drop_last,
        num_workers=num_workers,
        pin_memory=True,
    )
    return data_loader


In [43]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ntoken = len(vocab)  # size of vocabulary
model = TransformerModel(
    ntoken=ntoken,
    d_model=config.d_model,
    nhead=config.nhead,
    d_hid=config.d_model,
    nlayers=config.nlayers,
    vocab=vocab,
    dropout=config.dropout,
    pad_token=pad_token,
    pad_value=pad_value,

    # use_condition_labels=config.use_condition_labels,
    # num_condition_labels= 3 if config.use_condition_labels else None, # (-1, 0, 1)
)
# if config.load_model is not None:
#     load_pretrained(model, torch.load(model_file), verbose=False)
#     print('Pre-trained model successfully loaded')
    
model.to(device)
# wandb.watch(model)

print(f'''
device: {device} | d_model: {config.d_hid} | nhead: {config.nhead} | nlayers: {config.nlayers} | tot. params: {sum(p.numel() for p in model.parameters())/1e6:.0f}M
''')



device: cpu | d_model: 2048 | nhead: 8 | nlayers: 6 | tot. params: 12M



In [32]:
criterion = masked_mse_loss
criterion_dab = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    model.parameters(), lr=config.lr, eps=1e-4 if config.amp else 1e-8
)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=config.schedule_ratio)
scaler = torch.cuda.amp.GradScaler(enabled=config.amp)



In [33]:
def train(model: nn.Module, loader: DataLoader) -> None:
    """
    Train the model for one epoch.
    """
    model.train()
    total_loss, total_mse, total_gepc = 0.0, 0.0, 0.0
    total_error = 0.0
    log_interval = config.log_interval
    start_time = time.time()

    num_batches = len(loader)
    for batch, batch_data in enumerate(loader):
        input_gene_ids = batch_data["gene_ids"].to(device)
        input_values = batch_data["values"].to(device)
        target_values = batch_data["target_values"].to(device)
        # input_conditions = batch_data["conditions"].to(device) if 'conditions' in batch_data.keys() else None

        src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
        with torch.cuda.amp.autocast(enabled=config.amp):
            output_dict = model(
                input_gene_ids,
                input_values,
                # src_key_padding_mask=src_key_padding_mask,
                # batch_labels=batch_labels if DSBN else None,
                # conditions = input_conditions,
                # MVC=config.GEPC,
                # ECS=config.ecs_thres > 0,
            )

            masked_positions = input_values.eq(mask_value)  # the postions to predict
            loss = loss_mse = criterion(output_dict["mlm_output"], target_values, masked_positions)
            metrics_to_log = {"train/mse": loss_mse.item()}
            
            if config.explicit_zero_prob:
                loss_zero_log_prob = criterion_neg_log_bernoulli(output_dict["mlm_zero_probs"], target_values, masked_positions)
                loss += loss_zero_log_prob
                metrics_to_log.update({"train/nzlp": loss_zero_log_prob.item()})
            
            if config.GEPC:
                loss_gepc = criterion(output_dict["mvc_output"], target_values, masked_positions)
                loss += loss_gepc
                metrics_to_log.update({"train/mvc": loss_gepc.item()})
            
            if config.GEPC and config.explicit_zero_prob:
                loss_gepc_zero_log_prob = criterion_neg_log_bernoulli(output_dict["mvc_zero_probs"], target_values, masked_positions)
                loss = loss + loss_gepc_zero_log_prob
                metrics_to_log.update({"train/mvc_nzlp": loss_gepc_zero_log_prob.item()})
            
            # if config.ecs_thres > 0:
            #     loss_ecs = 10 * output_dict["loss_ecs"]
            #     loss = loss + loss_ecs
            #     metrics_to_log.update({"train/ecs": loss_ecs.item()})
            # loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)
            # loss = loss + config.dab_weight * loss_dab
            # metrics_to_log.update({"train/dab": loss_dab.item()})

        model.zero_grad()
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        # with warnings.catch_warnings(record=True) as w:
        #     warnings.filterwarnings("always")
        #     torch.nn.utils.clip_grad_norm_(
        #         model.parameters(),
        #         1.0,
        #         error_if_nonfinite=False if scaler.is_enabled() else True,
        #     )
        #     if len(w) > 0:
        #         print(
        #             f"Found infinite gradient. This may be caused by the gradient "
        #             f"scaler. The current scale is {scaler.get_scale()}. This warning "
        #             "can be ignored if no longer occurs after autoscaling of the scaler."
        #         )
        scaler.step(optimizer)
        scaler.update()

        wandb.log(metrics_to_log)

        with torch.no_grad():
            mre = masked_relative_error(output_dict["mlm_output"], target_values, masked_positions)

        total_loss += loss.item()
        total_mse += loss_mse.item()
        total_gepc += loss_gepc.item() if config.GEPC else 0.0
        total_error += mre.item()
        if batch % log_interval == 0 and batch > 0:
            lr = scheduler.get_last_lr()[0]
            ms_per_batch = (time.time() - start_time) * 1000 / log_interval
            cur_loss = total_loss / log_interval
            cur_mse = total_mse / log_interval
            cur_gepc = total_gepc / log_interval if config.GEPC else 0.0
            cur_error = total_error / log_interval
            # ppl = math.exp(cur_loss)
            print(
                f"| epoch {epoch:3d} | {batch:3d}/{num_batches:3d} batches | "
                f"lr {lr:05.4f} | ms/batch {ms_per_batch:5.2f} | "
                f"loss {cur_loss:5.2f} | mse {cur_mse:5.2f} | mre {cur_error:5.2f} |"
                + (f"gepc {cur_gepc:5.2f} |" if config.GEPC else "")
            )
            total_loss = 0
            total_mse = 0
            total_gepc = 0
            total_error = 0
            start_time = time.time()



In [34]:
# def define_wandb_metrcis():
#     wandb.define_metric("valid/mse", summary="min", step_metric="epoch")
#     wandb.define_metric("valid/mre", summary="min", step_metric="epoch")
#     wandb.define_metric("valid/dab", summary="min", step_metric="epoch")
#     wandb.define_metric("valid/sum_mse_dab", summary="min", step_metric="epoch")
#     wandb.define_metric("test/avg_bio", summary="max")

In [35]:
def evaluate(model: nn.Module, loader: DataLoader) -> float:
    """
    Evaluate the model on the evaluation data.
    """
    model.eval()
    total_loss = 0.0
    total_error = 0.0
    total_dab = 0.0
    total_num = 0
    with torch.no_grad():
        for batch_data in loader:
            input_gene_ids = batch_data["gene_ids"].to(device)
            input_values = batch_data["values"].to(device)
            target_values = batch_data["target_values"].to(device)
            # input_conditions = batch_data["conditions"].to(device) if 'conditions' in batch_data.keys() else None

            # src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
            with torch.cuda.amp.autocast(enabled=config.amp):
                output_dict = model(
                    input_gene_ids,
                    input_values,
                    # src_key_padding_mask=src_key_padding_mask,
                    # batch_labels=batch_labels if DSBN else None,
                    # conditions = input_conditions,
                )
                output_values = output_dict["mlm_output"]

                masked_positions = input_values.eq(mask_value)
                loss = criterion(output_values, target_values, masked_positions)
                # loss_dab = criterion_dab(output_dict["dab_output"], batch_labels)

            total_loss += loss.item() * len(input_gene_ids)
            total_error += masked_relative_error(
                output_values, target_values, masked_positions
            ).item() * len(input_gene_ids)
            # total_dab += loss_dab.item() * len(input_gene_ids)
            total_num += len(input_gene_ids)

    wandb.log(
        {
            "valid/mse": total_loss / total_num,
            "valid/mre": total_error / total_num,
            # "valid/dab": total_dab / total_num,
            # "valid/sum_mse_dab": (total_loss + config.dab_weight * total_dab)
            # / total_num,
            "epoch": epoch,
        },
    )

    return total_loss / total_num, total_error / total_num


In [36]:
best_val_loss = float("inf")
best_avg_bio = 0.0
best_model = None
# define_wandb_metrcis()

for epoch in range(1, config.epochs + 1):
    epoch_start_time = time.time()
    
    train_data_pt, valid_data_pt = prepare_data()
    
    train_loader = prepare_dataloader(
        train_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        drop_last=False,
    )
    valid_loader = prepare_dataloader(
        valid_data_pt,
        batch_size=config.batch_size,
        shuffle=False,
        drop_last=False,
    )

    if config.do_train:
        train(
            model,
            loader=train_loader,
        )

    val_loss, val_mre = evaluate(
        model,
        loader=valid_loader,
    )
    elapsed = time.time() - epoch_start_time
    print("-" * 89)
    print(
        f"| end of epoch {epoch:3d} | time: {elapsed:5.2f}s | "
        f"valid loss/mse {val_loss:5.4f} | mre {val_mre:5.4f}"
    )
    print("-" * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = copy.deepcopy(model)
        best_model_epoch = epoch
        print(f"Best model with score {best_val_loss:5.4f}")

    # if epoch % config.save_eval_interval == 0 or epoch == config.epochs:
    #     print(f"Saving model to {save_dir}")
    #     torch.save(best_model.state_dict(), save_dir / f"model_e{best_model_epoch}.pt")

    scheduler.step()

random masking at epoch   1, ratio of masked values in train:  0.1383




| epoch   1 | 100/260 batches | lr 0.0001 | ms/batch 2718.67 | loss 351.74 | mse 351.74 | mre 446295.37 |
| epoch   1 | 200/260 batches | lr 0.0001 | ms/batch 1879.34 | loss 185.04 | mse 185.04 | mre 377947.74 |
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 586.96s | valid loss/mse 153.7478 | mre 13912.4271
-----------------------------------------------------------------------------------------
Best model with score 153.7478


In [41]:
if config.wandb:
    run.finish()

In [37]:


# def eval_testdata(
#     model: nn.Module,
#     adata_t: AnnData,
#     include_types: List[str] = ["cls"],
# ) -> Optional[Dict]:
#     """evaluate the model on test dataset of adata_t"""
#     model.eval()

#     # copy adata_t to avoid reuse previously computed results stored in adata_t
#     adata_t = adata_t.copy()

#     all_counts = (
#         adata_t.layers[input_layer_key].A
#         if issparse(adata_t.layers[input_layer_key])
#         else adata_t.layers[input_layer_key]
#     )

#     celltypes_labels = adata_t.obs["celltype"].tolist()
#     celltypes_labels = np.array(celltypes_labels)

#     batch_ids = adata_t.obs["batch_id"].tolist()
#     batch_ids = np.array(batch_ids)

#     # Evaluate cls cell embeddings
#     if "cls" in include_types:
#         logger.info("Evaluating cls cell embeddings")
#         tokenized_all = tokenize_and_pad_batch(
#             all_counts,
#             gene_ids,
#             max_len=max_seq_len,
#             vocab=vocab,
#             pad_token=pad_token,
#             pad_value=pad_value,
#             append_cls=True,  # append <cls> token at the beginning
#             include_zero_gene=True,
#         )
#         all_gene_ids, all_values = tokenized_all["genes"], tokenized_all["values"]
#         src_key_padding_mask = all_gene_ids.eq(vocab[pad_token])
#         with torch.no_grad(), torch.cuda.amp.autocast(enabled=config.amp):
#             cell_embeddings = model.encode_batch(
#                 all_gene_ids,
#                 all_values.float(),
#                 src_key_padding_mask=src_key_padding_mask,
#                 batch_size=config.batch_size,
#                 batch_labels=torch.from_numpy(batch_ids).long() if DSBN else None,
#                 time_step=0,
#                 return_np=True,
#             )
#         cell_embeddings = cell_embeddings / np.linalg.norm(
#             cell_embeddings, axis=1, keepdims=True
#         )

#         adata_t.obsm["X_scGPT"] = cell_embeddings

#         results = {}
#         try:
#             results = eval_scib_metrics(adata_t)
#         except Exception as e:
#             traceback.print_exc()
#             logger.error(e)

#         sc.pp.neighbors(adata_t, use_rep="X_scGPT")
#         sc.tl.umap(adata_t, min_dist=0.3)
#         fig = sc.pl.umap(
#             adata_t,
#             color=["str_batch"],
#             title=[f"batch, avg_bio = {results.get('avg_bio', 0.0):.4f}"],
#             frameon=False,
#             return_fig=True,
#             show=False,
#         )

#         results["batch_umap"] = fig

#         sc.pp.neighbors(adata_t, use_rep="X_scGPT")
#         sc.tl.umap(adata_t, min_dist=0.3)
#         fig = sc.pl.umap(
#             adata_t,
#             color=["celltype"],
#             title=[
#                 f"celltype, avg_bio = {results.get('avg_bio', 0.0):.4f}",
#             ],
#             frameon=False,
#             return_fig=True,
#             show=False,
#         )

#         results["celltype_umap"] = fig

#     if len(include_types) == 1:
#         return results



# # Investigate one batch
# batch_data = next(iter(train_loader))
# input_gene_ids = batch_data["gene_ids"].to(device)
# input_values = batch_data["values"].to(device)
# target_values = batch_data["target_values"].to(device)
# input_conditions = batch_data["conditions"].to(device) if 'conditions' in batch_data.keys() else None
# src_key_padding_mask = input_gene_ids.eq(vocab[pad_token])
# with torch.cuda.amp.autocast(enabled=config.amp):
#     output_dict = model(
#         input_gene_ids,
#         input_values,
#         src_key_padding_mask=src_key_padding_mask,
#         # batch_labels=batch_labels if DSBN else None,
#         # conditions = input_conditions,
#         # MVC=config.GEPC,
#         # ECS=config.ecs_thres > 0,
#     )

# masked_positions = input_values.eq(mask_value)  # the postions to predict
# print(output_dict['mlm_output'][masked_positions])
# print(target_values[masked_positions])

# save the best model
# torch.save(best_model.state_dict(), save_dir / "best_model.pt")

# artifact = wandb.Artifact(f"best_model", type="model")
# glob_str = os.path.join(save_dir, "best_model.pt")
# artifact.add_file(glob_str)
# run.log_artifact(artifact)

# run.finish()
# wandb.finish()
# gc.collect()