# 1. Trainied BackBone from Huggingface 

`ChemBERT`를 backbone으로 사용
https://huggingface.co/models?sort=downloads&search=chemBERT 

molecular Graph를 활용해서 GNN에서 graph representation을 획득하고,
추가적으로 molecular descriptor를 사용해서 추가적인 molecular property를 같이 활용하는게 기본적인 접근인데.
이를 GNN대신 트랜스포머계열을 활용해서 graph representation을 획득하는 (매우단순한) 모델입니다. 

> 획득하고 싶은 임베딩 dimension을 `num_labels`에 전달

In [None]:
# molecule predictor 
import torch
import torch.nn as nn 
import torch.nn.functional as F

from transformers import AutoModelForSequenceClassification
from utils import chem



class ChemBERT(nn.Module):
    def __init__(self, BERT_out_dim, projection_dim, out_dim) -> None:
        super(ChemBERT, self).__init__()
        self.ChemBERT_encoder = AutoModelForSequenceClassification.from_pretrained(
            chem.chosen, num_labels=BERT_out_dim, problem_type="multi_label_classification"
        )
        
        # (classifier): RobertaClassificationHead(
        #                 (dense): Linear(in_features=384, out_features=384, bias=True)
        #                 (dropout): Dropout(p=0.144, inplace=False)
        #                 (out_proj): Linear(in_features=384, out_features=BERT_out_dim, bias=True)
        #                 )
        
        self.projection = nn.Linear(in_features=projection_dim, out_features=projection_dim)
        self.ln = nn.LayerNorm(normalized_shape=projection_dim)
        self.out = nn.Linear(in_features=projection_dim, out_features=out_dim)

        self.act = nn.GELU()
        self.drop = nn.Dropout(0.144)
        
    def forward(self, batch):
        enc_out = self.ChemBERT_encoder(batch.input_ids).logits
        
        h = torch.concat([enc_out, batch.mol_f], dim=1)
        h = self.projection(h)
        h = self.ln(h)
        h = self.act(h)
        h = self.drop(h)
        h = self.out(h)
        
        return h

# 2. Chemical data 획득

예전에 GNN을 가볍게 만져봤을 때 사용했던 형태에 얼기설기 붙인 형식이라서 코드가 매끄럽지는 않습니다.

### atomic features
- `get_atomic_features` 함수를 통해서 분자의 atomic feature를 획득할 수 있습니다. 사용하는 features는 `mendeleev` 패키지에서 획득을 했고, `Chemical_feature_generator.mendeleev_atomic_f`를 확인하면 어떠한 atomic feature를 선택했는지 확인하실 수 있습니다. 

- `generate_mol_atomic_features` 함수를 통해서 하나의 분자에 해당하는 atomic feature matrix를 획득할 수 있습니다. 

- `get_adj_matrix` 함수를 통해서 분자의 adjacency matrix를 획득하고, 이를 다시 sparse tensor의 형태로 변형합니다. atomic feature와 함께 사용하면 GCN의 forward 함수에 활용 할 수 있습니다. 
(*deepchem MPNN featurizer가 제공하는 것과 많은 부분이 유사합니다.)

- `get_atomic_feature_from_deepchem`를 통해서 더 넓고 다양한 범위 atomic feature를 획득하는 함수입니다. 대회에서 제공해준 모든 데이터에 유효한 동작을 하지 않아서 사용되지 못했습니다.

- `encoder_smiles` 함수는 chemBERT의 입력을 위해 smiles tokenizing을 하는 함수입니다. 


### molecular features
- `get_molecular_features` 함수를 통해서 분자의 molecular properties를 획득할 수 있습니다. 최대한 intrinsic properties를 획득할 수 있도록 했습니다. 사용하게되는 라이브러리는 'rdkit'인데, 이곳에서 획득하는 feature는 대회에서 제공된 동일한 특성의 값과 차이가 있습니다. (e.g. LogP)

- `get_molecule_fingerprints` 함수를 통해서 분자의 fingerprints 들을 획득할 수 있습니다. 예전에 GCN을 활용했을 때는 FP가 3d infomation loss가 있기 때문에 단점이 존재한다라고 배웠었는데 제가 직접 사용했을 때는 성능의 매우 큰 차이는 못느꼇습니다. 

- `get_mol_feature_from_deepchem` 함수를 통해서 rdkit에서 제공하는 molecular descriptor의 대부분 feature를 획득할 수 있습니다. (200개 이상, deepChem API 사용).



In [None]:
from mendeleev.fetch import fetch_table
from rdkit import Chem
from rdkit.Chem import (Descriptors,
                        Lipinski,
                        Crippen,
                        rdMolDescriptors,
                        MolFromSmiles,
                        AllChem,
                        PandasTools)

from torch import Tensor
from torch_geometric.utils.sparse import dense_to_sparse

from transformers import AutoTokenizer, AutoModel

import numpy as np

from sklearn.preprocessing import StandardScaler
from rdkit import DataStructs
from deepchem import feat

archieve = ['seyonec/PubChem10M_SMILES_BPE_450k', "DeepChem/ChemBERTa-77M-MTR", 'seyonec/ChemBERTa-zinc-base-v1', 'seyonec/ChemBERTa_zinc250k_v2_40k']
chosen = archieve[2]

class Chemical_feature_generator():
    def __init__(self) -> None:
        mendeleev_atomic_f = ['atomic_radius', 'atomic_radius_rahm', 'atomic_volume', 'atomic_weight', 'c6', 'c6_gb', 
                        'covalent_radius_cordero', 'covalent_radius_pyykko', 'covalent_radius_pyykko_double', 'covalent_radius_pyykko_triple', 
                        'density', 'dipole_polarizability', 'dipole_polarizability_unc', 'electron_affinity', 'en_allen', 'en_ghosh', 'en_pauling', 
                        'heat_of_formation', 'is_radioactive', 'molar_heat_capacity', 'specific_heat_capacity', 'vdw_radius']
                        # Others are fine
                        # Heat of Formation: This reflects the energy associated with the formation of a molecule and might indirectly impact metabolic reactions.
                        # Is Radioactive: This binary property may not be directly relevant to metabolic stability.
                        # Molar Heat Capacity, Specific Heat Capacity: These properties relate to heat transfer but might not be directly tied to metabolic stability.
        self.mendeleev_atomic_f_table = fetch_table('elements')[mendeleev_atomic_f]
        
        self.DMPNNFeaturizer = feat.DMPNNFeaturizer()
        self.Mol2VecFingerprint = feat.Mol2VecFingerprint()
        self.BPSymmetryFunctionInput = feat.BPSymmetryFunctionInput(max_atoms=150)
        
        
        self.tokenizer = AutoTokenizer.from_pretrained(chosen)
        
        
        # train_atom_idx = [5, 7, 6, 15, 8, 16, 34, 33, 52] # test idx = [5, 6, 15, 7, 16, 34, 8, 14]
        # train_table = self.mendeleev_atomic_f_table.loc[train_atom_idx]
        # scaler = StandardScaler()
        # scaler.fit(train_table)
        
        # self.mendeleev_atomic_f_table.iloc[:, :] = scaler.transform(self.mendeleev_atomic_f_table)

    def get_atomic_features(self,atom):

        atomic_num = atom.GetAtomicNum() - 1 # -1 is offset
        mendel_atom_f = self.mendeleev_atomic_f_table.loc[atomic_num]
        # mendel_atom_f.is_radioactive = mendel_atom_f.is_radioactive.astype(int)
        mendel_atom_f = mendel_atom_f.to_numpy().astype(np.float32)

        rdkit_atom_f = [atom.GetDegree(),
                atom.GetTotalDegree(),
                atom.GetFormalCharge(),
                atom.GetIsAromatic()*1.,
                atom.GetNumImplicitHs(),
                atom.GetNumExplicitHs(),
                atom.GetTotalNumHs(),
                atom.GetNumRadicalElectrons(),
                atom.GetImplicitValence(),
                atom.GetExplicitValence(),
                atom.GetTotalValence(),
                atom.IsInRing()*1.]
        
        return mendel_atom_f, rdkit_atom_f
    
    def get_molecular_features(self, mol):
        ## 1. Molecular Descriptors 5
        MolWt = Descriptors.MolWt(mol)
        HeavyAtomMolWt = Descriptors.HeavyAtomMolWt(mol)
        NumValenceElectrons = Descriptors.NumValenceElectrons(mol)
        MolMR = Crippen.MolMR(mol)
        MolLogP = Crippen.MolLogP(mol)
 
        ## 2. Lipinski's Rule of Five 16
        FractionCSP3 = Lipinski.FractionCSP3(mol)
        HeavyAtomCount = Lipinski.HeavyAtomCount(mol)
        NHOHCount = Lipinski.NHOHCount(mol)
        NOCount = Lipinski.NOCount(mol)
        NumAliphaticCarbocycles = Lipinski.NumAliphaticCarbocycles(mol)
        NumAliphaticHeterocycles = Lipinski.NumAliphaticHeterocycles(mol)
        NumAliphaticRings = Lipinski.NumAliphaticRings(mol)
        NumAromaticCarbocycles = Lipinski.NumAromaticCarbocycles(mol)
        NumAromaticHeterocycles = Lipinski.NumAromaticHeterocycles(mol)
        NumAromaticRings = Lipinski.NumAromaticRings(mol)
        NumHAcceptors = Lipinski.NumHAcceptors(mol)
        NumHDonors = Lipinski.NumHDonors(mol)
        NumHeteroatoms = Lipinski.NumHeteroatoms(mol)
        NumRotatableBonds = Lipinski.NumRotatableBonds(mol)
        RingCount = Lipinski.RingCount(mol)
        CalcNumBridgeheadAtom = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)

        ## 3. Additional Features 11
        ExactMolWt = Descriptors.ExactMolWt(mol)
        NumRadicalElectrons = Descriptors.NumRadicalElectrons(mol)
        # MaxPartialCharge = Descriptors.MaxPartialCharge(mol) 
        # MinPartialCharge = Descriptors.MinPartialCharge(mol) 
        # MaxAbsPartialCharge = Descriptors.MaxAbsPartialCharge(mol) 
        # MinAbsPartialCharge = Descriptors.MinAbsPartialCharge(mol)  
        NumSaturatedCarbocycles = Lipinski.NumSaturatedCarbocycles(mol)
        NumSaturatedHeterocycles = Lipinski.NumSaturatedHeterocycles(mol)
        NumSaturatedRings = Lipinski.NumSaturatedRings(mol)
        CalcNumAmideBonds = rdMolDescriptors.CalcNumAmideBonds(mol)
        CalcNumSpiroAtoms = rdMolDescriptors.CalcNumSpiroAtoms(mol)
        
        num_carboxyl_groups = len(mol.GetSubstructMatches(MolFromSmiles("[C](=O)[OH]"))) # "[C;X3](=O)[OH1]" not working
        num_amion_groups = len(mol.GetSubstructMatches(MolFromSmiles("[NH2]")))
        num_ammonium_groups = len(mol.GetSubstructMatches(MolFromSmiles("[NH4+]")))
        num_sulfonic_acid_groups = len(mol.GetSubstructMatches(MolFromSmiles("[S](=O)(=O)[O-]")))
        num_alkoxy_groups = len(mol.GetSubstructMatches(MolFromSmiles('CO'))) # "[*]-O-[*]" not working
        
        return [MolWt,
                HeavyAtomMolWt,
                NumValenceElectrons,
                FractionCSP3,
                HeavyAtomCount,
                NHOHCount,
                NOCount,
                NumAliphaticCarbocycles,
                NumAliphaticHeterocycles,
                NumAliphaticRings,
                NumAromaticCarbocycles,
                NumAromaticHeterocycles,
                NumAromaticRings,
                NumHAcceptors,
                NumHDonors,
                NumHeteroatoms,
                NumRotatableBonds,
                RingCount,
                MolMR,
                CalcNumBridgeheadAtom,
                ExactMolWt,
                NumRadicalElectrons,
                # MaxPartialCharge,
                # MinPartialCharge,
                # MaxAbsPartialCharge,
                # MinAbsPartialCharge,
                NumSaturatedCarbocycles,
                NumSaturatedHeterocycles,
                NumSaturatedRings,
                MolLogP,
                CalcNumAmideBonds,
                CalcNumSpiroAtoms,
                num_carboxyl_groups,
                num_amion_groups,
                num_ammonium_groups,
                num_sulfonic_acid_groups,
                num_alkoxy_groups]

    def get_molecule_fingerprints(self, mol):
        ECFP12 = AllChem.GetHashedMorganFingerprint(mol, 6, nBits=2048) # 2048
        ECFP6 = AllChem.GetHashedMorganFingerprint(mol, 3, nBits=2048) # 2048
        
        MACCS = Chem.rdMolDescriptors.GetMACCSKeysFingerprint(mol) # 167
        RDK_fp = Chem.RDKFingerprint(mol) # 2048
        Layer_fp = Chem.rdmolops.LayeredFingerprint(mol) # 2048
        Pattern_fp = Chem.rdmolops.PatternFingerprint(mol) # 2048
        
        ecfp12 = np.zeros((1,), dtype=np.int8)
        ecfp6 = np.zeros((1,), dtype=np.int8)
        maccs = np.zeros((1,), dtype=np.int8)
        rdk_fp = np.zeros((1,), dtype=np.int8)
        layer_fp = np.zeros((1,), dtype=np.int8)
        pattern_fp = np.zeros((1,), dtype=np.int8)
        DataStructs.ConvertToNumpyArray(ECFP12, ecfp12)
        DataStructs.ConvertToNumpyArray(ECFP6, ecfp6)
        DataStructs.ConvertToNumpyArray(MACCS, maccs)
        DataStructs.ConvertToNumpyArray(RDK_fp, rdk_fp)
        DataStructs.ConvertToNumpyArray(Layer_fp, layer_fp)
        DataStructs.ConvertToNumpyArray(Pattern_fp, pattern_fp)
        return np.hstack([ecfp12, ecfp6, maccs, rdk_fp, layer_fp, pattern_fp]) 
    
    def get_mol_feature_from_deepchem(self, smiles):
        return self.Mol2VecFingerprint(smiles) # (1, 300)
    
    def encoder_smiles(self, smiles):
        inputs = self.tokenizer.encode_plus(smiles, padding=True, return_tensors='pt', add_special_tokens=True)
        # outputs = self.model()
        return inputs['input_ids']
    
    def get_atomic_feature_from_deepchem(self, smiles):
        # CCCC(=O)OC[C@@H](OC(=O)CCC)[C@@H](OC(=O)CCC)[C@H](Cn1c2nc(=O)[nH]c(=O)c-2nc2cc(C)c(C)cc21)OC(=O)CCC
        # 이런거 때문에 죽게되는 경우가 발생
        # BPSF_matrix = self.BPSymmetryFunctionInput(smiles)[0] # 100 x 4
        DMPNN_F = self.DMPNNFeaturizer(smiles)[0].node_features
        
        # atom_n, _ = DMPNN_F.shape
        # try:
        #     np.concatenate([DMPNN_F, BPSF_matrix[:atom_n, :]], axis=1, dtype=np.float32)
        # except:
        #     print(BPSF_matrix.shape, DMPNN_F.shape)
        #     # np.concatenate([DMPNN_F, BPSF_matrix[:atom_n, :]], axis=1, dtype=np.float32)
        #     print(Exception)
        #     exit()
        return DMPNN_F

    def generate_mol_atomic_features(self, smiles):

        mol = Chem.MolFromSmiles(smiles)

        # gathering atomic feature 
        mendel_atom_features = [] 
        rdkit_atom_features = [] 
        for atom in mol.GetAtoms():
            mendel_atom_f, rdkit_atom_f = self.get_atomic_features(atom)

            mendel_atom_features.append(mendel_atom_f)
            rdkit_atom_features.append(rdkit_atom_f)

        dc_atmoic = self.get_atomic_feature_from_deepchem(smiles)
        atomic_features = np.concatenate([mendel_atom_features, rdkit_atom_features, dc_atmoic], axis=1, dtype=np.float32)
        
        return atomic_features
    
    def get_adj_matrix(self, smiles):
        
        mol = MolFromSmiles(smiles)
        adj = Chem.rdmolops.GetAdjacencyMatrix(mol)
        
        edge_index, edge_attr = dense_to_sparse(Tensor(adj))
        
        # bonds = []
        # for i in range(0, mol.GetNumAtoms()):
        #     for j in range(0, mol.GetNumAtoms()):
        #         if adj[i, j] == 1:
        #             bonds.append([i, j])
            
        return edge_index, edge_attr


if __name__ == '__main__' : 
    import pandas as pd 
    from tqdm import tqdm
    from deepchem.feat.molecule_featurizers import RDKitDescriptors
    
    train = pd.read_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/dup_removed/new_train.csv')
    test  = pd.read_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/test.csv')
    
    generator = Chemical_feature_generator()
    
    def process(df):
        molecular_f = [] 
        for sample in tqdm(df.SMILES):
            sample = Chem.MolFromSmiles(sample)
            molecular_features = generator.get_molecular_features(mol=sample)
            # print(fps.shape)
            molecular_f.append(molecular_features)
        
            # for fp, name in zip(fps, ['ECFP12','ECFP6','MACCS','RDK_fp','Layer_fp','Pattern_fp']):
            #     print(name, len(fp))
        molecular_f = np.concatenate([molecular_f], axis=0)
        # print(molecular_f.shape)
        
        return pd.DataFrame(data=molecular_f, columns=['MolWt','HeavyAtomMolWt','NumValenceElectrons','FractionCSP3','HeavyAtomCount','NHOHCount','NOCount','NumAliphaticCarbocycles','NumAliphaticHeterocycles','NumAliphaticRings','NumAromaticCarbocycles','NumAromaticHeterocycles','NumAromaticRings','NumHAcceptors','NumHDonors','NumHeteroatoms','NumRotatableBonds','RingCount','MolMR','CalcNumBridgeheadAtom','ExactMolWt','NumRadicalElectrons','NumSaturatedCarbocycles','NumSaturatedHeterocycles','NumSaturatedRings','MolLogP','CalcNumAmideBonds','CalcNumSpiroAtoms','num_carboxyl_groups','num_amion_groups','num_ammonium_groups','num_sulfonic_acid_groups','num_alkoxy_groups',])
                                                   #    'ECFP12','ECFP6','MACCS','RDK_fp','Layer_fp','Pattern_fp' ])
    
    # print(train.iloc[1, :])
    # print(test.columns)
    # train_molecular_f = process(train)
    # train_merged = pd.concat([train, train_molecular_f], axis=1)
    
    # test_molecular_f = process(test)
    # test_merged = pd.concat([test, test_molecular_f], axis=1)
    
    # train_merged.to_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/new_train.csv', index=False)
    # test_merged.to_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/new_test.csv', index=False)
    
    def deepchem_rdkit(df):
        
        featurizer = RDKitDescriptors()
        rdkit_features = []
        
        for smiles in tqdm(df.SMILES):
            feature = featurizer(smiles)
            rdkit_features.append(feature)
            
        return np.concatenate(rdkit_features)
    
    
    features = deepchem_rdkit(train)
    column_means = np.mean(features, axis=0)
    non_zero_mean_columns = np.where(column_means != 0)[0]
    features = features[:, non_zero_mean_columns]
    features = np.concatenate([np.reshape(train.SMILES, (-1, 1)), features], axis=1)
    features = pd.DataFrame(features).dropna(axis=1)
    pd.DataFrame(features).to_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/dup_removed/rdkit_train.csv', index=False)
    train_col = features.columns
    features = deepchem_rdkit(test)
    features = features[:, non_zero_mean_columns]
    features = np.concatenate([np.reshape(test.SMILES, (-1 ,1)), features], axis=1)
    features = pd.DataFrame(features).dropna(axis=1)
    features = features[train_col]
    pd.DataFrame(features).to_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/dup_removed/rdkit_test.csv', index=False)
    test_col = features.columns
    
    print(list(train_col), len(train_col))
    print(list(test_col), len(test_col))
    


    
    

# 3. Custom Dataset(PyG)과 Pytorch-lighting을 활용한 Dataloder 

(2)에서 생성된 atomic & molecular feature를 dataset 형태로 만드는 코드입니다. 

### dataset in PyG
원래 코드가 GCN을 학습시키기 위한 코드였기 때문에 dataset은 torch_geometric을 사용했습니다.
> batch.x (atomic feature) batch.edge_index (molecular bonds)를 통해서 GCN forward에 제공하면 동작합니다. e.g. `h1 = self.conv1(g.x, g.edge_index)`


### pl.LightningDataModule 
Dataloader를 만드는 클래스입니다. KFold를 상정하고 코드를 만들었고, torch-lighting을 사용해서 더 간단하게 사용하실 수도 있습니다. 
학습데이터에 존재하는 LogP 결측치는 rdkit 패키지의 LogP의 값으로 대체하였습니다. 


In [None]:
from typing import Callable, Optional, Union

import numpy as np
import pandas as pd

from sklearn import preprocessing
from sklearn.model_selection import KFold
from sklearn.feature_selection import VarianceThreshold

import torch 
from torch.utils.data import Dataset

from torch_geometric.data import Data
from torch_geometric.data.dataset import IndexType
from torch_geometric.loader import DataLoader

import pytorch_lightning as pl

from utils.chem import Chemical_feature_generator
from rdkit.Chem import PandasTools

# ['AlogP', 'Molecular_Weight', 'Num_H_Acceptors', 'Num_H_Donors',
#                  'Num_RotatableBonds', 'LogD', 'Molecular_PolarSurfaceArea']
                 
feature_label = [ 'MolWt', 'HeavyAtomMolWt',
                    'NumValenceElectrons', 'FractionCSP3', 'HeavyAtomCount', 'NHOHCount',
                    'NOCount', 'NumAliphaticCarbocycles', 'NumAliphaticHeterocycles',
                    'NumAliphaticRings', 'NumAromaticCarbocycles',
                    'NumAromaticHeterocycles', 'NumAromaticRings', 'NumHAcceptors',
                    'NumHDonors', 'NumHeteroatoms', 'NumRotatableBonds', 'RingCount',
                    'MolMR', 'CalcNumBridgeheadAtom', 'ExactMolWt', 
                    'NumSaturatedCarbocycles', 'NumSaturatedHeterocycles',
                    'NumSaturatedRings', 'MolLogP', 'CalcNumAmideBonds',
                    'CalcNumSpiroAtoms',  
                    'num_ammonium_groups',  'num_alkoxy_groups'] # 29 

# NumRadicalElectrons, num_carboxyl_groups, num_amion_groups, num_sulfonic_acid_groups --> train zero

given_features = ['AlogP','Molecular_Weight','Num_H_Acceptors','Num_H_Donors','Num_RotatableBonds','LogD','Molecular_PolarSurfaceArea'] # 7 

generator = Chemical_feature_generator()

class Chemcial_dataset(Dataset):
    def __init__(self, 
                 data_frame:pd.DataFrame,
                 fps,
                 mol_f,
                 transform=None,
                 is_train=True):
        super().__init__()
        self.df = data_frame
        self.fps = fps
        self.mol_f = mol_f
        self.transform = transform
        
        self.is_train = is_train
            
        
    def __getitem__(self, idx: IndexType | int ):
        return self.get_chem_prop(idx)
        
    def __len__(self) -> int:
        return self.df.shape[0]
    
    def get_chem_prop(self, idx):
        
        sample = self.df.iloc[idx]
        fingerprint = self.fps[idx]
        molecular_feature = self.mol_f[idx]
        smiles = sample['SMILES']
        
        edge_index, edge_attr = generator.get_adj_matrix(smiles=smiles)
        atomic_feature = generator.generate_mol_atomic_features(smiles=smiles)
        input_ids = generator.encoder_smiles(smiles) # 384
        # ChemBERTa = ChemBERTa.detach()
        # molecular_feature = sample[feature_label] # if we use VarianceThreshold, then block this code

        if self.is_train:
            MLM = sample['MLM']
            HLM = sample['HLM']
        else:
            MLM = -99.
            HLM = -99.
            
        atomic_feature = torch.tensor(atomic_feature, dtype=torch.float)
        molecular_feature = torch.tensor(molecular_feature, dtype=torch.float).view(1, -1)
        fingerprint = torch.tensor(fingerprint, dtype=torch.float).view(1, -1)
        MLM = torch.tensor(MLM, dtype=torch.float).view(1, -1)
        HLM = torch.tensor(HLM, dtype=torch.float).view(1, -1)
        y = torch.concat([MLM, HLM], dim=1)
        
        return Data(x=atomic_feature, mol_f=molecular_feature, fp=fingerprint,
                    edge_index=edge_index, edge_attr=edge_attr, input_ids=input_ids,
                    y=y, MLM = MLM, HLM=HLM)
            
        
        
        
        
class KFold_pl_DataModule(pl.LightningDataModule):
    def __init__(self,
                 train_df: str = '/root/Competitions/DACON/4. JUMP AI 2023/data/new_train.csv',
                 k_idx: int =1, # fold index
                 num_split: int = 5, # fold number, if k=1 then return the whole data
                 split_seed: int = 41,
                 batch_size: int = 1, 
                 num_workers: int = 0,
                 pin_memory: bool = False,
                 persistent_workers: bool=True,
                 train_transform=None,
                 val_transform =None
                 ) -> None:
        super().__init__()
        persistent_workers = True if num_workers > 0 else False
        self.save_hyperparameters(logger=False)

        self.train_data = None
        self.val_data = None
        self.num_cls = 0

        self.setup()

    def setup(self, stage=None) -> None:
        if not self.train_data and not self.val_data:
            df = pd.read_csv(self.hparams.train_df, index_col=0)
            
            mask = df['AlogP'] != df['AlogP']
            df.loc[mask, 'AlogP'] = df.loc[mask, 'MolLogP']
            
            # if we use rdkit fingerprint generators 
            # PandasTools.AddMoleculeColumnToFrame(df,'SMILES','Molecule')
            # df["FPs"] = df.Molecule.apply(generator.get_molecule_fingerprints)
            # train_fps = np.stack(df["FPs"])
            mol2vec = []
            
            for smiles in df.SMILES:
                vec = generator.get_mol_feature_from_deepchem(smiles=smiles)
                mol2vec.append(vec)
                
            mol2vec = np.concatenate(mol2vec, axis=0)

            scaler = preprocessing.StandardScaler()
            craft_mol_f = df[given_features].to_numpy()
            # rdkit_df = pd.read_csv('/root/Competitions/DACON/4. JUMP AI 2023/data/rdkit_train.csv').iloc[:, 1:].to_numpy()
            # craft_mol_f = np.concatenate([craft_mol_f, rdkit_df], axis=1)
            craft_mol_f = scaler.fit_transform(craft_mol_f)
            # print(df.columns)
            
            # df[feature_label] = craft_mol_f
            
            # print(df.columns)
            # feature_selector = VarianceThreshold(threshold=0.05)
            
            # mol_f = feature_selector.fit_transform(df[feature_label])
            # fps = feature_selector.fit_transform(train_fps)
            

            kf = KFold(n_splits=self.hparams.num_split,
                       shuffle=True,
                       random_state=self.hparams.split_seed)
            all_splits = [k for k in kf.split(df)]
            train_idx, val_idx = all_splits[self.hparams.k_idx]
            train_idx, val_idx = train_idx.tolist(), val_idx.tolist()

            train_df = df.iloc[train_idx]
            train_fp = mol2vec[train_idx]
            train_mol_f = craft_mol_f[train_idx]
            
            val_df = df.iloc[val_idx]
            val_fp = mol2vec[val_idx]
            val_mol_f = craft_mol_f[val_idx]
            
            self.train_data = Chemcial_dataset(data_frame=train_df, fps=train_fp, mol_f=train_mol_f, transform=None, is_train=True)
            self.val_data = Chemcial_dataset(data_frame=val_df, fps=val_fp, mol_f=val_mol_f, transform=None, is_train=True)

    def train_dataloader(self):
        return DataLoader(self.train_data,
                          batch_size=self.hparams.batch_size,
                          shuffle=True,
                          num_workers=self.hparams.num_workers,
                          persistent_workers=self.hparams.persistent_workers,
                          pin_memory=self.hparams.pin_memory,
                          drop_last=True)
                          
    def val_dataloader(self):
        return DataLoader(self.val_data,
                          batch_size=self.hparams.batch_size,
                          shuffle=False,
                          num_workers=self.hparams.num_workers,
                          persistent_workers=self.hparams.persistent_workers,
                          pin_memory=self.hparams.pin_memory)
        

if __name__ == '__main__':
    data = KFold_pl_DataModule()
    
    train_lodaer = data.train_dataloader()
    
    for batch in train_lodaer:
        # DataBatch(x=[29, 34], edge_index=[2, 62], edge_attr=[62], mol_f=[1, 36], fp=[5235], MLM=[1], HLM=[1], batch=[29], ptr=[2])
        print(batch)
        
    