# Fine-tuning BART for GC-MS using a PPO

Fine-tuning a BART model with the TRLX library providing RLHF training pipelines

In [2]:
!pwd

/auto/brno2/home/ahajek/Spektro/MassGenie/RLHF


In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
import sys
sys.path.append('../data')
sys.path.append('../bart_spektro')
sys.path.append('..')
sys.path.append('./trlx_GC_MS_BART/')

In [34]:
#### POTREBUJE PROTRIDIT
import os,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)


import argparse
from datetime import datetime
import json
import os
import pickle
import random
import time
import wandb
from pathlib import Path
import gc
import glob
from tqdm import tqdm

import numpy as np
from transformers import TrainingArguments, Trainer, BartConfig, BartForConditionalGeneration, logging, AutoTokenizer
# from transformers.file_utils import logging
# from tensorboardX import SummaryWriter
import torch

# custom veci
from dataset import SpectroDataset, SpectroDataCollator
from modeling_bart_spektro import BartSpektoForConditionalGeneration
from configuration_bart_spektro import BartSpektroConfig
from data_preprocess1 import print_args
from bart_spektro_tokenizer import BartSpektroTokenizer
from tokenizers import Tokenizer

import trlx
from trlx.models.modeling_ppo import PPOConfig
from trlx.data.configs import (ModelConfig, 
                               TrainConfig,
                               SchedulerConfig,
                               TokenizerConfig,
                               OptimizerConfig,
                               TRLConfig)

## Main Train parameter Setting

In [6]:
n_epochs=10
bs = 2

## Load BART, tokenizer, data

In [7]:
# model
model_path = f'../checkpoints/bart_2022-10-14-16_15_31_ft_12M_derivatized/checkpoint-58536/'
model_name = ''.join(model_path.split('/')[-3])
model = BartSpektoForConditionalGeneration.from_pretrained(model_path)

device = 'cuda' # 'cpu' alternatively
model.to(device)
None

In [8]:
# tokenizer
from transformers import PreTrainedTokenizerFast

tokenizer_path = "../tokenizer/bbpe_tokenizer/bart_bbpe_1M_transformers_format"
# tokenizer = Tokenizer.from_file(tokenizer_path)

In [9]:
# data
train_data_path = "../data/trial_set/1K_bbpe_1M_bart_prepared_data_train.pkl"
valid_data_path = "../data/trial_set/1K_bbpe_1M_bart_prepared_data_valid.pkl"

train_data = SpectroDataset(train_data_path, original=False)
valid_data = SpectroDataset(valid_data_path, original=False)

In [10]:
train_data.data.head(1)

Unnamed: 0,destereo_smiles,input_ids,decoder_input_ids,encoder_attention_mask,decoder_attention_mask,labels,position_ids
629,COCCN1C(=O)C(=O)N(C1=O)CC(=O)c1c(N)n(C)c(=O)n(...,"[15, 28, 29, 30, 31, 32, 33, 39, 40, 41, 42, 4...","[3, 224, 325, 20, 38, 260, 50, 12, 38, 260, 50...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[3, 224, 325, 20, 38, 260, 50, 12, 38, 260, 50...","[3, 1, 2, 5, 5, 4, 4, 5, 6, 7, 8, 6, 7, 9, 6, ..."


## Convert the data to TRLX friendly format

In [11]:
# prompts, ratings =  ... sth like this?? spis ne - spis dict + reward fce

In [12]:
a = torch.ones(4,1)*5
b = torch.ones(4,4)
torch.cat((a,b), 1)

tensor([[5., 1., 1., 1., 1.],
        [5., 1., 1., 1., 1.],
        [5., 1., 1., 1., 1.],
        [5., 1., 1., 1., 1.]])

In [13]:
torch.ones((4,1), dtype=torch.long)

tensor([[1],
        [1],
        [1],
        [1]])

## Setup all the configurations 

In [14]:
config = TRLConfig(
        train=TrainConfig(
            seq_length=200,
            epochs=n_epochs,
            total_steps=10000,
            batch_size=bs,
            checkpoint_interval=10000,
            eval_interval=100,
            pipeline="PromptPipeline",   # ???
            trainer="AcceleratePPOSpektroTrainer",
        ),
        model=ModelConfig(model_path=model, num_layers_unfrozen=-1, model_arch_type="seq2seq"),
        tokenizer=TokenizerConfig(tokenizer_path=tokenizer_path), #, truncation_side="right"),,
        optimizer=OptimizerConfig(
            name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)
        ),
        scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), # ???
        method=PPOConfig( #???
            name="PPOConfig",
            num_rollouts=32,
            chunk_size=32,
            ppo_epochs=2,
            init_kl_coef=0.05,
            target=6,
            horizon=10000,
            gamma=1,
            lam=0.95,
            cliprange=0.2,
            cliprange_value=0.2,
            vf_coef=1,
            scale_reward="ignored",
            ref_mean=None,
            ref_std=None,
            cliprange_reward=10,
            gen_kwargs=dict(
                max_new_tokens=200,
                top_p=0.8,
                do_sample=True,
                num_beams=1
            ),
        ),
    )

## Define a Reward function

In [15]:
from rdkit import Chem, DataStructs

def reward_fn(samples, prompts, outputs, invalid_smiles_reward=-0.1):
    """Variant for one sample for each prompt. Reward is the mean fingerprint 
       similarity of the sample to its corresponding output (=label in this case)"""    
    def smiles_to_fp(smiles):
        try:
            fp = Chem.RDKFingerprint(Chem.MolFromSmiles(smiles))
            return fp
        except:
            print("came accross: ", smiles, "which can't be translated to FingerPrint")
            return None
    
    samples_fp = [smiles_to_fp(s) for s in samples]
    outputs_fp = [smiles_to_fp(s) for s in outputs]
    rewards = [DataStructs.FingerprintSimilarity(s_fp, o_fp) if s_fp and o_fp else invalid_smiles_reward
               for (s_fp, o_fp) 
                   in zip(samples_fp, outputs_fp)]
    return rewards

## Train the model

In [16]:
train_data.data['position_ids'].iloc[0:2]

629    [3, 1, 2, 5, 5, 4, 4, 5, 6, 7, 8, 6, 7, 9, 6, ...
338    [1, 1, 3, 6, 6, 7, 4, 5, 2, 1, 1, 2, 4, 7, 6, ...
Name: position_ids, dtype: object

In [36]:
trlx.train(
    config=config,
    prompts=train_data.data[['input_ids', 'position_ids', 'encoder_attention_mask']].iloc[0:2],
    labels=train_data.data['labels'].iloc[0:2].to_list(),
    reward_fn=reward_fn,
    eval_prompts=valid_data.data[['input_ids', 'position_ids', 'encoder_attention_mask']].iloc[0:2],
    eval_labels=valid_data.data['labels'].iloc[0:2].to_list()
)

[RANK 0] Initializing model: BartSpektoForConditionalGeneration(
  (model): BartSpektroModel(
    (shared): Embedding(1233, 1024, padding_idx=2)
    (encoder): BartSpektroEncoder(
      (embed_tokens): Embedding(1233, 1024, padding_idx=2)
      (embed_positions): BartSpektroLearnedPositionalEmbedding(12, 1024)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bia

Problem at: /storage/brno2/home/ahajek/miniconda3/envs/trlx3.10/lib/python3.10/site-packages/accelerate/tracking.py 237 __init__


Traceback (most recent call last):
  File "/storage/brno2/home/ahajek/miniconda3/envs/trlx3.10/lib/python3.10/site-packages/wandb/sdk/wandb_init.py", line 1144, in init
    run = wi.init()
  File "/storage/brno2/home/ahajek/miniconda3/envs/trlx3.10/lib/python3.10/site-packages/wandb/sdk/wandb_init.py", line 773, in init
    raise error
wandb.errors.CommError: Error communicating with wandb process, exiting...
For more info see: https://docs.wandb.ai/guides/track/tracking-faq#initstarterror-error-communicating-with-wandb-process-
[34m[1mwandb[0m: [32m[41mERROR[0m Abnormal program exit


## Playground

In [29]:
t = AutoTokenizer.from_pretrained(tokenizer_path)

In [40]:
t.decode([  3,   3, 224, 278,  20, 266,  11, 271,  20,  38,  12,  54, 260,  50,
         273,  50,  12,  49,  20, 269,  11, 261,  20,  12,  54, 260,  50, 273,
          50,  12,  70,  20, 284,  11,  70,  20,  12,  38,   0,   2,   2,   2,
           2,   2,   2,   2,   2])

'<bos><bos> Cn1cc(nc1C)S(=O)(=O)N1CCN(CC1)S(=O)(=O)c1cnn(c1)C<eos><pad><pad><pad><pad><pad><pad><pad><pad>'

In [72]:
d = {"a": 4, "b": 6}
d.pop("a")
d.to(device)

{'b': 6}

In [15]:
a = [-100,1,2,3,-100]
a = np.array(a)
a[a == -100] = 2
a

array([2, 1, 2, 3, 2])

In [18]:
a = torch.tensor([[94199458545776, 94199463207936,         0,         0,         4,         5,
                 6,         7,         8,         9,        10,        11,
                12,        13,        14,        15,        16,        17,
                18,        19,        20,        21,        22,        23,
                24,        25,        26,        27,        28,        29,
                30,        31,        32,        33,        34,        35,
                36,        37,        38,        39,        40,        41,
                42,        43,        44,        45,        46,        47,
                48,        49,        50,        51,        52,        53,
                54,        55,        56,        57,        58,        59,
                60,        61,        62,        63,        64,        65,
                66,        67,        68,        69,        70,        71,
                72,        73,        74,        75,        76,        77,
                78,        79,        80,        81,        82,        83,
                84,        85,        86,        87,        88,        89,
                90,        91,        92,        93,        94,        95,
                96,        97,        98,        99,       100,       101,
               102,       103,       104,       105,       106,       107,
               108,       109,       110,       111,       112,       113,
               114,       115,       116,       117,       118,       119,
               120,       121,       122,       123,       124,       125,
               126,       127,       128,       129,       130,       131,
               132,       133,       134,       135,       136,       137,
               138,       139,       140,       141,       142,       143,
               144,       145,       146,       147,       148,       149,
               150,       151,       152,       153,       154,       155,
               156,       157,       158,       159,       160,       161,
               162,       163,       164,       165,       166,       167,
               168,       169,       170,       171,       172,       173,
               174,       175,       176,       177,       178,       179,
               180,       181,       182,       183,       184,       185,
               186,       187,       188,       189,       190,       191,
               192,       193,       194,       195,       196,       197,
               198,       199,       200],
        [94199464076560, 94199463207936,         0,         0,         5,         6,
                 7,         8,         9,        10,        11,        12,
                13,        14,        15,        16,        17,        18,
                19,        20,        21,        22,        23,        24,
                25,        26,        27,        28,        29,        30,
                31,        32,        33,        34,        35,        36,
                37,        38,        39,        40,        41,        42,
                43,        44,        45,        46,        47,        48,
                49,        50,        51,        52,        53,        54,
                55,        56,        57,        58,        59,        60,
                61,        62,        63,        64,        65,        66,
                67,        68,        69,        70,        71,        72,
                73,        74,        75,        76,        77,        78,
                79,        80,        81,        82,        83,        84,
                85,        86,        87,        88,        89,        90,
                91,        92,        93,        94,        95,        96,
                97,        98,        99,       100,       101,       102,
               103,       104,       105,       106,       107,       108,
               109,       110,       111,       112,       113,       114,
               115,       116,       117,       118,       119,       120,
               121,       122,       123,       124,       125,       126,
               127,       128,       129,       130,       131,       132,
               133,       134,       135,       136,       137,       138,
               139,       140,       141,       142,       143,       144,
               145,       146,       147,       148,       149,       150,
               151,       152,       153,       154,       155,       156,
               157,       158,       159,       160,       161,       162,
               163,       164,       165,       166,       167,       168,
               169,       170,       171,       172,       173,       174,
               175,       176,       177,       178,       179,       180,
               181,       182,       183,       184,       185,       186,
               187,       188,       189,       190,       191,       192,
               193,       194,       195,       196,       197,       198,
               199,       200,       201]])
a.size()

torch.Size([2, 201])

In [34]:
t(t.pad_token)

{'input_ids': [2], 'token_type_ids': [0], 'attention_mask': [1]}

In [35]:
t.pad_token_id

2

In [20]:
label = torch.tensor([0,0,0,-100,0])
label[label == -100] = 2
label

tensor([0, 0, 0, 2, 0])

In [22]:
import torch.nn.functional as F

In [25]:
t = torch.ones((4,5))
print(t)
F.pad(t, (0,-4,0,0))

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])


tensor([[1.],
        [1.],
        [1.],
        [1.]])

In [None]:
t