In [1]:
import pandas as pd
import numpy as np
import sys
sys.path.insert(0, '/home/mm22d016/TransPolymer')
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import StandardScaler
from transformers import AdamW, get_linear_schedule_with_warmup, RobertaModel, RobertaConfig, RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from PolymerSmilesTokenization import PolymerSmilesTokenizer
from dataset import Downstream_Dataset, DataAugmentation, LoadPretrainData
import torch
import torch.nn as nn
from torchmetrics import R2Score
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy
from torch.utils.data import Dataset, DataLoader

In [2]:
value = []

In [3]:
test_data_METSA = pd.read_csv('../ckpt/METSA/test_for_LLM.csv')
test_data_METSA

Unnamed: 0,smiles,value
0,*CCN[Si](C)(C)N*,
1,*C=Cc1ccc(C=Cc2nc3cc4nc(*)[nH]c4cc3[nH]2)cc1,
2,*CC(*)C(=O)c1ccc(Br)cc1,
3,*C(S1)=CC=C1C(=S)OC(=S)*,
4,*c1cccc(-c2ccc(-c3ccc(*)s3)s2)n1,
...,...,...
840,*CCCCCCNC(=O)CCCCCCCCCCCCCCC(=O)N*,
841,*CC(*)OC(=O)c1ccc(OC)cc1,
842,*CCCCNC(=O)CCCCCCCC(=O)NCCCCNC(=O)C(=O)N*,
843,*CCCCCCCCCCOC(=O)c1ccc2cc(C(=O)O*)ccc2c1,


In [4]:
class DownstreamRegression(nn.Module):
    def __init__(self, drop_rate=0.1):
        super(DownstreamRegression, self).__init__()
        self.PretrainedModel = deepcopy(PretrainedModel)
        self.PretrainedModel.resize_token_embeddings(len(tokenizer))
        
        self.Regressor = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(self.PretrainedModel.config.hidden_size, self.PretrainedModel.config.hidden_size),
            nn.SiLU(),
            nn.Linear(self.PretrainedModel.config.hidden_size, 1)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.PretrainedModel(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.last_hidden_state[:, 0, :] #fingerprint
        output = self.Regressor(logits)
        return output
    
def test(model, train_dataloader,device, scaler):
        model.eval()
        for step, batch in enumerate(train_dataloader):
            print(f'Smiles: {step+1}')
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            output = model(input_ids, attention_mask)
            output = scaler.inverse_transform(output.detach().cpu().numpy().reshape(-1, 1))
            value.append(output.item())

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
saved_state = torch.load('../ckpt/METSA/METSA_train.pt')
PretrainedModel = RobertaModel.from_pretrained('../ckpt/pretrain.pt')
tokenizer = PolymerSmilesTokenizer.from_pretrained("roberta-base", max_len=411)
model = DownstreamRegression(drop_rate=0.1).to(device)
model = model.double()
model.load_state_dict(saved_state['model'])

Some weights of the model checkpoint at ../ckpt/pretrain.pt were not used when initializing RobertaModel: ['lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ../ckpt/pretrain.pt and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The tokenizer clas

<All keys matched successfully>

In [7]:
data = pd.read_csv('../ckpt/METSA/train_for_LLM.csv')
scaler = StandardScaler()
splits = KFold(n_splits=5, shuffle=True, random_state=1)   
for fold, (train_idx, val_idx) in enumerate(splits.split(np.arange(data.shape[0]))):
    print('Fold {}'.format(fold + 1))
    if(fold==4):
        train_data = data.loc[train_idx, :].reset_index(drop=True)
        test_data = data.loc[val_idx, :].reset_index(drop=True)
        train_data.iloc[:, 1] = scaler.fit_transform(train_data.iloc[:, 1].values.reshape(-1, 1))
        test_data.iloc[:, 1] = scaler.transform(test_data.iloc[:, 1].values.reshape(-1, 1))

Fold 1
Fold 2
Fold 3
Fold 4
Fold 5


In [8]:
test_dataset = Downstream_Dataset(test_data_METSA, tokenizer, 411)
test_dataloader = DataLoader(test_dataset, 1, shuffle=False, num_workers=8)

In [9]:
test(model, test_dataloader, device, scaler)

Smiles: 1
Smiles: 2
Smiles: 3
Smiles: 4
Smiles: 5
Smiles: 6
Smiles: 7
Smiles: 8
Smiles: 9
Smiles: 10
Smiles: 11
Smiles: 12
Smiles: 13
Smiles: 14
Smiles: 15
Smiles: 16
Smiles: 17
Smiles: 18
Smiles: 19
Smiles: 20
Smiles: 21
Smiles: 22
Smiles: 23
Smiles: 24
Smiles: 25
Smiles: 26
Smiles: 27
Smiles: 28
Smiles: 29
Smiles: 30
Smiles: 31
Smiles: 32
Smiles: 33
Smiles: 34
Smiles: 35
Smiles: 36
Smiles: 37
Smiles: 38
Smiles: 39
Smiles: 40
Smiles: 41
Smiles: 42
Smiles: 43
Smiles: 44
Smiles: 45
Smiles: 46
Smiles: 47
Smiles: 48
Smiles: 49
Smiles: 50
Smiles: 51
Smiles: 52
Smiles: 53
Smiles: 54
Smiles: 55
Smiles: 56
Smiles: 57
Smiles: 58
Smiles: 59
Smiles: 60
Smiles: 61
Smiles: 62
Smiles: 63
Smiles: 64
Smiles: 65
Smiles: 66
Smiles: 67
Smiles: 68
Smiles: 69
Smiles: 70
Smiles: 71
Smiles: 72
Smiles: 73
Smiles: 74
Smiles: 75
Smiles: 76
Smiles: 77
Smiles: 78
Smiles: 79
Smiles: 80
Smiles: 81
Smiles: 82
Smiles: 83
Smiles: 84
Smiles: 85
Smiles: 86
Smiles: 87
Smiles: 88
Smiles: 89
Smiles: 90
Smiles: 91
Smiles: 

Smiles: 706
Smiles: 707
Smiles: 708
Smiles: 709
Smiles: 710
Smiles: 711
Smiles: 712
Smiles: 713
Smiles: 714
Smiles: 715
Smiles: 716
Smiles: 717
Smiles: 718
Smiles: 719
Smiles: 720
Smiles: 721
Smiles: 722
Smiles: 723
Smiles: 724
Smiles: 725
Smiles: 726
Smiles: 727
Smiles: 728
Smiles: 729
Smiles: 730
Smiles: 731
Smiles: 732
Smiles: 733
Smiles: 734
Smiles: 735
Smiles: 736
Smiles: 737
Smiles: 738
Smiles: 739
Smiles: 740
Smiles: 741
Smiles: 742
Smiles: 743
Smiles: 744
Smiles: 745
Smiles: 746
Smiles: 747
Smiles: 748
Smiles: 749
Smiles: 750
Smiles: 751
Smiles: 752
Smiles: 753
Smiles: 754
Smiles: 755
Smiles: 756
Smiles: 757
Smiles: 758
Smiles: 759
Smiles: 760
Smiles: 761
Smiles: 762
Smiles: 763
Smiles: 764
Smiles: 765
Smiles: 766
Smiles: 767
Smiles: 768
Smiles: 769
Smiles: 770
Smiles: 771
Smiles: 772
Smiles: 773
Smiles: 774
Smiles: 775
Smiles: 776
Smiles: 777
Smiles: 778
Smiles: 779
Smiles: 780
Smiles: 781
Smiles: 782
Smiles: 783
Smiles: 784
Smiles: 785
Smiles: 786
Smiles: 787
Smiles: 788
Smil

In [11]:
submit = pd.DataFrame(columns=[['id','band_gap']])
submit

Unnamed: 0,id,band_gap


In [13]:
original = pd.read_csv('../ckpt/METSA/test.csv')
original

Unnamed: 0,id,polymer,fp_mqns_1,fp_mqns_2,fp_mqns_3,fp_mqns_4,fp_mqns_5,fp_mqns_6,fp_mqns_7,fp_mqns_8,...,fp_o_desc_chi1,fp_o_desc_chi1n,fp_o_desc_chi1v,fp_o_desc_chi2n,fp_o_desc_chi2v,fp_o_desc_chi3n,fp_o_desc_chi3v,fp_o_desc_chi4n,fp_o_desc_chi4v,fp_o_desc_HallKierAlpha
0,402,[*]CCN[Si](C)(C)N[*],1,0.0,0.0,0.000000,0.0,0.000000,0.0,0.500000,...,0.797200,0.673409,1.425826,0.612559,1.596830,0.295143,0.830742,0.191847,0.575542,0.034221
1,321,[*]C=Cc1ccc(C=Cc2nc3cc4nc([*])[nH]c4cc3[nH]2)cc1,1,0.0,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.601635,0.398572,0.398572,0.301130,0.301130,0.216130,0.216130,0.152212,0.152212,0.174444
2,457,[*]CC([*])C(=O)c1ccc(Br)cc1,1,0.0,0.0,0.111111,0.0,0.000000,0.0,0.000000,...,0.584644,0.395481,0.483592,0.306438,0.408180,0.221428,0.280169,0.156136,0.185506,0.070000
3,879,[*]C(S1)=CC=C1C(=S)OC(=S)[*],1,0.0,0.0,0.000000,0.0,0.500000,0.0,0.000000,...,0.796636,0.436934,0.711731,0.299270,0.596946,0.192372,0.505789,0.118563,0.338875,0.053333
4,1536,[*]c1cccc(-c2ccc(-c3ccc([*])s3)s2)n1,1,0.0,0.0,0.000000,0.0,0.153846,0.0,0.000000,...,0.607938,0.389791,0.516013,0.283723,0.449150,0.204320,0.387516,0.144337,0.299251,0.111538
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
840,1999,[*]CCCCCCNC(=O)CCCCCCCCCCCCCCC(=O)N[*],1,0.0,0.0,0.000000,0.0,0.000000,0.0,0.090909,...,0.580555,0.514045,0.514045,0.355778,0.355778,0.239026,0.239026,0.160479,0.160479,0.048182
841,859,[*]CC([*])OC(=O)c1ccc(OC)cc1,1,0.0,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.628507,0.413238,0.413238,0.294333,0.294333,0.205662,0.205662,0.135275,0.135275,0.151000
842,2093,[*]CCCCNC(=O)CCCCCCCC(=O)NCCCCNC(=O)C(=O)N[*],1,0.0,0.0,0.000000,0.0,0.000000,0.0,0.210526,...,0.688250,0.535492,0.535492,0.361213,0.361213,0.231253,0.231253,0.144982,0.144982,0.111579
843,2578,[*]CCCCCCCCCCOC(=O)c1ccc2cc(C(=O)O[*])ccc2c1,1,0.0,0.0,0.000000,0.0,0.000000,0.0,0.000000,...,0.578260,0.436916,0.436916,0.311542,0.311542,0.217445,0.217445,0.146681,0.146681,0.107273


In [14]:
submit['id'] = original['id']

In [15]:
submit

Unnamed: 0,id,band_gap
0,402,
1,321,
2,457,
3,879,
4,1536,
...,...,...
840,1999,
841,859,
842,2093,
843,2578,


In [16]:
value

[5.354493618495919,
 2.710038972819945,
 2.3135321025741145,
 2.5864388716302935,
 2.540879481072279,
 6.892379398711077,
 1.6732362608835536,
 1.8690193159838762,
 6.472431987361124,
 3.684964391835725,
 3.4662957187409558,
 4.500926017583922,
 0.2855854353579925,
 3.856860002828565,
 2.970672253614627,
 4.35253394885131,
 3.143310145994129,
 6.610614092195594,
 5.67196045467207,
 6.007422947105294,
 5.083595172291362,
 2.7881165070063982,
 6.873921045625343,
 4.813089790534937,
 5.3373189398258205,
 6.061309184513309,
 3.0020822197822534,
 4.137946913795932,
 5.929725353980939,
 2.3636447104344156,
 2.834535166427839,
 2.2930620318269814,
 5.7522709815122095,
 4.675562690395461,
 2.6707167922191357,
 2.8527811342284313,
 5.8018608383899535,
 7.440743154757687,
 5.840707795086743,
 4.444940769653643,
 3.787768771031312,
 5.563848002976417,
 2.54545047238943,
 4.647325211304881,
 4.686366613664133,
 3.3075962863998534,
 5.938986701807822,
 4.067893747319403,
 0.9034442296937129,
 6.366

In [17]:
submit['band_gap'] = value

In [18]:
submit

Unnamed: 0,id,band_gap
0,402,5.354494
1,321,2.710039
2,457,2.313532
3,879,2.586439
4,1536,2.540879
...,...,...
840,1999,5.942865
841,859,4.056190
842,2093,5.285124
843,2578,3.880980


In [19]:
submit.to_csv('../ckpt/METSA/submit.csv',index=False)