In [None]:
class GraphFeaturizer(Featurizer): #uznajmy ze mamy taki featurizer
    def __call__(self, df):
        graphs = []
        labels = []
        for i, row in df.iterrows():
            y = row[self.y_column]
            smiles = row[self.smiles_col]
            mol = Chem.MolFromSmiles(smiles)
            
            edges = []
            for bond in mol.GetBonds():
                begin = bond.GetBeginAtomIdx()
                end = bond.GetEndAtomIdx()
                edges.append((begin, end))  # TODO: Add edges in both directions
            edges = np.array(edges)
            
            nodes = []
            for atom in mol.GetAtoms():
                results = one_of_k_encoding_unk(atom.GetAtomicNum(), range(11)) + one_of_k_encoding(
                    atom.GetDegree(), range(11)
                ) + one_of_k_encoding_unk(
                    atom.GetImplicitValence(), range(11)
                ) + [atom.GetIsAromatic()] + one_of_k_encoding_unk(
                    atom.GetTotalNumHs(), range(11)
                ) + [atom.GetNumImplicitHs(), atom.GetFormalCharge(), atom.GetNumRadicalElectrons(), atom.GetIsAromatic()]
                nodes.append(results)
            nodes = np.array(nodes)
            
            graphs.append((nodes, edges.T))
            labels.append(y)
        labels = np.array(labels)
        return [Data(
            x=torch.FloatTensor(x), 
            edge_index=torch.LongTensor(edge_index), 
            y=torch.FloatTensor([y])
        ) for ((x, edge_index), y) in zip(graphs, labels)]

In [None]:
#warstwa attention pooling
class MyAttentionModule(torch.nn.Module): # zakladamy ze atom ma 49 featerow
    def __init__(self, groupFeatures=1):
        super().__init__()
        self.conv = GCNConv(49, 49) # dla zebrania informacji od sasiadow
        self.gates = { # do wyliczenia atencji dla kazdej grupy cech - jest ich 9
            'AtomicNum': torch.nn.Linear(11, 1),
            'Degree': torch.nn.Linear(11, 1),
            'ImplicitValence': torch.nn.Linear(11, 1),
            'IsAromatic': torch.nn.Linear(1, 1),
            'TotalNumHs': torch.nn.Linear(11, 1),
            'NumImplicitHs': torch.nn.Linear(1, 1),
            'FormalCharge': torch.nn.Linear(1, 1),
            'NumRadicalElectrons': torch.nn.Linear(1, 1),
            'IsAromatic': torch.nn.Linear(1, 1)
        }
        
        self.feats = { # do transformacji grupy cech w wektor, na razie dziala tylko dla groupFeatures=1
            'AtomicNum': torch.nn.Linear(11, groupFeatures),
            'Degree': torch.nn.Linear(11, groupFeatures),
            'ImplicitValence': torch.nn.Linear(11, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures),
            'TotalNumHs': torch.nn.Linear(11, groupFeatures),
            'NumImplicitHs': torch.nn.Linear(1, groupFeatures),
            'FormalCharge': torch.nn.Linear(1, groupFeatures),
            'NumRadicalElectrons': torch.nn.Linear(1, groupFeatures),
            'IsAromatic': torch.nn.Linear(1, groupFeatures)
        }

    def forward(self, x, edge_index, batch):
        x = self.conv(x, edge_index)
        # print(x.shape)
        subgroups = []
        subgroups.append(self.gates['AtomicNum'](x[:,0:11]))
        subgroups.append(self.gates['Degree'](x[:,11:22]))
        subgroups.append(self.gates['ImplicitValence'](x[:,22:33]))
        subgroups.append(self.gates['IsAromatic'](x[:,33:34]))
        subgroups.append(self.gates['TotalNumHs'](x[:,34:45]))
        subgroups.append(self.gates['NumImplicitHs'](x[:,45:46]))
        subgroups.append(self.gates['FormalCharge'](x[:,46:47]))
        subgroups.append(self.gates['NumRadicalElectrons'](x[:,47:48]))
        subgroups.append(self.gates['IsAromatic'](x[:,48:49]))
        logits = torch.cat(subgroups, dim=-1) # dla np x o ksztalcie (1200, 49) bedziemy mieli tensor istotnosci (1200, 9)
        logits = gap(logits, batch=batch) # nie pisal Pan o tym, ale chyba chcemy miec wektor atencji dla kazdej czasteczki a nie dla kazdego atomu z osobna, wiec usredniam dla czasteczki
        # czyli dla batch_size=64 bedziemy mieli pooling: (1200, 9) -> (64, 9)
        attention = torch.softmax(logits, dim=-1)
        
        subgroups = []
        subgroups.append(self.feats['AtomicNum'](x[:,0:11]))
        subgroups.append(self.feats['Degree'](x[:,11:22]))
        subgroups.append(self.feats['ImplicitValence'](x[:,22:33]))
        subgroups.append(self.feats['IsAromatic'](x[:,33:34]))
        subgroups.append(self.feats['TotalNumHs'](x[:,34:45]))
        subgroups.append(self.feats['NumImplicitHs'](x[:,45:46]))
        subgroups.append(self.feats['FormalCharge'](x[:,46:47]))
        subgroups.append(self.feats['NumRadicalElectrons'](x[:,47:48]))
        subgroups.append(self.feats['IsAromatic'](x[:,48:49]))
        x = torch.cat(subgroups, dim=-1) # kazda grupe przerzucamy przez warstwe liniowa i konkatenujemy: (1200, 49) -> (1200, 9)
        
        for i, atom_features in enumerate(x):
            x[i] *= attention[batch[i]] #przemnazamy grupy featerow przez ich istotnosci
        
        return x, attention

In [None]:
#przykladowe jej uzycie w sieci grafowej
class GraphNeuralNetwork(torch.nn.Module):
    def __init__(self, hidden_size, n_features=49, dropout=0.2):
        super().__init__()
        self.myAttentionModule = MyAttentionModule(1)
        self.conv1 = GCNConv(9, hidden_size)
        self.conv2 = GCNConv(hidden_size, int(hidden_size))
        self.conv3 = GCNConv(int(hidden_size), int(hidden_size))
        self.linear = torch.nn.Linear(int(hidden_size), 1)
        self.dropout = dropout
    
    def forward(self, x, edge_index, batch):
        x, att = self.myAttentionModule(x, edge_index, batch) #nie rozumiem czemu mialbym robic pooling, skoro nasz modul uzywamy na poczatku forward - jeszcze przed warstwami konwolucyjnymi
        #a chcemy moc pozniej uzyc nasza wyuczona warstwe do roznych innych datasetow i architektor, jak dobrze rozumiem
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        
        x = gap(x, batch)
        
        x = F.dropout(x, p=self.dropout, training=self.training)

        out = self.linear(x)

        return out, att