# Models

### Load Data

In [2]:
import pickle

# Specify the path where you saved the dictionary
directory = '../Data/biohack/'
load_path = directory + 'datalist_fungal_test.pkl'

# Load the dictionary using pickle
with open(load_path, 'rb') as f:
    data_list = pickle.load(f)

In [None]:
# import smiles from csv

import pandas as pd

df = pd.read_csv(directory + 'fungal_test.csv')
smiles = df['SMILES']

## GCN

In [10]:
import torch
import torch.nn as nn
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

'''
GCN_base and GCN_base_FP
- GCN_base: graph embedding followed by a final classification layer
- GCN_base_FP: graph + fingerprints embedding followed by a final classification layer
'''


class GCN(nn.Module):
    '''
    Define a Graph Convolutional Network (GCN) model architecture.
    Can include 'graph' only or 'graph + fingerprints' embedding before final classification layer.
    '''

    def __init__(self, args):
        super(GCN, self).__init__()
        torch.manual_seed(12345)

        num_node_features = args['num_node_features']
        hidden_channels = args['hidden_channels']
        num_classes = 1
        self.dropout = args['dropout']

        if args['model'] == 'GCN_base_FP':
            fp_dim = args['fp_dim']
        else:  # aka only GCN
            fp_dim = 0

        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)

        self.lin = Linear(hidden_channels + fp_dim, num_classes)

    def forward(self, x, edge_index, batch, fp=None):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # if also using fingerprints
        if fp is not None:
            # reshape fp to batch_size x fp_dim
            fp = fp.reshape(x.shape[0], -1)
            # concatenate graph node embeddings with fingerprint
            # print('BEFORE CONCAT x:',x.shape, 'fp:', fp.shape)
            x = torch.cat([x, fp], dim=1)
            # print('AFTER CONCAT x:',x.shape)

        # 3. Apply a final classifier
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin(x)

        return x



In [16]:
# run forward pass

model = GCN(args)

data = data_list[2]
out = model(data.x, data.edge_index, data.batch)
out

tensor([[-0.1081]], grad_fn=<AddmmBackward0>)