# ALTeGraD 2023 Data Challenge 
## Molecule Retrieval with Natural Language Queries
### École Polytechnique

MLP

In [1]:
import os
import numpy as np
import pandas as pd
import torch

# Import evaluation metric
from sklearn.metrics import label_ranking_average_precision_score # Use : label_ranking_average_precision_score(y_true, y_pred)

# Loading token embedding dictionary
token_embedding_dict = np.load("data/token_embedding_dict.npy", allow_pickle=True)[()]

# ??
from transformers import AutoTokenizer

# ??
from dataloader import GraphTextDataset

# ??
from torch_geometric.data import Dataset 
from torch_geometric.data import Data
from torch_geometric.data import DataLoader

# ??  
from Model import Model

# ??
from torch import optim

# Counting each epoch training time
import time

# ??
from torch.utils.data import DataLoader as TorchDataLoader

# ??

from dataloader import GraphDataset
from dataloader import TextDataset

## Provided Benchmark : DistilBERT + GCN + Cosine similarity 



### Contrastive Loss

In [2]:
CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

### Text tokenization

In [3]:
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer, nrows=10)
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer, nrows=10)

### Train parameters

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 5
batch_size = 32
learning_rate = 2e-5

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = Model(model_name=model_name, num_node_features=300, nout=768, nhid=300, graph_hidden_channels=300) # nout = bert model hidden dim
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000



### Training the model

In [5]:
for i in range(nb_epochs):
    print('-----EPOCH{}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        val_loss += current_loss.item()
    best_validation_loss = min(best_validation_loss, val_loss)
    print('-----EPOCH'+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improoved saving checkpoint...')
        save_path = os.path.join('./model_checkpoints/', 'model'+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))


-----EPOCH1-----
-----EPOCH1----- done.  Validation loss:  4.646461009979248
validation loss improoved saving checkpoint...
checkpoint saved to: ./model_checkpoints/model0.pt
-----EPOCH2-----
-----EPOCH2----- done.  Validation loss:  4.641589164733887
validation loss improoved saving checkpoint...
checkpoint saved to: ./model_checkpoints/model1.pt
-----EPOCH3-----
-----EPOCH3----- done.  Validation loss:  4.627495765686035
validation loss improoved saving checkpoint...
checkpoint saved to: ./model_checkpoints/model2.pt
-----EPOCH4-----
-----EPOCH4----- done.  Validation loss:  4.617728233337402
validation loss improoved saving checkpoint...
checkpoint saved to: ./model_checkpoints/model3.pt
-----EPOCH5-----
-----EPOCH5----- done.  Validation loss:  4.603543758392334
validation loss improoved saving checkpoint...
checkpoint saved to: ./model_checkpoints/model4.pt


### Select best model

In [6]:
print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

loading best model...


Model(
  (graph_encoder): GraphEncoder(
    (relu): ReLU()
    (ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    (conv1): GCNConv(300, 300)
    (conv2): GCNConv(300, 300)
    (conv3): GCNConv(300, 300)
    (mol_hidden1): Linear(in_features=300, out_features=300, bias=True)
    (mol_hidden2): Linear(in_features=300, out_features=768, bias=True)
  )
  (text_encoder): TextEncoder(
    (bert): DistilBertModel(
      (embeddings): Embeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (transformer): Transformer(
        (layer): ModuleList(
          (0-5): 6 x TransformerBlock(
            (attention): MultiHeadSelfAttention(
              (dropout): Dropout(p=0.1, inplace=False)
              (q_lin): Linear(in_features=768, out_features=768, bias=True)
              (

### Building prediction on test data

In [7]:
graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids', nrows=10)
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer, nrows=10)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())



### Output submission 

In [10]:
from sklearn.metrics.pairwise import cosine_similarity
similarity = cosine_similarity(text_embeddings, graph_embeddings)
solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('outputs/submission_dumb.csv', index=False)

In [12]:
sub=pd.read_csv('outputs/submission_dumb.csv')

In [13]:
sub.head()

Unnamed: 0,ID,0,1,2,3,4,5,6,7,8,9
0,0,-0.009874,-0.009006,-0.015367,-0.013406,-0.013982,-0.001774,-0.008939,-0.017372,-0.010627,-0.006086
1,1,-0.003363,-0.011163,-0.018349,-0.01168,-0.014895,-0.002469,-0.010352,-0.019198,-0.00652,-0.005887
2,2,-0.015993,-0.014658,-0.024656,-0.021249,-0.018638,-0.005473,-0.016849,-0.02152,-0.016214,-0.00942
3,3,-0.003931,-0.010506,-0.018623,-0.013427,-0.011262,-0.001596,-0.00937,-0.01489,-0.009089,-0.006868
4,4,-0.016527,-0.007393,-0.012251,-0.012557,-0.013111,-0.001895,-0.00509,-0.021559,-0.014019,-0.003318
