# Fine-tuning on Pre-trained Model with Batch Integration
In this tutorial, we demonstrate how to fine-tune a pre-trained model on a new dataset for the batch integration task. We use the PBMC 10K dataset as an example and fine-tune on the pre-trained whole-body model. 

We summarize the fine-tuning pipeline in the following steps, which can be used as a general recipe for finetuning on integration tasks and beyond: 

     1. Specify hyper-parameter setup for integration task
     
     2. Load pre-trained scGPT model
     
     3. Load and pre-process data
     
     4. Finetune scGPT with task-specific objectives
     
     5. Evaluate fine-tuned scGPT


In [1]:
# Import required packages
import copy
import gc
import json
import os
from pathlib import Path
import sys
import time
import traceback
from typing import List, Tuple, Dict, Union, Optional
import warnings

import torch
from anndata import AnnData
import scanpy as sc
import scvi
import numpy as np
import wandb
from scipy.sparse import issparse
import matplotlib.pyplot as plt
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from torchtext.vocab import Vocab
from torchtext._torchtext import (
    Vocab as VocabPybind,
)

from scgpt.tokenizer.gene_tokenizer import GeneVocab

import scgpt as scg
from scgpt.model import TransformerModel, AdversarialDiscriminator
from scgpt.tokenizer import tokenize_and_pad_batch, random_mask_value
from scgpt.loss import (
    masked_mse_loss,
    masked_relative_error,
    criterion_neg_log_bernoulli,
)
from scgpt.preprocess import Preprocessor
from scgpt import SubsetsBatchSampler
from scgpt.utils import set_seed, category_str2int, eval_scib_metrics

sc.set_figure_params(figsize=(4, 4))
os.environ["KMP_WARNINGS"] = "off"

Global seed set to 0
  IPython.display.set_matplotlib_formats(*ipython_format)


In [2]:
set_seed(42)

# Step 1: Specify hyper-parameter setup for integration task

In [None]:
hyperparameter_defaults = dict(
    seed=42,
    dataset_name="PBMC_10K", # name of the dataset to fine-tune on
    do_train=True, # whether to update model parameters during fine-tuning
    load_model="/scratch/ssd004/datasets/cellxgene/save/cellxgene_census_human-May23-08-36-2023", # path to pre-trained model
    mask_ratio=0.4, # default mask_ratio
    epochs=30, # default fine-tuning epochs
    n_bins=51, # default number of bins
    GEPC=True, # Masked value prediction for cell embedding
    ecs_thres=0.8,  # Elastic cell similarity objective, 0.0 to 1.0, 0.0 to disable
    dab_weight=1.0,
    lr=1e-4,
    batch_size=64,
    layer_size=128,
    nlayers=4,
    nhead=4,
    # if load model, batch_size, layer_size, nlayers, nhead will be ignored
    dropout=0.2,
    schedule_ratio=0.9,  # ratio of epochs for learning rate schedule
    save_eval_interval=5,
    log_interval=100,
    fast_transformer=True,
    pre_norm=False,
    amp=True, # Automatic Mixed Precision
)

In [4]:
hyperparameter_defaults

NameError: name 'hyperparameter_defaults' is not defined

In [None]:
"/scratch/ssd004/datasets/cellxgene/save/cellxgene_census_human-May23-08-36-2023"

In [None]:
run = wandb.init(
    config=hyperparameter_defaults,
    project="scGPT",
    reinit=True,
    settings=wandb.Settings(start_method="fork"),
)