In [None]:
!pip install rdkit
!pip install torch
!pip install allennlp

In [None]:
import torch
import pickle as pi
import pandas as pd
import numpy as np
import rdkit
from rdkit import Chem
from rdkit.Chem import Descriptors, rdMolDescriptors
import tensorflow as tf

In [None]:
clf = pi.load(open('weights/clf.pkl', 'rb'))

In [None]:
drug = 'CN1C=NC2=C1C(=O)N(C)C(=O)N2C'

In [None]:
generated_example = 'NC(=O)CN1CCCC1=O'

In [None]:
def generate_descriptors(smiles):
  
  descriptor_names = list(rdMolDescriptors.Properties.GetAvailableProperties())
  get_descriptors = rdMolDescriptors.Properties(descriptor_names)
  
  molecule_object = Chem.MolFromSmiles(smiles)
  final_descriptors = np.array(get_descriptors.ComputeProperties(molecule_object)).reshape((-1, 43))

  return final_descriptors


def get_clf_input(coformer_smiles, drug_smiles=drug):

  drug_descriptors, coformer_descriptors = generate_descriptors(drug_smiles), generate_descriptors(coformer_smiles)

  final_input = np.concatenate((drug_descriptors, coformer_descriptors), axis=1)

  return final_input


def calculate_clf_error(coformer_smiles, desired_clf_output=1, drug_smiles=drug, classifier=clf):

  clf_input = get_clf_input(coformer_smiles, drug_smiles)
  clf_prediction = classifier.predict_proba(clf_input)[:,desired_clf_output]

  error = tf.keras.metrics.binary_crossentropy(desired_clf_output, 
                                               clf_prediction)

  return float(error)

### Объединенная модель

In [None]:
from model import MolGen

In [None]:
#создание списка SMILES
with open('/content/database_cof_100smb_kekule.csv', "r") as file:
  data = [molecule.replace('\n', '') for molecule in file]

In [None]:
#загрузка предобученной только на ChemBL генеративной модели и ее дообучение
# gan_mol = pi.load(open('weights/gan_mol7_50ksteps_100symb_kekuke.pkl', 'rb'))
gan_mol = MolGen(data=data, clf_path='weights/clf.pkl')
gan_mol.load('weights/gan_mol7_50ksteps_100symb_kekuke.pkl')

# загрузчик данных для модели
loader = gan_mol.create_dataloader(data, batch_size=20, shuffle=True, num_workers=10)

In [None]:
# train model for 10000 steps
gan_mol.train()

In [None]:
import warnings #возникает одна и та же ошибка
warnings.filterwarnings('ignore')

gan_mol.train_n_steps(loader, max_step=1000, evaluate_every=10)

### График измененеие clf loss в процессе обучения

In [None]:
import matplotlib.pyplot as plt

In [None]:
x1 = np.arange(100)
x2 = np.arange(100)

In [None]:
# построение графика
plt.figure(figsize=(15,10))
plt.plot(x1, y1, label = "Предобученная на коформерах модель")
plt.plot(x2, y2, label = "Предобученной только на ChemBL модель")
plt.title('Изменение clf loss в процессе обучения')
plt.xlabel('Итерация обучения')
plt.ylabel('clf loss mean')
plt.legend()
plt.show()