<a href="https://colab.research.google.com/github/oklaja/Virtual-Drug-Screening-for-Covid-19/blob/main/train_GNNs_covid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install rdkit

# !curl -L bit.ly/rdkit-colab | tar xz -C /
!pip install rdkit-pypi

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rdkit-pypi
  Downloading rdkit_pypi-2022.3.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (22.7 MB)
[K     |████████████████████████████████| 22.7 MB 82.6 MB/s 
Installing collected packages: rdkit-pypi
Successfully installed rdkit-pypi-2022.3.3


In [None]:
# Install compatible versions of torch geometric and its dependencies

import torch
print(torch.__version__)
!pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-{torch.__version__}.html
!pip install torch_geometric

1.11.0+cu113
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-1.11.0+cu113.html
Collecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_scatter-2.0.9-cp37-cp37m-linux_x86_64.whl (7.9 MB)
[K     |████████████████████████████████| 7.9 MB 34.9 MB/s 
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-1.11.0%2Bcu113/torch_sparse-0.6.13-cp37-cp37m-linux_x86_64.whl (3.5 MB)
[K     |████████████████████████████████| 3.5 MB 64.8 MB/s 
Installing collected packages: torch-sparse, torch-scatter
Successfully installed torch-scatter-2.0.9 torch-sparse-0.6.13
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch_geometric
  Downloading torch_geometric-2.0.4.tar.gz (407 kB)
[K     |████████████████████████████████| 407 kB 36.5 MB/s 
Building wheels for collected packages: torch-geo

In [None]:
# Mount your google drive

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!ls /content/drive/MyDrive/AI/datasets/covid/datasetF/test/

processed


# Train Attentive FP

In [None]:
# Import dependencies

import os.path as osp
from math import sqrt
 
import torch
import torch.nn.functional as F
from rdkit import Chem
from torch_geometric.loader import DataLoader
from torch_geometric.datasets import MoleculeNet
from torch_geometric.nn.models import AttentiveFP
import torch
from sklearn import metrics

# Build a Torch Geometric datasets from CSV with SMILES and binary Activity columns
# Generate features and save Graph datasets for training, validation and testing
 
class GenFeatures(object):
    def __init__(self):
        self.symbols = [
            'B', 'C', 'N', 'O', 'F', 'Si', 'P', 'S', 'Cl', 'As', 'Se', 'Br',
            'Te', 'I', 'At', 'H', 'Na', 'K', 'Al', 'other'
        ]
 
        self.hybridizations = [
            Chem.rdchem.HybridizationType.S,
            Chem.rdchem.HybridizationType.SP,
            Chem.rdchem.HybridizationType.SP2,
            Chem.rdchem.HybridizationType.SP3,
            Chem.rdchem.HybridizationType.SP3D,
            Chem.rdchem.HybridizationType.SP3D2,
            'other',
        ]
 
        self.stereos = [
            Chem.rdchem.BondStereo.STEREONONE,
            Chem.rdchem.BondStereo.STEREOANY,
            Chem.rdchem.BondStereo.STEREOZ,
            Chem.rdchem.BondStereo.STEREOE,
        ]
 
    def __call__(self, data):
        # Generate AttentiveFP features according to Table 1.
        mol = Chem.MolFromSmiles(data.smiles)
 
        xs = []
        for atom in mol.GetAtoms():
            symbol = [0.] * len(self.symbols)
            symbol[self.symbols.index(atom.GetSymbol())] = 1.
            degree = [0.] * 6
            degree[atom.GetDegree()] = 1.
            formal_charge = atom.GetFormalCharge()
            radical_electrons = atom.GetNumRadicalElectrons()
            hybridization = [0.] * len(self.hybridizations)
            hybridization[self.hybridizations.index(
                atom.GetHybridization())] = 1.
            aromaticity = 1. if atom.GetIsAromatic() else 0.
            hydrogens = [0.] * 5
            hydrogens[atom.GetTotalNumHs()] = 1.
            chirality = 1. if atom.HasProp('_ChiralityPossible') else 0.
            chirality_type = [0.] * 2
            if atom.HasProp('_CIPCode'):
                chirality_type[['R', 'S'].index(atom.GetProp('_CIPCode'))] = 1.
 
            x = torch.tensor(symbol + degree + [formal_charge] +
                             [radical_electrons] + hybridization +
                             [aromaticity] + hydrogens + [chirality] +
                             chirality_type)
            xs.append(x)
 
        data.x = torch.stack(xs, dim=0)
 
        edge_indices = []
        edge_attrs = []
        for bond in mol.GetBonds():
            edge_indices += [[bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]]
            edge_indices += [[bond.GetEndAtomIdx(), bond.GetBeginAtomIdx()]]
 
            bond_type = bond.GetBondType()
            single = 1. if bond_type == Chem.rdchem.BondType.SINGLE else 0.
            double = 1. if bond_type == Chem.rdchem.BondType.DOUBLE else 0.
            triple = 1. if bond_type == Chem.rdchem.BondType.TRIPLE else 0.
            aromatic = 1. if bond_type == Chem.rdchem.BondType.AROMATIC else 0.
            conjugation = 1. if bond.GetIsConjugated() else 0.
            ring = 1. if bond.IsInRing() else 0.
            stereo = [0.] * 4
            stereo[self.stereos.index(bond.GetStereo())] = 1.
 
            edge_attr = torch.tensor(
                [single, double, triple, aromatic, conjugation, ring] + stereo)
 
            edge_attrs += [edge_attr, edge_attr]
 
        if len(edge_attrs) == 0:
            data.edge_index = torch.zeros((2, 0), dtype=torch.long)
            data.edge_attr = torch.zeros((0, 10), dtype=torch.float)
        else:
            data.edge_index = torch.tensor(edge_indices).t().contiguous()
            data.edge_attr = torch.stack(edge_attrs, dim=0)
 
        return data
 
 
x_map = {
    'atomic_num':
    list(range(0, 119)),
    'chirality': [
        'CHI_UNSPECIFIED',
        'CHI_TETRAHEDRAL_CW',
        'CHI_TETRAHEDRAL_CCW',
        'CHI_OTHER',
    ],
    'degree':
    list(range(0, 11)),
    'formal_charge':
    list(range(-5, 7)),
    'num_hs':
    list(range(0, 9)),
    'num_radical_electrons':
    list(range(0, 5)),
    'hybridization': [
        'UNSPECIFIED',
        'S',
        'SP',
        'SP2',
        'SP3',
        'SP3D',
        'SP3D2',
        'OTHER',
    ],
    'is_aromatic': [False, True],
    'is_in_ring': [False, True],
}
 
e_map = {
    'bond_type': [
        'misc',
        'SINGLE',
        'DOUBLE',
        'TRIPLE',
        'AROMATIC',
    ],
    'stereo': [
        'STEREONONE',
        'STEREOZ',
        'STEREOE',
        'STEREOCIS',
        'STEREOTRANS',
        'STEREOANY',
    ],
    'is_conjugated': [False, True],
}
 
 
import torch
from torch_geometric.data import (InMemoryDataset, Data)
import re
 
class Molecule(InMemoryDataset):
    r"""Customized processing of MoleculeNet Dataset:

    The `MoleculeNet <http://moleculenet.ai/datasets-1>`_ benchmark
    collection  from the `"MoleculeNet: A Benchmark for Molecular Machine
    Learning" <https://arxiv.org/abs/1703.00564>`_ paper, containing datasets
    from physical chemistry, biophysics and physiology.
    All datasets come with the additional node and edge features introduced by
    the `Open Graph Benchmark <https://ogb.stanford.edu/docs/graphprop/>`_.
    Args:
        root_dir (string): Root directory.
        name (string): The name of dataset (csv format)
        smi_idx (integer): index of smiles column
        target_idx (integer): index of target column
        transform (callable, optional): A function/transform that takes in an
            :obj:`torch_geometric.data.Data` object and returns a transformed
            version. The data object will be transformed before every access.
            (default: :obj:`None`)
        pre_transform (callable, optional): A function/transform that takes in
            an :obj:`torch_geometric.data.Data` object and returns a
            transformed version. The data object will be transformed before
            being saved to disk. (default: :obj:`None`)
        pre_filter (callable, optional): A function that takes in an
            :obj:`torch_geometric.data.Data` object and returns a boolean
            value, indicating whether the data object should be included in the
            final dataset. (default: :obj:`None`)
    """
 
 
 
    def __init__(self, root_dir, name, smi_idx, target_idx, transform=None, 
                 pre_transform=None, pre_filter=None):
      
        self.root_dir = root_dir
        self.name = name
        self.smi_idx = smi_idx
        self.target_idx = target_idx
        #skip calling data
        super(Molecule, self).__init__(None, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
 
    @property
    def raw_dir(self):
        return osp.join(self.root_dir, 'raw')
 
    @property
    def processed_dir(self):
        return osp.join(self.root_dir,'processed')
 
    @property
    def raw_file_names(self):
        return f'{self.name}'
 
    @property
    def processed_file_names(self):
        return 'data.pt'
 
 
    def process(self):
        from rdkit import Chem
        with open(self.raw_file_names, 'r') as f:
            dataset = f.read().split('\n')[1:-1]
            dataset = [x for x in dataset if len(x) > 0]  # Filter empty lines.
 
        data_list = []
        for line in dataset:
            # if not line.startswith("smiles"): # in line:
            try:
            
                line = re.sub(r'\".*\"', '', line)  # Replace ".*" strings.
                line = line.split(',')

                smiles = line[self.smi_idx]
                ys = line[self.target_idx]
                ys = float(ys)
                y = torch.tensor(ys, dtype=torch.float).view(1, -1)

                mol = Chem.MolFromSmiles(smiles)
                if mol is None:
                    continue

                xs = []
                for atom in mol.GetAtoms():
                    x = []
                    x.append(x_map['atomic_num'].index(atom.GetAtomicNum()))
                    x.append(x_map['chirality'].index(str(atom.GetChiralTag())))
                    x.append(x_map['degree'].index(atom.GetTotalDegree()))
                    x.append(x_map['formal_charge'].index(atom.GetFormalCharge()))
                    x.append(x_map['num_hs'].index(atom.GetTotalNumHs()))
                    x.append(x_map['num_radical_electrons'].index(
                        atom.GetNumRadicalElectrons()))
                    x.append(x_map['hybridization'].index(
                        str(atom.GetHybridization())))
                    x.append(x_map['is_aromatic'].index(atom.GetIsAromatic()))
                    x.append(x_map['is_in_ring'].index(atom.IsInRing()))
                    xs.append(x)

                x = torch.tensor(xs, dtype=torch.long).view(-1, 9)

                edge_indices, edge_attrs = [], []
                for bond in mol.GetBonds():
                    i = bond.GetBeginAtomIdx()
                    j = bond.GetEndAtomIdx()

                    e = []
                    e.append(e_map['bond_type'].index(str(bond.GetBondType())))
                    e.append(e_map['stereo'].index(str(bond.GetStereo())))
                    e.append(e_map['is_conjugated'].index(bond.GetIsConjugated()))

                    edge_indices += [[i, j], [j, i]]
                    edge_attrs += [e, e]

                edge_index = torch.tensor(edge_indices)
                edge_index = edge_index.t().to(torch.long).view(2, -1)
                edge_attr = torch.tensor(edge_attrs, dtype=torch.long).view(-1, 3)

                # Sort indices.
                if edge_index.numel() > 0:
                    perm = (edge_index[0] * x.size(0) + edge_index[1]).argsort()
                    edge_index, edge_attr = edge_index[:, perm], edge_attr[perm]

                data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y,
                            smiles=smiles)

                if self.pre_filter is not None and not self.pre_filter(data):
                    continue

                if self.pre_transform is not None:
                    data = self.pre_transform(data)

                data_list.append(data)
                
                
            except: pass

            torch.save(self.collate(data_list), self.processed_paths[0])

        def __repr__(self):
            return '{}({})'.format(self.names[self.name][0], len(self))



# Build or load the datasets 


train_dataset = Molecule(root_dir='/content/drive/MyDrive/AI/datasets/covid/datasetF/',
                  name='./trainC.csv',
                  smi_idx=1,
                  target_idx=2,
                  pre_transform=GenFeatures()).shuffle()


test_dataset = Molecule(root_dir='/content/drive/MyDrive/AI/datasets/covid/datasetF/test/',
                  name='./testC.csv',
                  smi_idx=1,
                  target_idx=2,
                  pre_transform=GenFeatures()).shuffle()

val_dataset = Molecule(root_dir='/content/drive/MyDrive/AI/datasets/covid/datasetF/val/',
                  name='./valC.csv',
                  smi_idx=1,
                  target_idx=2,
                  pre_transform=GenFeatures()).shuffle()


# Initiate DataLoader object

train_loader = DataLoader(train_dataset, batch_size=4096, 
                          shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=4096, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=4096, num_workers=0)


# Set optimizer parameters

optimizer = torch.optim.Adam(model.parameters(), lr=10**-3.5,
                             weight_decay=10**-2.9)


# Set weight for loss scaling, set printing of results to preferred number of
# samples or 0 (every epoch):

WEIGHT = 300.0
print_every = 0  # 0 for every epoch

def train():
    model.train()
    
    total_pred = [] 
    total_y = []
    total_loss = total_examples = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        out = out.sigmoid()

        weight = torch.tensor([WEIGHT], dtype=torch.float)
        loss = F.binary_cross_entropy(out, data.y, weight=weight.to(device))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
        total_pred.append(out.detach().cpu().numpy())
        total_y.append(data.y.detach().cpu().numpy())
        
        if print_every != 0 and i % print_every == 0:
            val_loss, val_roc = test(val_loader)
            test_loss, test_roc = test(test_loader)
            fpr, tpr, threshold = metrics.roc_curve(np.concatenate(total_y, 0), 
                                                    np.concatenate(total_pred, 0))
            roc_auc = metrics.auc(fpr, tpr)
            print(f'Epoch: {epoch:03d} | LOSS: Train: {total_loss / total_examples:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {roc_auc:.3f} Val: {val_roc:.3f}, '
                  f'Test: {test_roc:.3f}')
            
        
    fpr, tpr, threshold = metrics.roc_curve(np.concatenate(total_y, 0), 
                                            np.concatenate(total_pred, 0))
    roc_auc = metrics.auc(fpr, tpr)
    return total_loss / total_examples, roc_auc
 

from sklearn import metrics
import numpy as np
from tqdm import tqdm

@torch.no_grad()
def test(loader):
    model.eval()

    correct = 0
    total_loss = 0
    total_examples = 0
    total_pred = []
    total_y = []
    
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)  
        pred = torch.sigmoid(out)  #  out.sigmoid()
        total_pred.append(pred.detach().cpu().numpy())
        total_y.append(data.y.detach().cpu().numpy())

            
        weight = torch.tensor([WEIGHT], dtype=torch.float)
        loss_test = F.binary_cross_entropy(pred, data.y, weight=weight.to(device))
        total_loss += float(loss_test)  * data.num_graphs
        total_examples += data.num_graphs

    fpr, tpr, threshold = metrics.roc_curve(np.concatenate(total_y, 0), 
                                            np.concatenate(total_pred, 0))
    roc_auc = metrics.auc(fpr, tpr)
    return total_loss / total_examples , roc_auc
    

AttentiveFP(
  (lin1): Linear(in_features=44, out_features=250, bias=True)
  (atom_convs): ModuleList(
    (0): GATEConv()
    (1): GATConv(250, 250, heads=1)
    (2): GATConv(250, 250, heads=1)
    (3): GATConv(250, 250, heads=1)
  )
  (atom_grus): ModuleList(
    (0): GRUCell(250, 250)
    (1): GRUCell(250, 250)
    (2): GRUCell(250, 250)
    (3): GRUCell(250, 250)
  )
  (mol_conv): GATConv(250, 250, heads=1)
  (mol_gru): GRUCell(250, 250)
  (lin2): Linear(in_features=250, out_features=1, bias=True)
)


In [None]:
print(len(train_dataset), len(val_dataset), len(test_dataset))

234617 33517 67036


In [None]:
# Train the Attentive FP

history = []

for epoch in range(1, 200):
    train_loss, train_roc = train()   #  train loss and roc with dropout
    val_loss, val_roc = test(val_loader)
    test_loss, test_roc = test(test_loader)
    # train_loss, train_roc = test(train_loader)  # compute loss without dropout
    print(f'Epoch: {epoch:03d} | LOSS: Train: {train_loss:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {train_roc:.3f}, Val: {val_roc:.3f}, '
          f'Test: {test_roc:.3f}')
    history.append({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss,'test_loss': test_loss, 'train_roc': train_roc, 'val_roc': val_roc, 'test_roc': test_roc})
    
    # Uncomment to save the model
    # Save the model in a dictionary with optimizer state for training 
    # continuation (first) or save the model for inference only (second)

    # PATH = f"/content/drive/MyDrive/AI/covid-1/model250/covid-250-l4-t3--1-epoch{epoch:03d}.pt"
    # torch.save({
    #         'epoch': epoch,
    #         'model_state_dict': model.state_dict(),
    #         'optimizer_state_dict': optimizer.state_dict(),
    #         'loss': train_loss,
    #         }, PATH)
    # torch.save(model,  PATH.strip(".pt") + "model1.pt" )

Epoch: 001 | LOSS: Train: 24.2889 Val: 7.6985 Test: 7.6301  | ROC: Train: 0.440, Val: 0.271, Test: 0.364
Epoch: 002 | LOSS: Train: 7.3749 Val: 6.6244 Test: 6.8560  | ROC: Train: 0.536, Val: 0.710, Test: 0.631
Epoch: 003 | LOSS: Train: 7.0393 Val: 6.5624 Test: 6.8216  | ROC: Train: 0.614, Val: 0.720, Test: 0.634
Epoch: 004 | LOSS: Train: 6.9041 Val: 6.5502 Test: 6.8289  | ROC: Train: 0.631, Val: 0.717, Test: 0.644
Epoch: 005 | LOSS: Train: 6.7473 Val: 6.4012 Test: 6.7633  | ROC: Train: 0.669, Val: 0.732, Test: 0.654
Epoch: 006 | LOSS: Train: 6.7424 Val: 6.3948 Test: 6.7736  | ROC: Train: 0.664, Val: 0.737, Test: 0.669
Epoch: 007 | LOSS: Train: 6.6201 Val: 6.3267 Test: 6.7040  | ROC: Train: 0.696, Val: 0.718, Test: 0.682
Epoch: 008 | LOSS: Train: 6.4670 Val: 6.2125 Test: 6.5775  | ROC: Train: 0.710, Val: 0.730, Test: 0.693
Epoch: 009 | LOSS: Train: 6.4344 Val: 6.1438 Test: 6.5513  | ROC: Train: 0.710, Val: 0.731, Test: 0.702
Epoch: 010 | LOSS: Train: 6.4025 Val: 6.2087 Test: 6.5415  | RO

# Train the GraphConv Network

In [None]:
# !pip install class-resolver  # Might be needed depending on the torch version
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, GraphSAGE
from torch_geometric.nn import global_mean_pool, global_add_pool

class Net(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Net, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(44, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin(x)
        
        # x = x.sigmoid()
        
        return x

model = Net(hidden_channels=200)
print(model)

Net(
  (conv1): GraphConv(44, 200)
  (conv2): GraphConv(200, 200)
  (conv3): GraphConv(200, 200)
  (lin): Linear(in_features=200, out_features=1, bias=True)
)


In [None]:
WEIGHT = 230
print_every = 0
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=10**-3.5,
                             weight_decay=10**-2.9)

def train():
    model.train()
    
    total_pred = [] 
    total_y = []
    total_loss = total_examples = 0
    for i, data in enumerate(train_loader):
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch)
        out = out.sigmoid()

        weight = torch.tensor([WEIGHT], dtype=torch.float)
        loss = F.binary_cross_entropy(out, data.y, weight=weight.to(device))
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * data.num_graphs
        total_examples += data.num_graphs
        total_pred.append(out.detach().cpu().numpy())
        total_y.append(data.y.detach().cpu().numpy())
        
        if print_every != 0 and i % print_every == 0:
            val_loss, val_roc = test(val_loader)
            test_loss, test_roc = test(test_loader)
            fpr, tpr, threshold = metrics.roc_curve(np.concatenate(total_y, 0), np.concatenate(total_pred, 0))
            roc_auc = metrics.auc(fpr, tpr)
            print(f'Epoch: {epoch:03d} | LOSS: Train: {total_loss / total_examples:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {roc_auc:.3f} Val: {val_roc:.3f}, '
                  f'Test: {test_roc:.3f}')
            
        
    fpr, tpr, threshold = metrics.roc_curve(np.concatenate(total_y, 0), np.concatenate(total_pred, 0))
    roc_auc = metrics.auc(fpr, tpr)
    return total_loss / total_examples, roc_auc
 

from sklearn import metrics
import numpy as np
from tqdm import tqdm

@torch.no_grad()
def test(loader):
    model.eval()

    correct = 0
    total_loss = 0
    total_examples = 0
    total_pred = [] #np.array([], dtype=float)
    total_y = [] # np.array([], dtype=float)
    
    for data in loader:  # Iterate in batches over the training/test dataset.
        data = data.to(device)
        out = model(data.x, data.edge_index,  data.batch)  
        pred = torch.sigmoid(out)  #  out.sigmoid()
        total_pred.append(pred.detach().cpu().numpy())
        total_y.append(data.y.detach().cpu().numpy())

            
        weight = torch.tensor([WEIGHT], dtype=torch.float)
        loss_test = F.binary_cross_entropy(pred, data.y, weight=weight.to(device))
        total_loss += float(loss_test)  * data.num_graphs
        total_examples += data.num_graphs

    fpr, tpr, threshold = metrics.roc_curve(np.concatenate(total_y, 0), np.concatenate(total_pred, 0))
    roc_auc = metrics.auc(fpr, tpr)
    return total_loss / total_examples , roc_auc
 
for epoch in range(1, 20):
    train_loss, train_roc = train()
    val_loss, val_roc = test(val_loader)
    test_loss, test_roc = test(test_loader)
    # train_loss, train_roc = test(train_loader)
    print(f'Epoch: {epoch:03d} | LOSS: Train: {train_loss:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {train_roc:.3f}, Val: {val_roc:.3f}, '
          f'Test: {test_roc:.3f}')

Epoch: 001 | LOSS: Train: 13.7756 Val: 5.7125 Test: 5.7591  | ROC: Train: 0.487, Val: 0.440, Test: 0.442
Epoch: 002 | LOSS: Train: 5.5904 Val: 5.5062 Test: 5.5534  | ROC: Train: 0.489, Val: 0.459, Test: 0.460
Epoch: 003 | LOSS: Train: 5.4968 Val: 5.3784 Test: 5.4370  | ROC: Train: 0.522, Val: 0.530, Test: 0.540
Epoch: 004 | LOSS: Train: 5.3182 Val: 5.2966 Test: 5.3624  | ROC: Train: 0.629, Val: 0.626, Test: 0.614
Epoch: 005 | LOSS: Train: 5.2299 Val: 5.2272 Test: 5.2941  | ROC: Train: 0.672, Val: 0.663, Test: 0.649
Epoch: 006 | LOSS: Train: 5.1241 Val: 5.2195 Test: 5.2895  | ROC: Train: 0.717, Val: 0.682, Test: 0.676
Epoch: 007 | LOSS: Train: 5.0542 Val: 5.1010 Test: 5.1778  | ROC: Train: 0.732, Val: 0.704, Test: 0.700
Epoch: 008 | LOSS: Train: 4.9864 Val: 5.0393 Test: 5.1258  | ROC: Train: 0.744, Val: 0.718, Test: 0.718
Epoch: 009 | LOSS: Train: 4.9270 Val: 4.9738 Test: 5.0805  | ROC: Train: 0.752, Val: 0.727, Test: 0.727
Epoch: 010 | LOSS: Train: 4.8398 Val: 5.0000 Test: 5.1380  | RO

In [None]:
for epoch in range(1, 180):
    train_loss, train_roc = train()
    val_loss, val_roc = test(val_loader)
    test_loss, test_roc = test(test_loader)
    # train_loss, train_roc = test(train_loader)
    print(f'Epoch: {epoch:03d} | LOSS: Train: {train_loss:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {train_roc:.3f}, Val: {val_roc:.3f}, '
          f'Test: {test_roc:.3f}')

Epoch: 001 | LOSS: Train: 4.5085 Val: 4.6167 Test: 4.8653  | ROC: Train: 0.814, Val: 0.784, Test: 0.768
Epoch: 002 | LOSS: Train: 4.4815 Val: 4.5844 Test: 4.8378  | ROC: Train: 0.818, Val: 0.779, Test: 0.767
Epoch: 003 | LOSS: Train: 4.4854 Val: 4.6032 Test: 4.8825  | ROC: Train: 0.818, Val: 0.782, Test: 0.765
Epoch: 004 | LOSS: Train: 4.4725 Val: 4.7274 Test: 5.0042  | ROC: Train: 0.815, Val: 0.793, Test: 0.771
Epoch: 005 | LOSS: Train: 4.4776 Val: 4.5497 Test: 4.8346  | ROC: Train: 0.814, Val: 0.791, Test: 0.768
Epoch: 006 | LOSS: Train: 4.4796 Val: 4.5761 Test: 4.8700  | ROC: Train: 0.815, Val: 0.794, Test: 0.769
Epoch: 007 | LOSS: Train: 4.4421 Val: 4.5141 Test: 4.8180  | ROC: Train: 0.816, Val: 0.790, Test: 0.766
Epoch: 008 | LOSS: Train: 4.3750 Val: 4.5394 Test: 4.8269  | ROC: Train: 0.829, Val: 0.796, Test: 0.770
Epoch: 009 | LOSS: Train: 4.3941 Val: 4.5273 Test: 4.8296  | ROC: Train: 0.821, Val: 0.793, Test: 0.771
Epoch: 010 | LOSS: Train: 4.3764 Val: 4.5068 Test: 4.8049  | ROC

In [None]:
# Save the model

PATH = f"/content/drive/MyDrive/AI/covid-1/covid-GraphConv-3x--200-d02--1.pt"
torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
        }, PATH)
torch.save(model,  PATH.strip(".pt") + "-model1.pt" )

# Train the 'Conv-GRU' Network

In [None]:
# !pip install class-resolver
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, GraphSAGE
from torch_geometric.nn import global_mean_pool, global_add_pool
from torch.nn import GRUCell

class Net(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Net, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(44, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.gru1 = GRUCell(hidden_channels, hidden_channels)

        self.gru2 = GRUCell(hidden_channels, hidden_channels)
        self.gru3 = GRUCell(hidden_channels, hidden_channels)

        self.conv3 = GraphConv(hidden_channels, hidden_channels)

        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.gru1(x, x)
        x = x.relu()
        x = self.gru2(x, x)
        x = x.relu()
        x = self.gru3(x, x)

        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_add_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin(x)
        
        # x = x.sigmoid()
        
        return x

model = Net(hidden_channels=200)
print(model)




WEIGHT = 230
print_every = 0
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=10**-3.5,
                             weight_decay=10**-2.9)


for epoch in range(1, 200):
    train_loss, train_roc = train()
    val_loss, val_roc = test(val_loader)
    test_loss, test_roc = test(test_loader)
    # train_loss, train_roc = test(train_loader)
    print(f'Epoch: {epoch:03d} | LOSS: Train: {train_loss:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {train_roc:.3f}, Val: {val_roc:.3f}, '
          f'Test: {test_roc:.3f}')


Net(
  (conv1): GraphConv(44, 200)
  (conv2): GraphConv(200, 200)
  (gru1): GRUCell(200, 200)
  (gru2): GRUCell(200, 200)
  (gru3): GRUCell(200, 200)
  (conv3): GraphConv(200, 200)
  (lin): Linear(in_features=200, out_features=1, bias=True)
)
Epoch: 001 | LOSS: Train: 19.0634 Val: 5.5373 Test: 5.5719  | ROC: Train: 0.491, Val: 0.440, Test: 0.453
Epoch: 002 | LOSS: Train: 5.5375 Val: 5.4603 Test: 5.4997  | ROC: Train: 0.492, Val: 0.456, Test: 0.473
Epoch: 003 | LOSS: Train: 5.4619 Val: 5.3818 Test: 5.4223  | ROC: Train: 0.534, Val: 0.506, Test: 0.526
Epoch: 004 | LOSS: Train: 5.3507 Val: 5.3946 Test: 5.4467  | ROC: Train: 0.601, Val: 0.607, Test: 0.607
Epoch: 005 | LOSS: Train: 5.2396 Val: 5.1992 Test: 5.2697  | ROC: Train: 0.671, Val: 0.669, Test: 0.651
Epoch: 006 | LOSS: Train: 5.1737 Val: 5.1930 Test: 5.2733  | ROC: Train: 0.701, Val: 0.683, Test: 0.667
Epoch: 007 | LOSS: Train: 5.1160 Val: 5.1030 Test: 5.1886  | ROC: Train: 0.711, Val: 0.695, Test: 0.684
Epoch: 008 | LOSS: Train: 5.

In [None]:
PATH = f"/content/drive/MyDrive/AI/covid-1/covid-GraphConv-3x-3GRU--200-d02--1.pt"
torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
        }, PATH)
torch.save(model,  PATH.strip(".pt") + "-model1.pt" )

# Train the GraphSage Network

In [None]:
class Net(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Net, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphSAGE(44, hidden_channels, num_layers=4) 
        # self.conv1 = GraphConv(44, hidden_channels)
        # self.conv2 = GraphConv(hidden_channels, hidden_channels)
        # self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 1)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        # x = self.conv2(x, edge_index)
        # x = x.relu()
        # x = self.conv3(x, edge_index)

        # 2. Readout layer
        # x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]
        x = global_add_pool(x, batch) 
        
        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        # x = x.sigmoid()
        
        return x

model = Net(hidden_channels=250)
print(model)

WEIGHT = 230
print_every = 0
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=10**-3.5,
                             weight_decay=10**-2.9)

for epoch in range(1, 200):
    train_loss, train_roc = train()
    val_loss, val_roc = test(val_loader)
    test_loss, test_roc = test(test_loader)
    # train_loss, train_roc = test(train_loader)
    print(f'Epoch: {epoch:03d} | LOSS: Train: {train_loss:.4f} Val: {val_loss:.4f} Test: {test_loss:.4f}  | ROC: Train: {train_roc:.3f}, Val: {val_roc:.3f}, '
          f'Test: {test_roc:.3f}')


Net(
  (conv1): GraphSAGE(44, 250, num_layers=4)
  (lin): Linear(in_features=250, out_features=1, bias=True)
)
Epoch: 001 | LOSS: Train: 14.7445 Val: 7.1723 Test: 6.8976  | ROC: Train: 0.384, Val: 0.271, Test: 0.364
Epoch: 002 | LOSS: Train: 7.4244 Val: 7.0724 Test: 6.7874  | ROC: Train: 0.358, Val: 0.299, Test: 0.404
Epoch: 003 | LOSS: Train: 7.0092 Val: 6.8574 Test: 6.6405  | ROC: Train: 0.424, Val: 0.360, Test: 0.449
Epoch: 004 | LOSS: Train: 6.7787 Val: 6.5842 Test: 6.4652  | ROC: Train: 0.474, Val: 0.431, Test: 0.489
Epoch: 005 | LOSS: Train: 6.6524 Val: 6.4293 Test: 6.3346  | ROC: Train: 0.508, Val: 0.455, Test: 0.502
Epoch: 006 | LOSS: Train: 6.3423 Val: 6.4724 Test: 6.4509  | ROC: Train: 0.560, Val: 0.490, Test: 0.510
Epoch: 007 | LOSS: Train: 6.2522 Val: 6.1250 Test: 6.1633  | ROC: Train: 0.570, Val: 0.521, Test: 0.535
Epoch: 008 | LOSS: Train: 6.0589 Val: 5.9030 Test: 6.0261  | ROC: Train: 0.602, Val: 0.557, Test: 0.564
Epoch: 009 | LOSS: Train: 5.9684 Val: 5.7980 Test: 5.957

In [None]:
PATH = f"/content/drive/MyDrive/AI/covid-1/covid-GraphSage-4lay--200-d02--1.pt"
torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': train_loss,
        }, PATH)
torch.save(model,  PATH.strip(".pt") + "-model1.pt" )