In [2]:
pip install torch_geometric

Note: you may need to restart the kernel to use updated packages.


In [1]:
from torch import nn
import torch.nn.functional as F

from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool
from transformers import AutoModel
from torch_geometric.nn import GATConv
from torch_geometric.nn import GATv2Conv
from dataloader import GraphTextDataset, GraphDataset, TextDataset
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader

from Model import Model, GATEncoder, AMAN
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd

In [103]:

v1 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
v2 = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]])

In [3]:
def Ltriplet(v1, v2, delta):
    similarity = torch.matmul(v1,torch.transpose(v2, 0, 1))
    similarity_copy = similarity.clone()   
    size=similarity.size(0)
    delta=2
    for i in range(size):
        similarity_copy[i, i] = float('-inf')
    max_fixed_line= torch.argmax(similarity_copy, dim=1)
    max_fixed_column = torch.argmax(similarity_copy, dim=0)
    diag_index=torch.arange(size)
    relu=torch.nn.ReLU()
    loss=relu(similarity[diag_index, max_fixed_line]-similarity[diag_index, diag_index]+delta)+relu(similarity[max_fixed_column, diag_index]-similarity[diag_index, diag_index]+delta)
    return torch.sum(loss)

BCE=torch.nn.BCELoss()

def L_MA(text_prob,mol_prob):
    Dtext=text_prob
    Dmol=mol_prob
    criterion= torch.nn.BCELoss()
    text_labels = torch.ones_like(Dtext)  # Labels for real data are 1
    loss_text = criterion(Dtext, text_labels)
    mol_labels = torch.zeros_like(Dmol)  # Labels for fake data are 0
    loss_mol = criterion(Dmol, mol_labels)
    return loss_text+loss_mol

def L(v1,v2,text_prob,mol_prob,delta,lambda_):
    return Ltriplet(v1,v2,delta)+lambda_*L_MA(text_prob,mol_prob)

CE = torch.nn.CrossEntropyLoss()
def contrastive_loss(v1, v2):
  logits = torch.matmul(v1,torch.transpose(v2, 0, 1))
  labels = torch.arange(logits.shape[0], device=v1.device)
  return CE(logits, labels) + CE(torch.transpose(logits, 0, 1), labels)

In [3]:
model_name = 'distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 5
batch_size = 32
learning_rate = 2e-5
delta=2e-3
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

model = Model(model_name=model_name, num_node_features=300, nout=768, nhid=300, graph_hidden_channels=300) # nout = bert model hidden dim
model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

for i in range(nb_epochs):
    print('-----EPOCH{}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = contrastive_loss(x_graph, x_text)   
        val_loss += current_loss.item()
    best_validation_loss = min(best_validation_loss, val_loss)
    print('-----EPOCH'+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improoved saving checkpoint...')
        save_path = os.path.join('./', 'model'+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))


print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())


from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)



-----EPOCH1-----


KeyboardInterrupt: 

In [8]:
from dataloader import GraphTextDataset, GraphDataset, TextDataset
from torch_geometric.data import DataLoader
from torch.utils.data import DataLoader as TorchDataLoader

from Model import GATEncoder, AMAN
import numpy as np
from transformers import AutoTokenizer
import torch
from torch import optim
import time
import os
import pandas as pd
model_name = 'allenai/scibert_scivocab_uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]
val_dataset = GraphTextDataset(root='./data/', gt=gt, split='val', tokenizer=tokenizer)
train_dataset = GraphTextDataset(root='./data/', gt=gt, split='train', tokenizer=tokenizer)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

nb_epochs = 5
batch_size = 32
learning_rate = 2e-5
delta=2e-3
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

vocab.txt: 100%|██████████| 228k/228k [00:00<00:00, 1.79MB/s]


In [9]:
model = AMAN(model_name=model_name, num_node_features=300, nout=768, graph_hidden_channels=300)  # nout = bert model hidden dim

model.to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate,
                                betas=(0.9, 0.999),
                                weight_decay=0.01)

epoch = 0
loss = 0
losses = []
count_iter = 0
time1 = time.time()
printEvery = 50
best_validation_loss = 1000000

for i in range(nb_epochs):
    print('-----EPOCH{}-----'.format(i+1))
    model.train()
    for batch in train_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
      
        
        x_graph, x_text, text_prob, mol_prob = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = L(x_graph, x_text,text_prob,mol_prob , delta=0.3, lambda_=2.e-3   )
        optimizer.zero_grad()
        current_loss.backward()
        optimizer.step()
        loss += current_loss.item()
        
        count_iter += 1
        if count_iter % printEvery == 0:
            time2 = time.time()
            print("Iteration: {0}, Time: {1:.4f} s, training loss: {2:.4f}".format(count_iter,
                                                                        time2 - time1, loss/printEvery))
            losses.append(loss)
            loss = 0 
    model.eval()       
    val_loss = 0        
    for batch in val_loader:
        input_ids = batch.input_ids
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        x_graph, x_text, text_prob, mol_prob = model(graph_batch.to(device), 
                                input_ids.to(device), 
                                attention_mask.to(device))
        current_loss = L(x_graph, x_text,text_prob,mol_prob , delta=0.3, lambda_=2.e-3)
        val_loss += current_loss.item()
    best_validation_loss = min(best_validation_loss, val_loss)
    print('-----EPOCH'+str(i+1)+'----- done.  Validation loss: ', str(val_loss/len(val_loader)) )
    if best_validation_loss==val_loss:
        print('validation loss improoved saving checkpoint...')
        save_path = os.path.join('./', 'model'+str(i)+'.pt')
        torch.save({
        'epoch': i,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'validation_accuracy': val_loss,
        'loss': loss,
        }, save_path)
        print('checkpoint saved to: {}'.format(save_path))


print('loading best model...')
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()

graph_model = model.get_graph_encoder()
text_model = model.get_text_encoder()

test_cids_dataset = GraphDataset(root='./data/', gt=gt, split='test_cids')
test_text_dataset = TextDataset(file_path='./data/test_text.txt', tokenizer=tokenizer)

idx_to_cid = test_cids_dataset.get_idx_to_cid()

test_loader = DataLoader(test_cids_dataset, batch_size=batch_size, shuffle=False)

graph_embeddings = []
for batch in test_loader:
    for output in graph_model(batch.to(device)):
        graph_embeddings.append(output.tolist())

test_text_loader = TorchDataLoader(test_text_dataset, batch_size=batch_size, shuffle=False)
text_embeddings = []
for batch in test_text_loader:
    for output in text_model(batch['input_ids'].to(device), 
                             attention_mask=batch['attention_mask'].to(device)):
        text_embeddings.append(output.tolist())


from sklearn.metrics.pairwise import cosine_similarity

similarity = cosine_similarity(text_embeddings, graph_embeddings)

solution = pd.DataFrame(similarity)
solution['ID'] = solution.index
solution = solution[['ID'] + [col for col in solution.columns if col!='ID']]
solution.to_csv('submission.csv', index=False)

pytorch_model.bin: 100%|██████████| 442M/442M [00:23<00:00, 18.8MB/s] 


-----EPOCH1-----


KeyboardInterrupt: 

In [10]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [11]:
train_loader

<torch_geometric.deprecation.DataLoader at 0x7fad2a040be0>

In [23]:
i=0
for batch in train_loader:
    
    if i==0:
        i=1
        input_ids = batch.input_ids
        print(input_ids)
        batch.pop('input_ids')
        attention_mask = batch.attention_mask
        batch.pop('attention_mask')
        graph_batch = batch
        print(graph_batch)
        print(graph_batch.x)
        if i==1:
            break
        



tensor([[  101, 20199,  1011,  ...,     0,     0,     0],
        [  101,  1006,  1016,  ...,     0,     0,     0],
        [  101,  1006,  1011,  ...,     0,     0,     0],
        ...,
        [  101,  1050,  1011,  ...,     0,     0,     0],
        [  101,  1018,  1011,  ...,     0,     0,     0],
        [  101,  1015,  1010,  ...,     0,     0,     0]])
DataBatch(x=[1071, 300], edge_index=[2, 2226], batch=[1071], ptr=[33])
tensor([[ 0.0675, -0.1169,  0.2940,  ...,  0.1557,  0.1334, -0.0677],
        [ 0.0236, -0.3522,  0.2335,  ...,  0.2811,  0.1083,  0.2195],
        [-0.0474, -0.2763,  0.2340,  ...,  0.1403,  0.1615,  0.0154],
        ...,
        [-0.0474, -0.2763,  0.2340,  ...,  0.1403,  0.1615,  0.0154],
        [ 0.0838,  0.2535,  0.2991,  ...,  0.0321, -0.0651,  0.2079],
        [ 0.0446,  0.1850,  0.1403,  ...,  0.2777,  0.2474,  0.1810]])


In [87]:
train=pd.read_csv('/Users/damien/Downloads/train.csv')

In [58]:
hug_CIDs=train.CIDs.values

In [82]:
hug_CIDs=train.CIDs.values
hug_CIDs=list(hug_CIDs)
hug_CIDs.sort()

In [83]:
rout='./data/'
description = pd.read_csv(os.path.join(rout, 'train'+'.tsv'), sep='\t', header=None)
description = description.set_index(0).to_dict()
cids = list(description[1].keys())

In [84]:
cids=list(cids)
cids.sort()

In [85]:
print(cids)

[1, 3, 4, 5, 6, 19, 20, 22, 38, 43, 44, 45, 47, 48, 51, 58, 59, 61, 63, 64, 65, 66, 69, 70, 71, 72, 77, 78, 79, 85, 86, 89, 91, 98, 102, 104, 105, 106, 107, 108, 110, 111, 118, 119, 123, 124, 127, 128, 130, 134, 135, 138, 155, 174, 175, 177, 179, 180, 187, 190, 196, 199, 203, 204, 205, 207, 215, 216, 217, 219, 222, 223, 232, 233, 234, 236, 241, 244, 247, 248, 249, 254, 261, 263, 264, 266, 271, 273, 277, 278, 280, 281, 283, 289, 296, 300, 301, 305, 311, 312, 323, 325, 328, 332, 335, 336, 338, 339, 340, 341, 342, 355, 359, 362, 364, 366, 368, 385, 389, 397, 398, 401, 403, 409, 421, 423, 428, 434, 437, 441, 444, 447, 454, 456, 457, 460, 461, 469, 470, 471, 472, 473, 484, 486, 490, 496, 499, 500, 510, 511, 523, 525, 527, 533, 535, 541, 542, 543, 544, 545, 547, 558, 561, 563, 564, 584, 586, 588, 595, 597, 599, 602, 612, 614, 617, 632, 643, 647, 648, 649, 660, 663, 669, 670, 671, 673, 675, 677, 679, 691, 694, 700, 712, 713, 729, 738, 743, 750, 751, 752, 756, 757, 760, 763, 768, 769, 773, 774

In [86]:
print(hug_CIDs)

[29, 49, 93, 137, 144, 173, 176, 178, 186, 242, 284, 326, 363, 370, 371, 378, 449, 485, 513, 528, 553, 598, 655, 668, 674, 681, 736, 753, 767, 777, 798, 892, 943, 946, 984, 991, 994, 1003, 1005, 1021, 1047, 1049, 1456, 1639, 1678, 1855, 1983, 2033, 2130, 2131, 2230, 2259, 2284, 2355, 2484, 2582, 2703, 2720, 2754, 2764, 2794, 2833, 2859, 2971, 2972, 2997, 3082, 3126, 3182, 3224, 3245, 3276, 3278, 3292, 3337, 3343, 3347, 3366, 3410, 3488, 3657, 3857, 3866, 3913, 3954, 3960, 3969, 3976, 4062, 4078, 4091, 4107, 4130, 4133, 4156, 4170, 4171, 4380, 4431, 4510, 4534, 4601, 4634, 4650, 4657, 4747, 4874, 4876, 4878, 4931, 4944, 4980, 5216, 5253, 5269, 5311, 5320, 5325, 5329, 5359, 5379, 5419, 5447, 5523, 5564, 5571, 5590, 5787, 5798, 5819, 5852, 5905, 5959, 5984, 6001, 6005, 6014, 6087, 6116, 6132, 6166, 6167, 6209, 6249, 6264, 6297, 6342, 6360, 6386, 6421, 6436, 6477, 6497, 6590, 6593, 6646, 6716, 6724, 6727, 6728, 6805, 6842, 6872, 6944, 7018, 7213, 7250, 7289, 7298, 7329, 7342, 7346, 7360, 7

In [63]:
intersection = list(set(hug_CIDs) & set(cids))

In [64]:
print(len(intersection))

3301


In [81]:
print(len(hug_CIDs))

print(len(cids))

3301
26408


In [69]:
gt = np.load("./data/token_embedding_dict.npy", allow_pickle=True)[()]

def load_graph_for_cid(cid, root, gt):
    """
    Load the graph data for a given CID.

    :param cid: The CID for which to load the graph.
    :param root: Root directory where the data is stored.
    :param gt: Graph Tokenizer or a dictionary containing graph features.
    :return: A Data object containing the graph information for the given CID.
    """
    processed_dir = os.path.join(root, 'processed/train/')
    file_path = os.path.join(processed_dir, f'data_{cid}.pt')

    # Check if the processed file exists
    if not os.path.exists(file_path):
        raise FileNotFoundError(f"No processed file found for CID {cid}")

    # Load and return the graph data
    data = torch.load(file_path)
    return data

# Pour l'utiliser, remplacez 'cid', 'root', et 'gt' par vos valeurs spécifiques.
# cid = 12345  # Remplacez par le CID réel
# root = './data/'  # Remplacez par le répertoire racine réel
# gt = ...  # Remplacez par le tokenizer de graphe réel ou le dictionnaire de features
# graph_data = load_graph_for_cid(cid, root, gt)
# print(graph_data)


In [72]:
load_graph_for_cid('743', "./data/", gt).x.T

tensor([[ 0.1628,  0.1944,  0.2763,  ...,  0.1944,  0.1628,  0.2763],
        [-0.1553,  0.0250,  0.3478,  ...,  0.0250, -0.1553,  0.3478],
        [ 0.2417, -0.0693,  0.0340,  ..., -0.0693,  0.2417,  0.0340],
        ...,
        [ 0.2659,  0.1446, -0.0065,  ...,  0.1446,  0.2659, -0.0065],
        [ 0.1934,  0.1115, -0.3067,  ...,  0.1115,  0.1934, -0.3067],
        [-0.0798, -0.3250, -0.1269,  ..., -0.3250, -0.0798, -0.1269]])

In [98]:
train.loc[train['CIDs'] == 743]['Mol2Vec Embedding'].values[0]

'-0.45798618 1.5937392 -0.97944957 -1.8013673 -1.047049 -2.7594686 0.10395634 1.9398652 -0.8634601 0.8261346 0.065829135 -2.6128128 0.044763565 -0.7234365 -0.5508002 -2.0852473 1.4973645 0.08164382 -3.2699642 1.5141386 -1.2357903 -2.2255852 -0.6823597 -3.4186244 -0.51494426 1.1649424 0.92718047 2.0266256 -1.8114111 -3.0991344 -1.3811677 -1.3626362 1.3741624 3.4981418 0.8727637 2.4708643 -0.54434615 3.720817 0.55635947 3.525425 -1.9374981 0.52023613 0.31473002 -0.74832904 -4.082786 0.6401143 -0.26148474 1.0636846 -0.92533016 1.9902483 1.5935513 0.98510027 1.4310788 3.2329817 1.1772429 -0.06121142 0.24663907 -0.42414796 0.4072823 -1.2340124 0.5697938 -2.9446855 1.3028573 2.9703593 -1.4606227 0.3611489 1.1976044 0.3145687 1.6095968 -0.118478864 1.6885631 2.0345788 -1.9686852 0.38184828 -0.874387 -1.7679753 -1.5913996 0.7606907 0.5886597 1.1511185 -1.3328394 -2.7250404 -0.64870936 1.96917 1.0730948 -1.0234574 -1.1129141 2.0876265 1.8025873 -1.0759495 -0.46808034 1.6367911 1.0824349 1.71293

In [73]:
train

Unnamed: 0,CIDs,ChEBI descriptions,Mol2Vec Embedding
0,688461,(S)-etodolac is the S-enantiomer of etodolac. ...,2.461659 4.760218 -1.6911973 -1.5830767 -1.503...
1,44456859,Alpha-Neup5Ac-(2->3)-beta-D-Galp-(1->4)-[alpha...,-5.8559885 13.2673435 5.7469654 -3.717024 -3.7...
2,100067,"Syringaresinol is a lignan that is 7,9':7',9-d...",-3.830154 3.606443 -1.416855 -1.2316934 -2.037...
3,12512,Tert-butyl ethyl ether is an ether having ethy...,1.3672638 -1.0741857 0.34452504 -0.71439224 -2...
4,11601663,Phoyunbene D is a stilbenoid that is trans-sti...,-2.522672 3.2795463 0.07393134 -2.6464787 0.33...
...,...,...,...
3296,442658,Schaftoside is a C-glycosyl compound that is a...,-2.872496 7.7121634 -0.5171319 -2.676653 -0.36...
3297,92136187,2-O-alpha-D-mannosyl-1-O-{1-O-[(10S)-10-methyl...,9.19132 25.764975 11.173792 -23.566645 7.40210...
3298,138911112,(+)-minovincinine is a monoterpenoid indole al...,5.155534 9.635075 -2.1558852 -3.0859478 -4.690...
3299,5476873,C5-oxacyanine cation is the cationic form of a...,5.199314 9.511706 -1.1213603 -4.190405 2.15142...


In [97]:
train.loc[train['CIDs'] == 743]['Mol2Vec Embedding'].values[0]

'-0.45798618 1.5937392 -0.97944957 -1.8013673 -1.047049 -2.7594686 0.10395634 1.9398652 -0.8634601 0.8261346 0.065829135 -2.6128128 0.044763565 -0.7234365 -0.5508002 -2.0852473 1.4973645 0.08164382 -3.2699642 1.5141386 -1.2357903 -2.2255852 -0.6823597 -3.4186244 -0.51494426 1.1649424 0.92718047 2.0266256 -1.8114111 -3.0991344 -1.3811677 -1.3626362 1.3741624 3.4981418 0.8727637 2.4708643 -0.54434615 3.720817 0.55635947 3.525425 -1.9374981 0.52023613 0.31473002 -0.74832904 -4.082786 0.6401143 -0.26148474 1.0636846 -0.92533016 1.9902483 1.5935513 0.98510027 1.4310788 3.2329817 1.1772429 -0.06121142 0.24663907 -0.42414796 0.4072823 -1.2340124 0.5697938 -2.9446855 1.3028573 2.9703593 -1.4606227 0.3611489 1.1976044 0.3145687 1.6095968 -0.118478864 1.6885631 2.0345788 -1.9686852 0.38184828 -0.874387 -1.7679753 -1.5913996 0.7606907 0.5886597 1.1511185 -1.3328394 -2.7250404 -0.64870936 1.96917 1.0730948 -1.0234574 -1.1129141 2.0876265 1.8025873 -1.0759495 -0.46808034 1.6367911 1.0824349 1.71293