In [None]:
!pip install transformers

In [None]:
import os
import shutil
import time

import math

import numpy as np

import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import tokenizers
from tokenizers import Tokenizer
from transformers import BertTokenizerFast, BertModel

import csv

In [None]:
#Need a special generator for random sampling:

class GenerateData():
  def __init__(self, path_train, path_val, path_test, path_molecules, path_token_embs):
    self.path_train = path_train
    self.path_val = path_val
    self.path_test = path_test
    self.path_molecules = path_molecules
    self.path_token_embs = path_token_embs

    self.text_trunc_length = 256 

    self.prep_text_tokenizer()
    
    self.load_substructures()

    self.batch_size = 32

    self.store_descriptions()
    
  def load_substructures(self):
    self.molecule_sentences = {}
    self.molecule_tokens = {}

    total_tokens = set()
    self.max_mol_length = 0
    with open(self.path_molecules) as f:
      for line in f:
        spl = line.split(":")
        cid = spl[0]
        tokens = spl[1].strip()
        self.molecule_sentences[cid] = tokens
        t = tokens.split()
        total_tokens.update(t)
        size = len(t)
        if size > self.max_mol_length: self.max_mol_length = size


    self.token_embs = np.load(self.path_token_embs, allow_pickle = True)[()]



  def prep_text_tokenizer(self):
    self.text_tokenizer = BertTokenizerFast.from_pretrained("allenai/scibert_scivocab_uncased")
 

  def store_descriptions(self):
    self.descriptions = {}
    
    self.mols = {}



    self.training_cids = []
    #get training set cids...
    with open(self.path_train) as f:
      reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE, fieldnames = ['cid', 'mol2vec', 'desc'])
      for n, line in enumerate(reader):
        self.descriptions[line['cid']] = line['desc']
        self.mols[line['cid']] = line['mol2vec']
        self.training_cids.append(line['cid'])
        
    self.validation_cids = []
    #get validation set cids...
    with open(self.path_val) as f:
      reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE, fieldnames = ['cid', 'mol2vec', 'desc'])
      for n, line in enumerate(reader):
        self.descriptions[line['cid']] = line['desc']
        self.mols[line['cid']] = line['mol2vec']
        self.validation_cids.append(line['cid'])

    self.test_cids = []
    #get test set cids...
    with open(self.path_test) as f:
      reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE, fieldnames = ['cid', 'mol2vec', 'desc'])
      for n, line in enumerate(reader):
        self.descriptions[line['cid']] = line['desc']
        self.mols[line['cid']] = line['mol2vec']
        self.test_cids.append(line['cid'])

  def generate_examples_train(self):
    """Yields examples."""

    np.random.shuffle(self.training_cids)

    for cid in self.training_cids:
      text_input = self.text_tokenizer(self.descriptions[cid], truncation=True, max_length=self.text_trunc_length,
                                        padding='max_length', return_tensors = 'np')

      yield {
          'cid': cid,
          'input': {
              'text': {
                'input_ids': text_input['input_ids'].squeeze(),
                'attention_mask': text_input['attention_mask'].squeeze(),
              },
              'molecule' : {
                    'mol2vec' : np.fromstring(self.mols[cid], sep = " "),
                    'cid' : cid
              },
          },
      }


  def generate_examples_val(self):
    """Yields examples."""

    np.random.shuffle(self.validation_cids)

    for cid in self.validation_cids:
        text_input = self.text_tokenizer(self.descriptions[cid], truncation=True, padding = 'max_length', 
                                         max_length=self.text_trunc_length, return_tensors = 'np')

        mol_input = []

        yield {
            'cid': cid,
            'input': {
                'text': {
                  'input_ids': text_input['input_ids'].squeeze(),
                  'attention_mask': text_input['attention_mask'].squeeze(),
                },
                'molecule' : {
                    'mol2vec' : np.fromstring(self.mols[cid], sep = " "),
                    'cid' : cid
                }
            },
        }


  def generate_examples_test(self):
    """Yields examples."""

    np.random.shuffle(self.test_cids)

    for cid in self.test_cids:
        text_input = self.text_tokenizer(self.descriptions[cid], truncation=True, padding = 'max_length', 
                                         max_length=self.text_trunc_length, return_tensors = 'np')

        mol_input = []

        yield {
            'cid': cid,
            'input': {
                'text': {
                  'input_ids': text_input['input_ids'].squeeze(),
                  'attention_mask': text_input['attention_mask'].squeeze(),
                },
                'molecule' : {
                    'mol2vec' : np.fromstring(self.mols[cid], sep = " "),
                    'cid' : cid
                }
            },
        }



mounted_path_token_embs = "input/token_embedding_dict.npy"
mounted_path_train = "input/mol2vec_ChEBI_20_training.txt"
mounted_path_val = "input/mol2vec_ChEBI_20_val.txt"
mounted_path_test = "input/mol2vec_ChEBI_20_test.txt"
mounted_path_molecules = "input/ChEBI_defintions_substructure_corpus.cp"
gt = GenerateData(mounted_path_train, mounted_path_val, mounted_path_test, mounted_path_molecules, mounted_path_token_embs)


In [None]:


class Dataset(Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, gen, length):
      'Initialization'

      self.gen = gen
      self.it = iter(self.gen())

      self.length = length

  def __len__(self):
      'Denotes the total number of samples'
      return self.length


  def __getitem__(self, index):
      'Generates one sample of data'

      try:
        ex = next(self.it)
      except StopIteration:
        self.it = iter(self.gen())
        ex = next(self.it)

      X = ex['input']
      y = 1

      return X, y

training_set = Dataset(gt.generate_examples_train, len(gt.training_cids))
validation_set = Dataset(gt.generate_examples_val, len(gt.validation_cids))
test_set = Dataset(gt.generate_examples_test, len(gt.test_cids))


In [None]:

# Parameters
params = {'batch_size': gt.batch_size,
          'shuffle': True,
          'num_workers': 1}

training_generator = DataLoader(training_set, **params)
validation_generator = DataLoader(validation_set, **params)
test_generator = DataLoader(test_set, **params)


In [None]:

class Model(nn.Module):
    def __init__(self, ntoken, ninp, nout, nhid, dropout=0.5):
        super(Model, self).__init__()
        

        self.text_hidden1 = nn.Linear(ninp, nout)

        self.ninp = ninp
        self.nhid = nhid
        self.nout = nout
        
        self.drop = nn.Dropout(p=dropout)

        self.mol_hidden1 = nn.Linear(nout, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nhid)
        self.mol_hidden3 = nn.Linear(nhid, nout)
        

        self.temp = nn.Parameter(torch.Tensor([0.07]))
        self.register_parameter( 'temp' , self.temp )

        self.ln1 = nn.LayerNorm((nout))
        self.ln2 = nn.LayerNorm((nout))

        self.relu = nn.ReLU()
        self.selu = nn.SELU()
        
        self.other_params = list(self.parameters()) #get all but bert params
        
        self.text_transformer_model = BertModel.from_pretrained('allenai/scibert_scivocab_uncased')
        self.text_transformer_model.train()

    def forward(self, text, molecule, text_mask = None, molecule_mask = None):
      
        text_encoder_output = self.text_transformer_model(text, attention_mask = text_mask)

        text_x = text_encoder_output['pooler_output']
        text_x = self.text_hidden1(text_x)

        x = self.relu(self.mol_hidden1(molecule))
        x = self.relu(self.mol_hidden2(x))
        x = self.mol_hidden3(x)


        x = self.ln1(x)
        text_x = self.ln2(text_x)

        x = x * torch.exp(self.temp)
        text_x = text_x * torch.exp(self.temp)

        return text_x, x


In [None]:
model = Model(ntoken = gt.text_tokenizer.vocab_size, ninp = 768, nhid = 600, nout = 300)

In [None]:
import torch.optim as optim
from transformers.optimization import get_linear_schedule_with_warmup

epochs = 40

init_lr = 1e-4 
bert_lr = 3e-5
bert_params = list(model.text_transformer_model.parameters())

optimizer = optim.Adam([
                {'params': model.other_params},
                {'params': bert_params, 'lr': bert_lr}
            ], lr=init_lr)

num_warmup_steps = 1000
num_training_steps = epochs * len(training_generator) - num_warmup_steps
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps = num_warmup_steps, num_training_steps = num_training_steps) 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

tmp = model.to(device)


In [None]:
criterion = nn.CrossEntropyLoss()

def loss_func(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0]).to(device)
  return criterion(logits, labels) + criterion(torch.transpose(logits, 0, 1), labels)


In [None]:
train_losses = []
val_losses = []

train_acc = []
val_acc = []

mounted_path = "MLP_outputs/"
if not os.path.exists(mounted_path):
  os.mkdir(mounted_path)

# Loop over epochs
for epoch in range(epochs):
    # Training
    
    start_time = time.time()
    running_loss = 0.0
    running_acc = 0.0
    model.train()
    for i, d in enumerate(training_generator):
        batch, labels = d
        # Transfer to GPU
        
        text_mask = batch['text']['attention_mask'].bool()

        text = batch['text']['input_ids'].to(device)
        text_mask = text_mask.to(device)
        molecule = batch['molecule']['mol2vec'].float().to(device)

        text_out, chem_out = model(text, molecule, text_mask)
        
        loss = loss_func(text_out, chem_out).to(device)
        
        running_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        scheduler.step()
        
        if (i+1) % 100 == 0: print(i+1, "batches trained. Avg loss:\t", running_loss / (i+1), ". Avg ms/step =", 1000*(time.time()-start_time)/(i+1))
    train_losses.append(running_loss / (i+1))
    train_acc.append(running_acc / (i+1))

    print("Epoch", epoch, "training loss:\t\t", running_loss / (i+1), ". Time =", (time.time()-start_time), "seconds.")



    # Validation
    model.eval()
    with torch.set_grad_enabled(False):
      start_time = time.time()
      running_acc = 0.0
      running_loss = 0.0
      for i, d in enumerate(validation_generator):
          batch, labels = d
          # Transfer to GPU
        
          text_mask = batch['text']['attention_mask'].bool()

          text = batch['text']['input_ids'].to(device)
          text_mask = text_mask.to(device)
          molecule = batch['molecule']['mol2vec'].float().to(device)

          

          text_out, chem_out = model(text, molecule, text_mask)
        
          loss = loss_func(text_out, chem_out).to(device)
          running_loss += loss.item()
          
          if (i+1) % 100 == 0: print(i+1, "batches eval. Avg loss:\t", running_loss / (i+1), ". Avg ms/step =", 1000*(time.time()-start_time)/(i+1))
      val_losses.append(running_loss / (i+1))
      val_acc.append(running_acc / (i+1))


      min_loss = np.min(val_losses)
      if val_losses[-1] == min_loss:
          torch.save(model.state_dict(), mounted_path + 'weights_pretrained.{epoch:02d}-{min_loss:.2f}.pt'.format(epoch = epoch, min_loss = min_loss))
      
    print("Epoch", epoch, "validation loss:\t", running_loss / (i+1), ". Time =", (time.time()-start_time), "seconds.")


torch.save(model.state_dict(), mounted_path + "final_weights."+str(epochs)+".pt")


## Extract Embeddings



In [None]:

mounted_path = "MLP_outputs/"

cids_train = np.array([])
cids_val = np.array([])
cids_test = np.array([])
chem_embeddings_train = np.array([])
text_embeddings_train = np.array([])
chem_embeddings_val = np.array([])
text_embeddings_val = np.array([])
chem_embeddings_test = np.array([])
text_embeddings_test = np.array([])



with torch.no_grad():
  for i, d in enumerate(gt.generate_examples_train()):
    cid = np.array([d['cid']])
    text_mask = torch.Tensor(d['input']['text']['attention_mask']).bool().reshape(1,-1).to(device)

    text = torch.Tensor(d['input']['text']['input_ids']).int().reshape(1,-1).to(device)
    molecule = torch.Tensor(d['input']['molecule']['mol2vec']).float().reshape(1,-1).to(device)
    text_emb, chem_emb = model(text, molecule, text_mask)
    
    chem_emb = chem_emb.cpu().numpy()
    text_emb = text_emb.cpu().numpy()

    cids_train = np.concatenate((cids_train, cid)) if cids_train.size else cid
    chem_embeddings_train = np.concatenate((chem_embeddings_train, chem_emb)) if chem_embeddings_train.size else chem_emb
    text_embeddings_train = np.concatenate((text_embeddings_train, text_emb)) if text_embeddings_train.size else text_emb

    if (i+1) % 100 == 0: print(i+1, "samples eval.")

    
  print(cids_train.shape, chem_embeddings_train.shape)

  for d in gt.generate_examples_val():
    cid = np.array([d['cid']])
    text_mask = torch.Tensor(d['input']['text']['attention_mask']).bool().reshape(1,-1).to(device)

    text = torch.Tensor(d['input']['text']['input_ids']).int().reshape(1,-1).to(device)
    molecule = torch.Tensor(d['input']['molecule']['mol2vec']).float().reshape(1,-1).to(device)
    text_emb, chem_emb = model(text, molecule, text_mask)
    
    chem_emb = chem_emb.cpu().numpy()
    text_emb = text_emb.cpu().numpy()

    cids_val = np.concatenate((cids_val, cid)) if cids_val.size else cid
    chem_embeddings_val = np.concatenate((chem_embeddings_val, chem_emb)) if chem_embeddings_val.size else chem_emb
    text_embeddings_val = np.concatenate((text_embeddings_val, text_emb)) if text_embeddings_val.size else text_emb

  print(cids_val.shape, chem_embeddings_val.shape)
  
  for d in gt.generate_examples_test():
    cid = np.array([d['cid']])
    text_mask = torch.Tensor(d['input']['text']['attention_mask']).bool().reshape(1,-1).to(device)

    text = torch.Tensor(d['input']['text']['input_ids']).int().reshape(1,-1).to(device)
    molecule = torch.Tensor(d['input']['molecule']['mol2vec']).float().reshape(1,-1).to(device)
    text_emb, chem_emb = model(text, molecule, text_mask)
    
    chem_emb = chem_emb.cpu().numpy()
    text_emb = text_emb.cpu().numpy()

    cids_test = np.concatenate((cids_test, cid)) if cids_test.size else cid
    chem_embeddings_test = np.concatenate((chem_embeddings_test, chem_emb)) if chem_embeddings_test.size else chem_emb
    text_embeddings_test = np.concatenate((text_embeddings_test, text_emb)) if text_embeddings_test.size else text_emb

print(cids_test.shape, chem_embeddings_test.shape)

emb_path = mounted_path+"embeddings/"
if not os.path.exists(emb_path):
  os.mkdir(emb_path)
np.save(emb_path+"cids_train.npy", cids_train)
np.save(emb_path+"cids_val.npy", cids_val)
np.save(emb_path+"cids_test.npy", cids_test)
np.save(emb_path+"chem_embeddings_train.npy", chem_embeddings_train)
np.save(emb_path+"chem_embeddings_val.npy", chem_embeddings_val)
np.save(emb_path+"chem_embeddings_test.npy", chem_embeddings_test)
np.save(emb_path+"text_embeddings_train.npy", text_embeddings_train)
np.save(emb_path+"text_embeddings_val.npy", text_embeddings_val)
np.save(emb_path+"text_embeddings_test.npy", text_embeddings_test)
