In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
#Launch Kaggle TPU session
print('Session Start!')

In [None]:
!pip install torch_geometric
!pip install indexed_bzip2
!pip install rdkit

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_geometric.nn import GATConv
import indexed_bzip2 as ibz2
import os
import pickle
import rdkit
from rdkit import Chem
#from torch_scatter import scatter
from multiprocessing import Pool
from tqdm import tqdm
import gc
from torch_geometric.loader import DataLoader as PyGDataLoader
print('import DONE!')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_max_pool as gmp

class GATNet(torch.nn.Module):
    def __init__(self, num_features=9, n_output=3,n_filters=32, embed_dim=128, output_dim=1, dropout=0.2):
        super(GATNet, self).__init__()

        # GATConv
        self.gcn1 = GATConv(num_features, num_features * 10, heads=10, dropout=dropout)
        self.gcn2 = GATConv(num_features * 100, output_dim, dropout=dropout)
        self.fc_g1 = nn.Linear(output_dim, output_dim)


        self.fc1 = nn.Linear(output_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.out = nn.Linear(32, n_output)

        # relu and dropout
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        x, edge_index, batch = data.x.float(), data.edge_index, data.batch

        x = F.dropout(x, p=0.2, training=self.training)
        x = F.elu(self.gcn1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.gcn2(x, edge_index)
        x = self.relu(x)
        x = gmp(x, batch)          
        x = self.fc_g1(x)
        x = self.relu(x)

        # dense layers
        xc = self.fc1(x)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out
    
model = GATNet()
optimizer = optim.Adam(model.parameters(), lr=0.01)

def custom_loss(output, target):
    loss = 0
    for i in range(3):  
        loss += F.binary_cross_entropy_with_logits(output[0][i], target[i].float())
    return loss / 3  


# Training
def train_model(training_set):
    model.train()
    for data in training_set:
        optimizer.zero_grad()
        output = model(data)
        #print(output)
        #print(data.y)
        #print(output[0][2].float(), data.y[0].float())
        loss = custom_loss(output, data.y)#F.binary_cross_entropy_with_logits(output, data.y.view(-1, 3))
        loss.backward()
        optimizer.step()

In [None]:
#Load Train_graph
def load_compressed_ibz2_pickle(file):
    with ibz2.open(file, parallelization=os.cpu_count()) as f:
        data = pickle.load(f)
    return data
gdf_train = load_compressed_ibz2_pickle(
    '/kaggle/input/leash-bio-processed-dataset/train-replace-c-30m.graph.pickle.b2z'
)
print('train_graph Loaded!')
print(len(gdf_train))

In [None]:
# Load Train_bind
trainbind_data = np.load('/kaggle/input/leash-bio-processed-dataset/train.bind.npz')
train_bind = trainbind_data['bind']
trainbind_data.close()
print('train_bind Loaded!')
print(len(train_bind))

In [None]:
#Create testing subset
train_graph_sample = gdf_train[:60000]
train_bind_sample = train_bind[:60000]
print(len(train_graph_sample),len(train_bind_sample))

In [None]:
train_bind_sample[:5]

In [None]:
#Helper: convert graph to pyg list
def to_pyg_list(graph):
    L = len(graph)
    for i in tqdm(range(L)):
        N, edge, node_feature, edge_feature = graph[i]
        graph[i] = Data(
            idx=i,
            edge_index=torch.from_numpy(edge.T).int(),
            x=torch.from_numpy(node_feature).byte(),
            edge_attr=torch.from_numpy(edge_feature).byte(),
            y=torch.tensor(train_bind_sample[i])
        )
    return graph

#Helper for test_split
def to_pyg_list_test(graph):
    L = len(graph)
    for i in tqdm(range(L)):
        N, edge, node_feature, edge_feature = graph[i]
        graph[i] = Data(
            idx=i,
            edge_index=torch.from_numpy(edge.T).int(),
            x=torch.from_numpy(node_feature).byte(),
            edge_attr=torch.from_numpy(edge_feature).byte()
        )
    return graph

In [None]:
# Converted train-graph, with label
train_graph = to_pyg_list(train_graph_sample)
train_graph[:5]

In [None]:
# Create sample training_set of torch_geometric.data objects
training_set = train_graph[:1000]#[Data(idx=-1, edge_index=None, x=torch.rand(5, 20), edge_attr=None, y=torch.tensor([0, 0, 1])) for _ in range(10)]

# Train with the training_set
train_model(training_set)
print('training DONE!')

# Prediction
def predict_label(data):
    model.eval()
    output = model(data)
    labels = torch.sigmoid(output).detach().numpy()
    predicted_label = 0 if np.mean(labels)<0.1 else 1
    return predicted_label

In [None]:
# Prediction test
predict_label(to_pyg_list_test([test_item])[0])

In [None]:
#Playground

In [None]:
# Train the model with the training_set
train_model(training_set)

# Create a new torch_geometric.data object for prediction
data_to_predict = Data(idx=-1, edge_index=None, x=torch.rand(5, 20), edge_attr=None)

# Make predictions using the model
predicted_label = predict_label(data_to_predict)
print(predicted_label)

In [None]:
class GAT(nn.Module):
    def __init__(self, in_features, out_features, num_heads):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_features, out_features, heads=num_heads)
        self.fc = nn.Linear(out_features*num_heads, 3)
        
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.elu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.fc(x)
        return x

# Define the model
model = GAT(in_features=20, out_features=8, num_heads=3)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Train the model
def train_model(training_set):
    model.train()
    for data in training_set:
        optimizer.zero_grad()
        output = model(data)
        loss = F.binary_cross_entropy_with_logits(output, data.y.view(-1, 3))
        loss.backward()
        optimizer.step()

# Predict the label
def predict_label(data):
    model.eval()
    output = model(data)
    predicted_label = torch.sigmoid(output).detach().numpy()
    return predicted_label

# Create a sample training_set of torch_geometric.data objects
training_set = [Data(idx=-1, edge_index=None, x=torch.rand(5, 20), edge_attr=None, y=torch.tensor([0, 0, 1])) for _ in range(10)]

# Train the model with the training_set
train_model(training_set)

# Create a new torch_geometric.data object for prediction
data_to_predict = Data(idx=-1, edge_index=None, x=torch.rand(5, 20), edge_attr=None)

# Make predictions using the model
predicted_label = predict_label(data_to_predict)
print(predicted_label)

In [None]:
import numpy as np

import rdkit
from rdkit import Chem

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F

print('import ok!')

In [None]:
# helper
# torch version of np unpackbits
#https://gist.github.com/vadimkantorov/30ea6d278bc492abf6ad328c6965613a

def tensor_dim_slice(tensor, dim, dim_slice):
	return tensor[(dim if dim >= 0 else dim + tensor.dim()) * (slice(None),) + (dim_slice,)]

# @torch.jit.script
def packshape(shape, dim: int = -1, mask: int = 0b00000001, dtype=torch.uint8, pack=True):
	dim = dim if dim >= 0 else dim + len(shape)
	bits, nibble = (
		8 if dtype is torch.uint8 else 16 if dtype is torch.int16 else 32 if dtype is torch.int32 else 64 if dtype is torch.int64 else 0), (
		1 if mask == 0b00000001 else 2 if mask == 0b00000011 else 4 if mask == 0b00001111 else 8 if mask == 0b11111111 else 0)
	# bits = torch.iinfo(dtype).bits # does not JIT compile
	assert nibble <= bits and bits % nibble == 0
	nibbles = bits // nibble
	shape = (shape[:dim] + (int(math.ceil(shape[dim] / nibbles)),) + shape[1 + dim:]) if pack else (
				shape[:dim] + (shape[dim] * nibbles,) + shape[1 + dim:])
	return shape, nibbles, nibble

# @torch.jit.script
def F_unpackbits(tensor, dim: int = -1, mask: int = 0b00000001, shape=None, out=None, dtype=torch.uint8):
	dim = dim if dim >= 0 else dim + tensor.dim()
	shape_, nibbles, nibble = packshape(tensor.shape, dim=dim, mask=mask, dtype=tensor.dtype, pack=False)
	shape = shape if shape is not None else shape_
	out = out if out is not None else torch.empty(shape, device=tensor.device, dtype=dtype)
	assert out.shape == shape

	if shape[dim] % nibbles == 0:
		shift = torch.arange((nibbles - 1) * nibble, -1, -nibble, dtype=torch.uint8, device=tensor.device)
		shift = shift.view(nibbles, *((1,) * (tensor.dim() - dim - 1)))
		return torch.bitwise_and((tensor.unsqueeze(1 + dim) >> shift).view_as(out), mask, out=out)

	else:
		for i in range(nibbles):
			shift = nibble * i
			sliced_output = tensor_dim_slice(out, dim, slice(i, None, nibbles))
			sliced_input = tensor.narrow(dim, 0, sliced_output.shape[dim])
			torch.bitwise_and(sliced_input >> shift, mask, out=sliced_output)
	return out

class dotdict(dict):
	__setattr__ = dict.__setitem__
	__delattr__ = dict.__delitem__
	
	def __getattr__(self, name):
		try:
			return self[name]
		except KeyError:
			raise AttributeError(name)

            
print('helper ok!')

# mol to graph adopted from
# from https://github.com/LiZhang30/GPCNDTA/blob/main/utils/DrugGraph.py

PACK_NODE_DIM=9
PACK_EDGE_DIM=1
NODE_DIM=PACK_NODE_DIM*8
EDGE_DIM=PACK_EDGE_DIM*8

def one_of_k_encoding(x, allowable_set, allow_unk=False):
	if x not in allowable_set:
		if allow_unk:
			x = allowable_set[-1]
		else:
			raise Exception(f'input {x} not in allowable set{allowable_set}!!!')
	return list(map(lambda s: x == s, allowable_set))


#Get features of an atom (one-hot encoding:)
'''
	1.atom element: 44+1 dimensions    
	2.the atom's hybridization: 5 dimensions
	3.degree of atom: 6 dimensions                        
	4.total number of H bound to atom: 6 dimensions
	5.number of implicit H bound to atom: 6 dimensions    
	6.whether the atom is on ring: 1 dimension
	7.whether the atom is aromatic: 1 dimension           
	Total: 70 dimensions
'''

ATOM_SYMBOL = [
	'C', 'N', 'O', 'S', 'F', 'Si', 'P', 'Cl', 'Br', 'Mg',
	'Na', 'Ca', 'Fe', 'As', 'Al', 'I', 'B', 'V', 'K', 'Tl',
	'Yb', 'Sb', 'Sn', 'Ag', 'Pd', 'Co', 'Se', 'Ti', 'Zn', 'H',
	'Li', 'Ge', 'Cu', 'Au', 'Ni', 'Cd', 'In', 'Mn', 'Zr', 'Cr',
	'Pt', 'Hg', 'Pb', 'Dy',
	#'Unknown'
]
#print('ATOM_SYMBOL', len(ATOM_SYMBOL))44
HYBRIDIZATION_TYPE = [
	Chem.rdchem.HybridizationType.S,
	Chem.rdchem.HybridizationType.SP,
	Chem.rdchem.HybridizationType.SP2,
	Chem.rdchem.HybridizationType.SP3,
	Chem.rdchem.HybridizationType.SP3D
]

def get_atom_feature(atom):
	feature = (
		 one_of_k_encoding(atom.GetSymbol(), ATOM_SYMBOL)
	   + one_of_k_encoding(atom.GetHybridization(), HYBRIDIZATION_TYPE)
	   + one_of_k_encoding(atom.GetDegree(), [0, 1, 2, 3, 4, 5])
	   + one_of_k_encoding(atom.GetTotalNumHs(), [0, 1, 2, 3, 4, 5])
	   + one_of_k_encoding(atom.GetImplicitValence(), [0, 1, 2, 3, 4, 5])
	   + [atom.IsInRing()]
	   + [atom.GetIsAromatic()]
	)
	#feature = np.array(feature, dtype=np.uint8)
	feature = np.packbits(feature)
	return feature


#Get features of an edge (one-hot encoding)
'''
	1.single/double/triple/aromatic: 4 dimensions       
	2.the atom's hybridization: 1 dimensions
	3.whether the bond is on ring: 1 dimension          
	Total: 6 dimensions
'''

def get_bond_feature(bond):
	bond_type = bond.GetBondType()
	feature = [
		bond_type == Chem.rdchem.BondType.SINGLE,
		bond_type == Chem.rdchem.BondType.DOUBLE,
		bond_type == Chem.rdchem.BondType.TRIPLE,
		bond_type == Chem.rdchem.BondType.AROMATIC,
		bond.GetIsConjugated(),
		bond.IsInRing()
	]
	#feature = np.array(feature, dtype=np.uint8)
	feature = np.packbits(feature)
	return feature


def smile_to_graph(smiles):
	mol = Chem.MolFromSmiles(smiles)
	N = mol.GetNumAtoms()
	node_feature = []
	edge_feature = []
	edge = []
	for i in range(mol.GetNumAtoms()):
		atom_i = mol.GetAtomWithIdx(i)
		atom_i_features = get_atom_feature(atom_i)
		node_feature.append(atom_i_features)

		for j in range(mol.GetNumAtoms()):
			bond_ij = mol.GetBondBetweenAtoms(i, j)
			if bond_ij is not None:
				edge.append([i, j])
				bond_features_ij = get_bond_feature(bond_ij)
				edge_feature.append(bond_features_ij)
	node_feature=np.stack(node_feature)
	edge_feature=np.stack(edge_feature)
	edge = np.array(edge,dtype=np.uint8)
	return N,edge,node_feature,edge_feature

def to_pyg_format(N,edge,node_feature,edge_feature):
	graph = Data(
		idx=-1,
		edge_index = torch.from_numpy(edge.T).int(),
		x          = torch.from_numpy(node_feature).byte(),
		edge_attr  = torch.from_numpy(edge_feature).byte(),
	)
	return graph

#debug one example
g = to_pyg_format(*smile_to_graph(smiles="C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1"))
print(g)
print('[Dy] is replaced by C !!')
print('smile_to_graph() ok!')

In [None]:
print('DONE')

In [None]:
import numpy as np

import rdkit
from rdkit import Chem

import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from torch_geometric.nn import MessagePassing, global_mean_pool
from torch_scatter import scatter

import torch
import torch.nn as nn
import torch.nn.functional as F

print('import ok!')

In [None]:
pip install rdkit

In [None]:
DEVICE='cpu'

# i have removed all comments here to jepp it clean. refer to orginal link for code comments
# of MPNNModel
class MPNNLayer(MessagePassing):
	def __init__(self, emb_dim=64, edge_dim=4, aggr='add'):
		super().__init__(aggr=aggr)

		self.emb_dim = emb_dim
		self.edge_dim = edge_dim
		self.mlp_msg = nn.Sequential(
			nn.Linear(2 * emb_dim + edge_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
			nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
		)
		self.mlp_upd = nn.Sequential(
			nn.Linear(2 * emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU(),
			nn.Linear(emb_dim, emb_dim), nn.BatchNorm1d(emb_dim), nn.ReLU()
		)

	def forward(self, h, edge_index, edge_attr):
		out = self.propagate(edge_index, h=h, edge_attr=edge_attr)
		return out

	def message(self, h_i, h_j, edge_attr):
		msg = torch.cat([h_i, h_j, edge_attr], dim=-1)
		return self.mlp_msg(msg)

	def aggregate(self, inputs, index):
		return scatter(inputs, index, dim=self.node_dim, reduce=self.aggr)

	def update(self, aggr_out, h):
		upd_out = torch.cat([h, aggr_out], dim=-1)
		return self.mlp_upd(upd_out)

	def __repr__(self) -> str:
		return (f'{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})')


class MPNNModel(nn.Module):
    def __init__(self, num_layers=3, emb_dim=64, in_dim=11, edge_dim=4, out_dim=1):
        super().__init__()

        self.lin_in = nn.Linear(in_dim, emb_dim)

        # Stack of MPNN layers
        self.convs = torch.nn.ModuleList()
        for layer in range(num_layers):
            self.convs.append(MPNNLayer(emb_dim, edge_dim, aggr='add'))

        self.pool = global_mean_pool

    def forward(self, data): #PyG.Data - batch of PyG graphs

        h = self.lin_in(F_unpackbits(data.x,-1).float())  

        for conv in self.convs:
            h = h + conv(h, data.edge_index.long(), F_unpackbits(data.edge_attr,-1).float())  # (n, d) -> (n, d)

        h_graph = self.pool(h, data.batch)  
        return h_graph

# our prediction model here !!!!
class Net(nn.Module):
    def __init__(self, ):
        super().__init__()

        self.output_type = ['infer', 'loss']

        graph_dim=96
        self.smile_encoder = MPNNModel(
            in_dim=NODE_DIM, edge_dim=EDGE_DIM, emb_dim=graph_dim, num_layers=4,
        )
        self.bind = nn.Sequential(
            nn.Linear(graph_dim, 1024),
            #nn.BatchNorm1d(1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 1024),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(1024, 512),
            #nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.1),
            nn.Linear(512, 3),
        )

    def forward(self, batch):
        graph = batch['graph']
        x = self.smile_encoder(graph) 
        bind = self.bind(x)

        # --------------------------
        output = {}
        if 'loss' in self.output_type:
            target = batch['bind']
            output['bce_loss'] = F.binary_cross_entropy_with_logits(bind.float(), target.float())
        if 'infer' in self.output_type:
            output['bind'] = torch.sigmoid(bind)

        return output
    
print('Create Model OK!')

In [None]:
#debug: make some dummy data and run

def run_check_net():
	batch_size = 3
	node_dim=NODE_DIM
	edge_dim=EDGE_DIM

	data = []
	for b in range(batch_size):
		N = np.random.randint(5,10)
		E = np.random.randint(3,N*(N-1))
		edge_index = np.stack([
			np.random.choice(N, E, replace=True),
			np.random.choice(N, E, replace=True),
		]).T
		edge_index = np.sort(edge_index)
		edge_index = edge_index[edge_index[:, 0].argsort()]
		edge_index[0] = [0,1] #default
		edge_index = edge_index[edge_index[:,0]!=edge_index[:,1]]
		edge_index = np.unique(edge_index, axis=0)

		E = len(edge_index)
		edge_index = np.ascontiguousarray(edge_index.T)

		d = Data(
			idx        = b,
			edge_index = torch.from_numpy(edge_index).int(),
			x          = torch.from_numpy(np.packbits(np.random.choice(2, (N, node_dim)),-1)).byte(),
			edge_attr  = torch.from_numpy(np.packbits(np.random.choice(2, (E, edge_dim)),-1)).byte(),
		)
		data.append(d)

	#from my_mol2graph import make_dummy_data
	#data = make_dummy_data()

	loader = DataLoader(data, batch_size=batch_size)
	graph = next(iter(loader))
	idx = graph.idx.tolist()  #use to index bind array
	batch = dotdict( 
		graph = graph.to(DEVICE),
		bind  = torch.from_numpy(np.random.choice(2, (batch_size, 3))).float().to(DEVICE),
	)
	zz=0
 
	net = Net().to(DEVICE)
	#print(net)

	with torch.no_grad():
		with torch.cuda.amp.autocast(enabled=True): # dtype=torch.float16):
			output = net(batch)
			#print(output['bind'])

	# ---
	print('batch')
	for k, v in batch.items():
		if k=='idx':
			print(f'{k:>32} : {len(v)} ')
		elif k=='graph':
			print(f'{k:>32} : {graph} ')
		else:
			print(f'{k:>32} : {v.shape} ')

	print('output')
	for k, v in output.items():
		if 'loss' not in k:
			print(f'{k:>32} : {v.shape} ')
	print('loss')
	for k, v in output.items():
		if 'loss' in k:
			print(f'{k:>32} : {v.item()} ')

            
run_check_net()
print('model ok!')

In [None]:
def my_collate(graph, index=None, device='cpu'):
    if index is None:
        index = np.arange(len(graph)).tolist()
    batch = dotdict(
        x=[],
        edge_index=[],
        edge_attr=[],
        batch=[],
        idx=index
    )
    offset = 0
    for b, i in enumerate(index):
        N, edge, node_feature, edge_feature = graph[i]
        batch.x.append(node_feature)
        batch.edge_attr.append(edge_feature)
        batch.edge_index.append(edge.astype(int) + offset)
        batch.batch += N * [b]
        offset += N
    batch.x = torch.from_numpy(np.concatenate(batch.x)).to(device)
    batch.edge_attr = torch.from_numpy(np.concatenate(batch.edge_attr)).to(device)
    batch.edge_index = torch.from_numpy(np.concatenate(batch.edge_index).T).to(device)
    batch.batch = torch.LongTensor(batch.batch).to(device)
    return batch


#.... more code here ....

while epoch<cfg.num_epoch:
    shuffled_idx = train_idx.copy()
    np.random.shuffle(shuffled_idx)
    for t, index in enumerate(np.arange(0,len(shuffled_idx),cfg.train_batch_size)):
        index = shuffled_idx[index:index+cfg.train_batch_size]
        if len(index)!=cfg.train_batch_size: continue #drop last

        B = len(index)
        batch = dotdict(
            graph = my_collate(train_graph,index,device='cuda'),
            bind = torch.from_numpy(train_bind[index]).float().cuda(),
        )

        net.train()
        net.output_type = ['loss', 'infer']

        with torch.cuda.amp.autocast(enabled=cfg.is_amp):
            output = net(batch)  #data_parallel(net,batch) #
            bce_loss = output['bce_loss']

In [None]:
from multiprocessing import Pool
from tqdm import tqdm
import gc
from torch_geometric.loader import DataLoader as PyGDataLoader

def to_pyg_list(graph):
	L = len(graph)
	for i in tqdm(range(L)):
		N, edge, node_feature, edge_feature = graph[i]
		graph[i] = Data(
			idx=i,
			edge_index=torch.from_numpy(edge.T).int(),
			x=torch.from_numpy(node_feature).byte(),
			edge_attr=torch.from_numpy(edge_feature).byte(),
		)
	return graph


train_smiles=[ #replace [Dy] with C
    "C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1",
    "C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1",
    "C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1",
    "C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1",
    "C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1",
    "C#CCOc1ccc(CNc2nc(NCc3cccc(Br)n3)nc(N[C@@H](CC#C)CC(=O)NC)n2)cc1",
]
train_bind =np.array([
    [0,0,0],[1,0,0],[0,1,0],[0,0,1],[1,1,0],[0,0,0],
])
num_train= len(train_smiles)
with Pool(processes=64) as pool:
    train_graph = list(tqdm(pool.imap(smile_to_graph, train_smiles), total=num_train))

train_graph = to_pyg_list(train_graph)
train_loader = PyGDataLoader(train_graph, batch_size=3, shuffle=True)



In [None]:
## example training loop
scaler = torch.cuda.amp.GradScaler(enabled=True)
net = Net()
net.to(DEVICE)

optimizer =\
	torch.optim.AdamW(filter(lambda p: p.requires_grad, net.parameters()), lr=0.001)

num_epoch=10
epoch=0
iteration=0
while epoch<num_epoch: 
	for t, graph_batch in enumerate(train_loader): 
		index = graph_batch.idx.tolist()
		B = len(index)
		batch = dotdict(
			graph  = graph_batch.to(DEVICE),
			bind   = torch.from_numpy(train_bind[index]).to(DEVICE),
		)

		net.train()
		net.output_type = ['loss', 'infer']
		with torch.cuda.amp.autocast(enabled=True):
			output = net(batch)  #data_parallel(net,batch) #
			bce_loss = output['bce_loss']

		optimizer.zero_grad() 
		scaler.scale(bce_loss).backward() 
		scaler.step(optimizer)
		scaler.update()
		 
		torch.clear_autocast_cache()
		print(epoch,iteration,bce_loss.item())
		iteration +=  1
        
	epoch += 1

In [None]:
train_path ='/kaggle/input/leash-BELKA/train.parquet'

con = duckdb.connect()
df_train = con.query(f"""(SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE binds = 0
                        ORDER BY random()
                        LIMIT 30000)
                        UNION ALL
                        (SELECT *
                        FROM parquet_scan('{train_path}')
                        WHERE binds = 1
                        ORDER BY random()
                        LIMIT 30000)""").df()
con.close()
print('DONE!')

In [None]:
df_train.head()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GATConv
from torch_geometric.nn import global_max_pool as gmp

# GAT  model
class GATNet(torch.nn.Module):
    def __init__(self, num_features=112, n_output=1,n_filters=32, embed_dim=128, output_dim=128, dropout=0.2):
        super(GATNet, self).__init__()

        # graph layers
        self.gcn1 = GATConv(num_features, num_features, heads=10, dropout=dropout)
        self.gcn2 = GATConv(num_features * 10, output_dim, dropout=dropout)
        self.fc_g1 = nn.Linear(output_dim, output_dim)

        # combined layers
        self.fc1 = nn.Linear(output_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.out = nn.Linear(32, n_output)

        # activation and regularization
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, data):
        # graph input feed-forward
        x, edge_index, batch = data.x.float(), data.edge_index, data.batch

        x = F.dropout(x, p=0.2, training=self.training)
        x = F.elu(self.gcn1(x, edge_index))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.gcn2(x, edge_index)
        x = self.relu(x)
        x = gmp(x, batch)          # global max pooling
        x = self.fc_g1(x)
        x = self.relu(x)

       
        # add some dense layers
        xc = self.fc1(x)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        xc = self.fc2(xc)
        xc = self.relu(xc)
        xc = self.dropout(xc)
        out = self.out(xc)
        return out