In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import KFold, train_test_split
from sklearn.preprocessing import StandardScaler
from transformers import AdamW, get_linear_schedule_with_warmup, RobertaModel, RobertaConfig, RobertaTokenizer, RobertaForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
from PolymerSmilesTokenization import PolymerSmilesTokenizer
from dataset import Downstream_Dataset, DataAugmentation, LoadPretrainData
import torch
import torch.nn as nn
from torchmetrics import R2Score
from torch.utils.tensorboard import SummaryWriter
from copy import deepcopy
from torch.utils.data import Dataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm
2023-06-16 17:21:49.292834: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-06-16 17:21:49.385838: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-06-16 17:21:49.403757: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-06-16 17:21:49.691374: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not l

In [2]:
fingerprint = torch.empty(3380, 768)
pred_output = torch.empty(3380,1)

In [3]:
class DownstreamRegression(nn.Module):
    def __init__(self, drop_rate=0.1):
        super(DownstreamRegression, self).__init__()
        self.PretrainedModel = deepcopy(PretrainedModel)
        self.PretrainedModel.resize_token_embeddings(len(tokenizer))
        
        self.Regressor = nn.Sequential(
            nn.Dropout(drop_rate),
            nn.Linear(self.PretrainedModel.config.hidden_size, self.PretrainedModel.config.hidden_size),
            nn.SiLU(),
            nn.Linear(self.PretrainedModel.config.hidden_size, 1)
        )

    def forward(self, input_ids, attention_mask,step):
        outputs = self.PretrainedModel(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.last_hidden_state[:, 0, :] #fingerprint
        fingerprint[step] = logits
        output = self.Regressor(logits)
        return output

def test(model, loss_fn, train_dataloader,device):

    model.eval()
    with torch.no_grad():
        for step, batch in enumerate(train_dataloader):
            print(f'Smiles: {step+1}')
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            prop = batch["prop"].to(device).float()
            outputs = model(input_ids, attention_mask,step).float()
            pred_output[step] = outputs

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_data = pd.read_csv('data/Egc.csv')


In [5]:
scaler = StandardScaler()
train_data.iloc[:, 1] = scaler.fit_transform(train_data.iloc[:, 1].values.reshape(-1, 1))
original_output = train_data.iloc[:,1]

PretrainedModel = RobertaModel.from_pretrained('ckpt/pretrain.pt')
tokenizer = PolymerSmilesTokenizer.from_pretrained("roberta-base", max_len=411)
train_dataset = Downstream_Dataset(train_data, tokenizer, 411)

model = DownstreamRegression(drop_rate=0.1).to(device)
model = model.double()
loss_fn = nn.MSELoss()

train_dataloader = DataLoader(train_dataset, 1, shuffle=False, num_workers=8)

steps_per_epoch = train_data.shape[0] // 1
training_steps = steps_per_epoch * 1
warmup_steps = int(training_steps * 0.05)

optimizer = AdamW(
                    [
                        {"params": model.PretrainedModel.parameters(), "lr":  0.00005,
                         "weight_decay": 0.0},
                        {"params": model.Regressor.parameters(), "lr": 0.0001,
                         "weight_decay": 0.01},
                    ],
    				no_deprecation_warning=True
                )
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                        num_training_steps=training_steps)
original_output

Some weights of the model checkpoint at ckpt/pretrain.pt were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaModel were not initialized from the model checkpoint at ckpt/pretrain.pt and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
The tokenizer class you 

0       1.517933
1       1.276332
2       1.274669
3       1.413256
4       1.416967
          ...   
3375   -0.741061
3376   -2.756854
3377   -2.829091
3378   -2.781807
3379   -1.482115
Name: value, Length: 3380, dtype: float64

In [6]:
for epoch in range(1):
    test(model, loss_fn, train_dataloader, device)

fingerprint = fingerprint.detach().cpu().numpy().tolist()
pred_output = pred_output.detach().cpu().numpy().tolist()
original_output = original_output.tolist()
    

Smiles: 1
Smiles: 2
Smiles: 3
Smiles: 4
Smiles: 5
Smiles: 6
Smiles: 7
Smiles: 8
Smiles: 9
Smiles: 10
Smiles: 11
Smiles: 12
Smiles: 13
Smiles: 14
Smiles: 15
Smiles: 16
Smiles: 17
Smiles: 18
Smiles: 19
Smiles: 20
Smiles: 21
Smiles: 22
Smiles: 23
Smiles: 24
Smiles: 25
Smiles: 26
Smiles: 27
Smiles: 28
Smiles: 29
Smiles: 30
Smiles: 31
Smiles: 32
Smiles: 33
Smiles: 34
Smiles: 35
Smiles: 36
Smiles: 37
Smiles: 38
Smiles: 39
Smiles: 40
Smiles: 41
Smiles: 42
Smiles: 43
Smiles: 44
Smiles: 45
Smiles: 46
Smiles: 47
Smiles: 48
Smiles: 49
Smiles: 50
Smiles: 51
Smiles: 52
Smiles: 53
Smiles: 54
Smiles: 55
Smiles: 56
Smiles: 57
Smiles: 58
Smiles: 59
Smiles: 60
Smiles: 61
Smiles: 62
Smiles: 63
Smiles: 64
Smiles: 65
Smiles: 66
Smiles: 67
Smiles: 68
Smiles: 69
Smiles: 70
Smiles: 71
Smiles: 72
Smiles: 73
Smiles: 74
Smiles: 75
Smiles: 76
Smiles: 77
Smiles: 78
Smiles: 79
Smiles: 80
Smiles: 81
Smiles: 82
Smiles: 83
Smiles: 84
Smiles: 85
Smiles: 86
Smiles: 87
Smiles: 88
Smiles: 89
Smiles: 90
Smiles: 91
Smiles: 

Smiles: 694
Smiles: 695
Smiles: 696
Smiles: 697
Smiles: 698
Smiles: 699
Smiles: 700
Smiles: 701
Smiles: 702
Smiles: 703
Smiles: 704
Smiles: 705
Smiles: 706
Smiles: 707
Smiles: 708
Smiles: 709
Smiles: 710
Smiles: 711
Smiles: 712
Smiles: 713
Smiles: 714
Smiles: 715
Smiles: 716
Smiles: 717
Smiles: 718
Smiles: 719
Smiles: 720
Smiles: 721
Smiles: 722
Smiles: 723
Smiles: 724
Smiles: 725
Smiles: 726
Smiles: 727
Smiles: 728
Smiles: 729
Smiles: 730
Smiles: 731
Smiles: 732
Smiles: 733
Smiles: 734
Smiles: 735
Smiles: 736
Smiles: 737
Smiles: 738
Smiles: 739
Smiles: 740
Smiles: 741
Smiles: 742
Smiles: 743
Smiles: 744
Smiles: 745
Smiles: 746
Smiles: 747
Smiles: 748
Smiles: 749
Smiles: 750
Smiles: 751
Smiles: 752
Smiles: 753
Smiles: 754
Smiles: 755
Smiles: 756
Smiles: 757
Smiles: 758
Smiles: 759
Smiles: 760
Smiles: 761
Smiles: 762
Smiles: 763
Smiles: 764
Smiles: 765
Smiles: 766
Smiles: 767
Smiles: 768
Smiles: 769
Smiles: 770
Smiles: 771
Smiles: 772
Smiles: 773
Smiles: 774
Smiles: 775
Smiles: 776
Smil

Smiles: 1348
Smiles: 1349
Smiles: 1350
Smiles: 1351
Smiles: 1352
Smiles: 1353
Smiles: 1354
Smiles: 1355
Smiles: 1356
Smiles: 1357
Smiles: 1358
Smiles: 1359
Smiles: 1360
Smiles: 1361
Smiles: 1362
Smiles: 1363
Smiles: 1364
Smiles: 1365
Smiles: 1366
Smiles: 1367
Smiles: 1368
Smiles: 1369
Smiles: 1370
Smiles: 1371
Smiles: 1372
Smiles: 1373
Smiles: 1374
Smiles: 1375
Smiles: 1376
Smiles: 1377
Smiles: 1378
Smiles: 1379
Smiles: 1380
Smiles: 1381
Smiles: 1382
Smiles: 1383
Smiles: 1384
Smiles: 1385
Smiles: 1386
Smiles: 1387
Smiles: 1388
Smiles: 1389
Smiles: 1390
Smiles: 1391
Smiles: 1392
Smiles: 1393
Smiles: 1394
Smiles: 1395
Smiles: 1396
Smiles: 1397
Smiles: 1398
Smiles: 1399
Smiles: 1400
Smiles: 1401
Smiles: 1402
Smiles: 1403
Smiles: 1404
Smiles: 1405
Smiles: 1406
Smiles: 1407
Smiles: 1408
Smiles: 1409
Smiles: 1410
Smiles: 1411
Smiles: 1412
Smiles: 1413
Smiles: 1414
Smiles: 1415
Smiles: 1416
Smiles: 1417
Smiles: 1418
Smiles: 1419
Smiles: 1420
Smiles: 1421
Smiles: 1422
Smiles: 1423
Smiles: 1424

Smiles: 1980
Smiles: 1981
Smiles: 1982
Smiles: 1983
Smiles: 1984
Smiles: 1985
Smiles: 1986
Smiles: 1987
Smiles: 1988
Smiles: 1989
Smiles: 1990
Smiles: 1991
Smiles: 1992
Smiles: 1993
Smiles: 1994
Smiles: 1995
Smiles: 1996
Smiles: 1997
Smiles: 1998
Smiles: 1999
Smiles: 2000
Smiles: 2001
Smiles: 2002
Smiles: 2003
Smiles: 2004
Smiles: 2005
Smiles: 2006
Smiles: 2007
Smiles: 2008
Smiles: 2009
Smiles: 2010
Smiles: 2011
Smiles: 2012
Smiles: 2013
Smiles: 2014
Smiles: 2015
Smiles: 2016
Smiles: 2017
Smiles: 2018
Smiles: 2019
Smiles: 2020
Smiles: 2021
Smiles: 2022
Smiles: 2023
Smiles: 2024
Smiles: 2025
Smiles: 2026
Smiles: 2027
Smiles: 2028
Smiles: 2029
Smiles: 2030
Smiles: 2031
Smiles: 2032
Smiles: 2033
Smiles: 2034
Smiles: 2035
Smiles: 2036
Smiles: 2037
Smiles: 2038
Smiles: 2039
Smiles: 2040
Smiles: 2041
Smiles: 2042
Smiles: 2043
Smiles: 2044
Smiles: 2045
Smiles: 2046
Smiles: 2047
Smiles: 2048
Smiles: 2049
Smiles: 2050
Smiles: 2051
Smiles: 2052
Smiles: 2053
Smiles: 2054
Smiles: 2055
Smiles: 2056

Smiles: 2612
Smiles: 2613
Smiles: 2614
Smiles: 2615
Smiles: 2616
Smiles: 2617
Smiles: 2618
Smiles: 2619
Smiles: 2620
Smiles: 2621
Smiles: 2622
Smiles: 2623
Smiles: 2624
Smiles: 2625
Smiles: 2626
Smiles: 2627
Smiles: 2628
Smiles: 2629
Smiles: 2630
Smiles: 2631
Smiles: 2632
Smiles: 2633
Smiles: 2634
Smiles: 2635
Smiles: 2636
Smiles: 2637
Smiles: 2638
Smiles: 2639
Smiles: 2640
Smiles: 2641
Smiles: 2642
Smiles: 2643
Smiles: 2644
Smiles: 2645
Smiles: 2646
Smiles: 2647
Smiles: 2648
Smiles: 2649
Smiles: 2650
Smiles: 2651
Smiles: 2652
Smiles: 2653
Smiles: 2654
Smiles: 2655
Smiles: 2656
Smiles: 2657
Smiles: 2658
Smiles: 2659
Smiles: 2660
Smiles: 2661
Smiles: 2662
Smiles: 2663
Smiles: 2664
Smiles: 2665
Smiles: 2666
Smiles: 2667
Smiles: 2668
Smiles: 2669
Smiles: 2670
Smiles: 2671
Smiles: 2672
Smiles: 2673
Smiles: 2674
Smiles: 2675
Smiles: 2676
Smiles: 2677
Smiles: 2678
Smiles: 2679
Smiles: 2680
Smiles: 2681
Smiles: 2682
Smiles: 2683
Smiles: 2684
Smiles: 2685
Smiles: 2686
Smiles: 2687
Smiles: 2688

Smiles: 3244
Smiles: 3245
Smiles: 3246
Smiles: 3247
Smiles: 3248
Smiles: 3249
Smiles: 3250
Smiles: 3251
Smiles: 3252
Smiles: 3253
Smiles: 3254
Smiles: 3255
Smiles: 3256
Smiles: 3257
Smiles: 3258
Smiles: 3259
Smiles: 3260
Smiles: 3261
Smiles: 3262
Smiles: 3263
Smiles: 3264
Smiles: 3265
Smiles: 3266
Smiles: 3267
Smiles: 3268
Smiles: 3269
Smiles: 3270
Smiles: 3271
Smiles: 3272
Smiles: 3273
Smiles: 3274
Smiles: 3275
Smiles: 3276
Smiles: 3277
Smiles: 3278
Smiles: 3279
Smiles: 3280
Smiles: 3281
Smiles: 3282
Smiles: 3283
Smiles: 3284
Smiles: 3285
Smiles: 3286
Smiles: 3287
Smiles: 3288
Smiles: 3289
Smiles: 3290
Smiles: 3291
Smiles: 3292
Smiles: 3293
Smiles: 3294
Smiles: 3295
Smiles: 3296
Smiles: 3297
Smiles: 3298
Smiles: 3299
Smiles: 3300
Smiles: 3301
Smiles: 3302
Smiles: 3303
Smiles: 3304
Smiles: 3305
Smiles: 3306
Smiles: 3307
Smiles: 3308
Smiles: 3309
Smiles: 3310
Smiles: 3311
Smiles: 3312
Smiles: 3313
Smiles: 3314
Smiles: 3315
Smiles: 3316
Smiles: 3317
Smiles: 3318
Smiles: 3319
Smiles: 3320

In [7]:
data = {'fingerprint': fingerprint, 'pred_out': pred_output, 'orig_out': original_output }
df = pd.DataFrame(data)
df.to_csv('data.csv', index=False)
df

Unnamed: 0,fingerprint,pred_out,orig_out
0,"[-0.8549143671989441, 0.07563474029302597, -1....",[-0.02010958455502987],1.517933
1,"[-0.4669734537601471, -0.8146975636482239, -0....",[-0.08788313716650009],1.276332
2,"[-0.46833327412605286, -0.6615654230117798, -0...",[-0.05377301201224327],1.274669
3,"[-0.35044270753860474, -0.8016369938850403, -0...",[-0.03918203338980675],1.413256
4,"[-0.899631679058075, -1.0210680961608887, -0.4...",[0.026506539434194565],1.416967
...,...,...,...
3375,"[-0.025111764669418335, -0.9626762866973877, -...",[0.05290037393569946],-0.741061
3376,"[0.5077277421951294, -2.0220606327056885, -0.1...",[-0.2541065216064453],-2.756854
3377,"[0.829918384552002, -2.096996545791626, -0.247...",[-0.30228862166404724],-2.829091
3378,"[1.0541179180145264, -2.050633192062378, -0.32...",[-0.33889251947402954],-2.781807
