In [1]:
import torch.nn as nn
# import class ChebConv
from torch_geometric.nn import ChebConv
# import class GraphConv
from torch_geometric.nn import GraphConv
# import class GCNConv
from torch_geometric.nn import GCNConv

# import class torch.nn.functional as F
import torch.nn.functional as F


class InvoiceGCN(nn.Module):

    def __init__(self, input_dim, chebnet=False, n_classes=6, dropout_rate=0.2, K=3):
        super().__init__()

        self.input_dim = input_dim
        self.n_classes = n_classes
        self.dropout_rate = dropout_rate

        if chebnet:
            self.conv1 = ChebConv(self.input_dim, 64, K=K)
            self.conv2 = ChebConv(64, 32, K=K)
            self.conv3 = ChebConv(32, 16, K=K)
            self.conv4 = ChebConv(16, self.n_classes, K=K)
        else:
            self.conv1 = GCNConv(self.first_dim, 64, improved=True, cached=True)
            self.conv2 = GCNConv(64, 32, improved=True, cached=True)
            self.conv3 = GCNConv(32, 16, improved=True, cached=True)
            self.conv4 = GCNConv(16, self.n_classes, improved=True, cached=True)

    def forward(self, data):
        # for transductive setting with full-batch update
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr

        x = F.dropout(F.relu(self.conv1(x, edge_index, edge_weight)), p=self.dropout_rate, training=self.training)
        x = F.dropout(F.relu(self.conv2(x, edge_index, edge_weight)), p=self.dropout_rate, training=self.training)
        x = F.dropout(F.relu(self.conv3(x, edge_index, edge_weight)), p=self.dropout_rate, training=self.training)
        x = self.conv4(x, edge_index, edge_weight)

        return F.log_softmax(x, dim=1)


In [3]:
from sklearn.utils.class_weight import compute_class_weight

from sklearn.metrics import confusion_matrix
import os
from sklearn.metrics import classification_report
import torch

save_fd = "/workspace/nabang1010/LBA_VAIPE/GNN/GNN_Drugnames_Extraction_from_Prescription/data_GNN/dataset"
train_data, val_data = torch.load(os.path.join(save_fd, 'train_data.dataset')), torch.load(os.path.join(save_fd, 'val_data.dataset'))

model = InvoiceGCN(input_dim=train_data.x.shape[1], chebnet=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = torch.optim.AdamW(
    model.parameters(), lr=0.001, weight_decay=0.9
)
train_data = train_data.to(device)
val_data = val_data.to(device)

# class weights for imbalanced data
_class_weights = compute_class_weight(class_weight= "balanced", classes= train_data.y.unique().cpu().numpy(), y = train_data.y.cpu().numpy())
# print(_class_weights)

no_epochs = 2000
for epoch in range(1, no_epochs + 1):
    model.train()
    optimizer.zero_grad()
    
    loss = F.nll_loss(
        model(train_data), train_data.y - 1, weight=torch.FloatTensor(_class_weights).to(device)
    )
    loss.backward()
    optimizer.step()

    # calculate acc on 5 classes
    with torch.no_grad():
        if epoch % 200 == 0:
            model.eval()

            # forward model
            for index, name in enumerate(['train', 'val']):
                _data = eval("{}_data".format(name))
                y_pred = model(_data).max(dim=1)[1]
                y_true = (_data.y - 1)
                acc = y_pred.eq(y_true).sum().item() / y_pred.shape[0]

                y_pred = y_pred.cpu().numpy()
                y_true = y_true.cpu().numpy()
                print("\t{} acc: {}".format(name, acc))
                # confusion matrix
                if name == 'val':
                    cm = confusion_matrix(y_true, y_pred)
                    class_accs = cm.diagonal() / cm.sum(axis=1)
                    print(classification_report(y_true, y_pred))

            loss_val = F.nll_loss(model(test_data), test_data.y - 1
            )
            fmt_log = "Epoch: {:03d}, train_loss:{:.4f}, val_loss:{:.4f}"
            print(fmt_log.format(epoch, loss, loss_val))
            print(">" * 50)



	train acc: 0.9504025566959622
	val acc: 0.9507417709782104
              precision    recall  f1-score   support

           0       0.84      1.00      0.91       606
           1       0.99      1.00      1.00       606
           2       0.72      1.00      0.84       242
           3       0.97      1.00      0.99       816
           4       0.63      0.98      0.77       324
           5       1.00      0.93      0.96      6034

    accuracy                           0.95      8628
   macro avg       0.86      0.98      0.91      8628
weighted avg       0.96      0.95      0.95      8628

Epoch: 200, train_loss:0.2445, val_loss:0.2173
>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
	train acc: 0.9911498985925881
	val acc: 0.9931617987946222
              precision    recall  f1-score   support

           0       0.99      1.00      1.00       606
           1       1.00      1.00      1.00       606
           2       0.99      1.00      1.00       242
           3       1.0

In [8]:
save_fd_path = "../weight"
torch.save(model.state_dict(), os.path.join(save_fd_path, 'model_2000.pt'))