<a href="https://colab.research.google.com/github/mgdixon/text2mol-team29/blob/main/MLP_Ablations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Instructions

This file is for the MLP flavor of the ablations. Some of the ablations are not relevant for MLP. For example, the negative sampling loss. However, for consistency sake, we keep all of the options. Using the global variables below, pick one of the ablations to run and give the model a name. The actual ablations are below:

1. Use BERT instead of SciBERT to gauge the impact.
1. Remove the learned temperature parameter from the general loss function to gauge the impact.
1. Remove negative sampling from the loss function for the cross-modal attention model to gauge the impact.
1. Use one token for each atom (r = 0) instead of two to gauge the impact.
1. Remove layer normalization from the encoders to gauge the impact.

In [None]:
ABLATION_1 = False
ABLATION_2 = True
ABLATION_3 = False
ABLATION_4 = False
ABLATION_5 = False
MODEL = "MLP"



In [None]:
# Authenticate.
from google.colab import auth
auth.authenticate_user()

# Install Cloud Storage FUSE.
!echo "deb https://packages.cloud.google.com/apt gcsfuse-`lsb_release -c -s` main" | sudo tee /etc/apt/sources.list.d/gcsfuse.list
!curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -
!apt -qq update && apt -qq install gcsfuse

In [None]:
# Mount a Cloud Storage bucket or location, without the gs:// prefix.
mount_path = "team29-text2mol"  # or a location like "my-bucket/path/to/mount"
local_path = f"/mnt/gs/{mount_path}"

!mkdir -p {local_path}
!gcsfuse --implicit-dirs {mount_path} {local_path}

In [None]:
# Then you can access it like a local path.
!ls -lh {local_path}/outputs/

In [None]:
import os
import shutil
import time

import math

import numpy as np

import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import cosine_similarity

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import tokenizers
from tokenizers import Tokenizer
from transformers import BertTokenizerFast, BertModel

import csv

In [None]:
if ABLATION_1:
  output_path = local_path + "/" + MODEL + "_Ablation_1/"
elif ABLATION_2:
  output_path = local_path + "/" + MODEL + "_Ablation_2/"
elif ABLATION_3:
  output_path = local_path + "/" + MODEL + "_Ablation_3/"
elif ABLATION_4:
  output_path = local_path + "/" + MODEL + "_Ablation_4/"
elif ABLATION_5:
  output_path = local_path + "/" + MODEL + "_Ablation_5/"
else:
  output_path = local_path + "/" + MODEL + "_NoAblation/"

emb_path = output_path + "embeddings/"

if not os.path.exists(output_path):
  os.mkdir(output_path)

if not os.path.exists(emb_path):
  os.mkdir(emb_path)


In [None]:
#Need a special generator for random sampling:

class GenerateData():
  def __init__(self, path_train, path_val, path_test, path_molecules,
               path_token_embs, ablation_1 = False):

    self.path_train = path_train
    self.path_val = path_val
    self.path_test = path_test
    self.path_molecules = path_molecules
    self.path_token_embs = path_token_embs
    self.ablation_1 = ablation_1

    self.text_trunc_length = 256

    self.prep_text_tokenizer()

    self.load_substructures()

    self.batch_size = 32

    self.store_descriptions()

  def load_substructures(self):
    self.molecule_sentences = {}
    self.molecule_tokens = {}

    total_tokens = set()
    self.max_mol_length = 0
    with open(self.path_molecules) as f:
      for line in f:
        spl = line.split(":")
        cid = spl[0]
        tokens = spl[1].strip()
        self.molecule_sentences[cid] = tokens
        t = tokens.split()
        total_tokens.update(t)
        size = len(t)
        if size > self.max_mol_length: self.max_mol_length = size


    self.token_embs = np.load(self.path_token_embs, allow_pickle = True)[()]

  def prep_text_tokenizer(self):
    if self.ablation_1:
      self.text_tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
    else:
      self.text_tokenizer = BertTokenizerFast.from_pretrained("allenai/scibert_scivocab_uncased")

  # Refactored version that's much shorter
  def read_file(self, file_path):
    cids = []
    with open(file_path) as f:
        reader = csv.DictReader(f, delimiter="\t", quoting=csv.QUOTE_NONE,
                                fieldnames=['cid', 'mol2vec', 'desc'])
        for line in reader:
            self.descriptions[line['cid']] = line['desc']
            self.mols[line['cid']] = line['mol2vec']
            cids.append(line['cid'])
    return cids

  def store_descriptions(self):
    self.descriptions = {}
    self.mols = {}
    self.training_cids = self.read_file(self.path_train)
    self.validation_cids = self.read_file(self.path_val)
    self.test_cids = self.read_file(self.path_test)

  # Refactored version
  def generate_examples(self, cids):
    """Yields examples."""

    np.random.shuffle(cids)

    for cid in cids:
      text_input = self.text_tokenizer(self.descriptions[cid], truncation=True,
                                       max_length=self.text_trunc_length,
                                       padding='max_length',
                                       return_tensors = 'np')

      yield {
          'cid': cid,
          'input': {
              'text': {
                'input_ids': text_input['input_ids'].squeeze(),
                'attention_mask': text_input['attention_mask'].squeeze(),
                'original_text': self.descriptions[cid]
              },
              'molecule' : {
                    'mol2vec' : np.fromstring(self.mols[cid], sep = " "),
                    'cid' : cid
              },
          },
      }

  def generate_examples_train(self):
    yield from self.generate_examples(self.training_cids)

  def generate_examples_val(self):
    yield from self.generate_examples(self.validation_cids)

  def generate_examples_test(self):
    yield from self.generate_examples(self.test_cids)

  def generate_examples_custom(self, custom_cids):
    yield from self.generate_examples(custom_cids)

data_path = local_path + "/data/"
mounted_path_token_embs = data_path + "token_embedding_dict.npy"
mounted_path_train = data_path + "training.txt"
mounted_path_val = data_path + "val.txt"
mounted_path_test = data_path + "test.txt"
mounted_path_molecules = data_path + "ChEBI_defintions_substructure_corpus.cp"
gt = GenerateData(mounted_path_train, mounted_path_val, mounted_path_test,
                  mounted_path_molecules, mounted_path_token_embs,
                  ablation_1 = ABLATION_1)


In [None]:


class Dataset(Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, gen, length):
      'Initialization'

      self.gen = gen
      self.it = iter(self.gen())

      self.length = length

  def __len__(self):
      'Denotes the total number of samples'
      return self.length


  def __getitem__(self, index):
      'Generates one sample of data'

      try:
        ex = next(self.it)
      except StopIteration:
        self.it = iter(self.gen())
        ex = next(self.it)

      X = ex['input']
      y = 1

      return X, y

training_set = Dataset(gt.generate_examples_train, len(gt.training_cids))
validation_set = Dataset(gt.generate_examples_val, len(gt.validation_cids))
test_set = Dataset(gt.generate_examples_test, len(gt.test_cids))


In [None]:

# Parameters
params = {'batch_size': gt.batch_size,
          'shuffle': True,
          'num_workers': 1}

training_generator = DataLoader(training_set, **params)
validation_generator = DataLoader(validation_set, **params)
test_generator = DataLoader(test_set, **params)


In [None]:

class MLPModel(nn.Module):
    def __init__(self, ntoken, ninp, nout, nhid, dropout=0.5,
                 ablation_1 = False, ablation_2 = False, ablation_3 = False,
                 ablation_4 = False, ablation_5 = False):
        super(MLPModel, self).__init__()

        # This could be a list, but the code is simpler to read using
        # discrete variables
        self.ablation_1 = ablation_1 # BERT not SciBert
        self.ablation_2 = ablation_2 # Remove temperature
        self.ablation_3 = ablation_3 # Remove negative sampling
        self.ablation_4 = ablation_4 # One token per atom?
        self.ablation_5 = ablation_5 # Remove layer norm

        self.text_hidden1 = nn.Linear(ninp, nout)

        self.ninp = ninp
        self.nhid = nhid
        self.nout = nout

        self.drop = nn.Dropout(p=dropout)

        self.mol_hidden1 = nn.Linear(nout, nhid)
        self.mol_hidden2 = nn.Linear(nhid, nhid)
        self.mol_hidden3 = nn.Linear(nhid, nout)

        # Ablation 2 removes the learned temperature
        if not self.ablation_2:
          self.temp = nn.Parameter(torch.Tensor([0.07]))
          self.register_parameter( 'temp' , self.temp )

        # Ablation 5 removes the layer normalization
        if not self.ablation_5:
          self.ln1 = nn.LayerNorm((nout))
          self.ln2 = nn.LayerNorm((nout))

        self.relu = nn.ReLU()
        self.selu = nn.SELU()

        self.other_params = list(self.parameters()) #get all but bert params

        # Ablation 1 chooses between the SciBert model and a base BERT model
        if self.ablation_1:
          self.text_transformer_model = BertModel.from_pretrained("google-bert/bert-base-uncased")
        else:
          self.text_transformer_model = BertModel.from_pretrained("allenai/scibert_scivocab_uncased")

        self.text_transformer_model.train()

    def forward(self, text, molecule, text_mask = None, molecule_mask = None):

        text_encoder_output = self.text_transformer_model(text, attention_mask = text_mask)

        text_x = text_encoder_output['pooler_output']
        text_x = self.text_hidden1(text_x)

        x = self.relu(self.mol_hidden1(molecule))
        x = self.relu(self.mol_hidden2(x))
        x = self.mol_hidden3(x)

        # Ablation 5 removes the layer normalization
        if not self.ablation_5:
          x = self.ln1(x)
          text_x = self.ln2(text_x)

        # Ablation 2 removes the learned temperature
        if not self.ablation_2:
          x = x * torch.exp(self.temp)
          text_x = text_x * torch.exp(self.temp)

        return text_x, x


In [None]:
model = MLPModel(ntoken = gt.text_tokenizer.vocab_size, ninp = 768, nhid = 600,
                 nout = 300, ablation_1 = ABLATION_1, ablation_2 = ABLATION_2,
                 ablation_3 = ABLATION_3, ablation_4 = ABLATION_4,
                 ablation_5 = ABLATION_5)

In [None]:
import torch.optim as optim
from transformers.optimization import get_linear_schedule_with_warmup

epochs = 40

init_lr = 1e-4
bert_lr = 3e-5
bert_params = list(model.text_transformer_model.parameters())

optimizer = optim.Adam([
                {'params': model.other_params},
                {'params': bert_params, 'lr': bert_lr}
            ], lr=init_lr)

num_warmup_steps = 1000
num_training_steps = epochs * len(training_generator) - num_warmup_steps
scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps = num_warmup_steps,
                                            num_training_steps = num_training_steps)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(device)

tmp = model.to(device)


In [None]:
criterion = nn.CrossEntropyLoss()

def loss_func(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0]).to(device)
  return criterion(logits, labels) + criterion(torch.transpose(logits, 0, 1),
                                               labels)


In [None]:
train_losses = []
val_losses = []

train_acc = []
val_acc = []



# Loop over epochs
for epoch in range(epochs):
    # Training

    start_time = time.time()
    running_loss = 0.0
    running_acc = 0.0
    # Keep the losses on the GPU
    running_loss_gpu = torch.tensor(0.0).to(device)

    model.train()
    for i, d in enumerate(training_generator):
        batch, labels = d
        # Transfer to GPU

        text_mask = batch['text']['attention_mask'].bool()

        text = batch['text']['input_ids'].to(device)
        text_mask = text_mask.to(device)
        molecule = batch['molecule']['mol2vec'].float().to(device)

        text_out, chem_out = model(text, molecule, text_mask)

        loss = loss_func(text_out, chem_out).to(device)

        running_loss_gpu += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        scheduler.step()

        if (i+1) % 100 == 0:
          running_loss = running_loss_gpu.item()
          print(i+1, "batches trained. Avg loss:\t", running_loss / (i+1),
                ". Avg ms/step =", 1000*(time.time()-start_time)/(i+1))

    running_loss = running_loss_gpu.item()
    train_losses.append(running_loss / (i+1))
    train_acc.append(running_acc / (i+1))

    print("Epoch", epoch, "training loss:\t\t", running_loss / (i+1),
          ". Time =", (time.time()-start_time), "seconds.")


    # Validation
    model.eval()
    with torch.set_grad_enabled(False):
      start_time = time.time()
      running_acc = 0.0
      running_loss = 0.0
      # Keep the losses on the GPU
      running_loss_gpu = torch.tensor(0.0).to(device)

      for i, d in enumerate(validation_generator):
          batch, labels = d
          # Transfer to GPU

          text_mask = batch['text']['attention_mask'].bool()

          text = batch['text']['input_ids'].to(device)
          text_mask = text_mask.to(device)
          molecule = batch['molecule']['mol2vec'].float().to(device)



          text_out, chem_out = model(text, molecule, text_mask)

          loss = loss_func(text_out, chem_out).to(device)
          running_loss_gpu += loss

          if (i+1) % 100 == 0:
            running_loss = running_loss_gpu.item()
            print(i+1, "batches eval. Avg loss:\t", running_loss / (i+1),
                  ". Avg ms/step =", 1000*(time.time()-start_time)/(i+1))

      running_loss = running_loss_gpu.item()
      val_losses.append(running_loss / (i+1))
      val_acc.append(running_acc / (i+1))


      min_loss = np.min(val_losses)
      if val_losses[-1] == min_loss:
          torch.save(model.state_dict(),
                     output_path + \
                     'weights_pretrained.{epoch:02d}-{min_loss:.2f}.pt'\
                     .format(epoch = epoch, min_loss = min_loss))

    print("Epoch", epoch, "validation loss:\t", running_loss / (i+1),
          ". Time =", (time.time()-start_time), "seconds.")


torch.save(model.state_dict(), output_path + "final_weights."+str(epochs)+".pt")


## Extract Embeddings



In [None]:
cids_train = np.array([])
cids_val = np.array([])
cids_test = np.array([])
chem_embeddings_train = np.array([])
text_embeddings_train = np.array([])
chem_embeddings_val = np.array([])
text_embeddings_val = np.array([])
chem_embeddings_test = np.array([])
text_embeddings_test = np.array([])

with torch.no_grad():
  for i, d in enumerate(gt.generate_examples_train()):
    cid = np.array([d['cid']])
    text_mask = torch.Tensor(d['input']['text']['attention_mask']).bool().reshape(1,-1).to(device)

    text = torch.Tensor(d['input']['text']['input_ids']).int().reshape(1,-1).to(device)
    molecule = torch.Tensor(d['input']['molecule']['mol2vec']).float().reshape(1,-1).to(device)
    text_emb, chem_emb = model(text, molecule, text_mask)

    chem_emb = chem_emb.cpu().numpy()
    text_emb = text_emb.cpu().numpy()

    cids_train = np.concatenate((cids_train, cid)) if cids_train.size else cid
    chem_embeddings_train = np.concatenate((chem_embeddings_train, chem_emb)) if chem_embeddings_train.size else chem_emb
    text_embeddings_train = np.concatenate((text_embeddings_train, text_emb)) if text_embeddings_train.size else text_emb

    if (i+1) % 100 == 0: print(i+1, "samples eval.")


  print(cids_train.shape, chem_embeddings_train.shape)

  for d in gt.generate_examples_val():
    cid = np.array([d['cid']])
    text_mask = torch.Tensor(d['input']['text']['attention_mask']).bool().reshape(1,-1).to(device)

    text = torch.Tensor(d['input']['text']['input_ids']).int().reshape(1,-1).to(device)
    molecule = torch.Tensor(d['input']['molecule']['mol2vec']).float().reshape(1,-1).to(device)
    text_emb, chem_emb = model(text, molecule, text_mask)

    chem_emb = chem_emb.cpu().numpy()
    text_emb = text_emb.cpu().numpy()

    cids_val = np.concatenate((cids_val, cid)) if cids_val.size else cid
    chem_embeddings_val = np.concatenate((chem_embeddings_val, chem_emb)) if chem_embeddings_val.size else chem_emb
    text_embeddings_val = np.concatenate((text_embeddings_val, text_emb)) if text_embeddings_val.size else text_emb

  print(cids_val.shape, chem_embeddings_val.shape)

  for d in gt.generate_examples_test():
    cid = np.array([d['cid']])
    text_mask = torch.Tensor(d['input']['text']['attention_mask']).bool().reshape(1,-1).to(device)

    text = torch.Tensor(d['input']['text']['input_ids']).int().reshape(1,-1).to(device)
    molecule = torch.Tensor(d['input']['molecule']['mol2vec']).float().reshape(1,-1).to(device)
    text_emb, chem_emb = model(text, molecule, text_mask)

    chem_emb = chem_emb.cpu().numpy()
    text_emb = text_emb.cpu().numpy()

    cids_test = np.concatenate((cids_test, cid)) if cids_test.size else cid
    chem_embeddings_test = np.concatenate((chem_embeddings_test, chem_emb)) if chem_embeddings_test.size else chem_emb
    text_embeddings_test = np.concatenate((text_embeddings_test, text_emb)) if text_embeddings_test.size else text_emb

print(cids_test.shape, chem_embeddings_test.shape)


np.save(emb_path + "cids_train.npy", cids_train)
np.save(emb_path + "cids_val.npy", cids_val)
np.save(emb_path + "cids_test.npy", cids_test)
np.save(emb_path + "chem_embeddings_train.npy", chem_embeddings_train)
np.save(emb_path + "chem_embeddings_val.npy", chem_embeddings_val)
np.save(emb_path + "chem_embeddings_test.npy", chem_embeddings_test)
np.save(emb_path + "text_embeddings_train.npy", text_embeddings_train)
np.save(emb_path + "text_embeddings_val.npy", text_embeddings_val)
np.save(emb_path + "text_embeddings_test.npy", text_embeddings_test)
