In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

#!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
#!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cpu.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install warnings
!pip install dgl
!pip install texttable

if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
print(device)

2.0.1+cu118
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in links: https://data.pyg.org/whl/torch-2.0.0+cpu.html
Collecting pyg_lib
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcpu/pyg_lib-0.2.0%2Bpt20cpu-cp310-cp310-linux_x86_64.whl (627 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m627.0/627.0 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_scatter
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_scatter-2.1.1%2Bpt20cpu-cp310-cp310-linux_x86_64.whl (504 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m504.1/504.1 kB[0m [31m906.9 kB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse
  Downloading https://data.pyg.org/whl/torch-2.0.0%2Bcpu/torch_sparse-0.6.17%2Bpt20cpu-cp310-cp310-linux_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [2]:
import random
from torchvision import transforms, datasets

import os
import pickle
from scipy.spatial.distance import cdist
from scipy import ndimage
import numpy as np

import dgl
import torch
import time
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib
def sigma(dists, kth=8):
    # Get k-nearest neighbors for each node
    knns = np.partition(dists, kth, axis=-1)[:, kth::-1]

    # Compute sigma and reshape
    sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth
    return sigma + 1e-8 # adding epsilon to avoid zero value of sigma

def compute_adjacency_matrix_images(coord, feat, use_feat=False, kth=8):
    coord = coord.reshape(-1, 2)
    # Compute coordinate distance
    c_dist = cdist(coord, coord)

    if use_feat:
        # Compute feature distance
        f_dist = cdist(feat, feat)
        # Compute adjacency
        A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )
    else:
        A = np.exp(- (c_dist/sigma(c_dist))**2)

    # Convert to symmetric matrix
    A = 0.5 * A * A.T
    A[np.diag_indices_from(A)] = 0
    return A

def compute_edges_list(A, kth=8+1):
    # Get k-similar neighbor indices for each node
    if 1==1:
        num_nodes = A.shape[0]
        new_kth = num_nodes - kth
        knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]
        knns_d = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1]
    else:
        knns = np.argpartition(A, kth, axis=-1)[:, kth::-1]
        knns_d = np.partition(A, kth, axis=-1)[:, kth::-1]
    return knns, knns_d
class newCIFARSuperPix(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir,
                 use_mean_px=True,
                 use_coord=True,
                 use_feat_for_graph_construct=False,):

        #self.split = split
        #self.is_test = split.lower() in ['test', 'val']
        with open(data_dir, 'rb') as f:
            self.labels, self.sp_data = pickle.load(f)

        self.use_mean_px = use_mean_px
        self.use_feat_for_graph = use_feat_for_graph_construct
        self.use_coord = use_coord
        self.n_samples = len(self.labels)
        self.img_size = 32

    def precompute_graph_images(self):
        #print('precompute all data for the %s set...' % self.split.upper())
        self.Adj_matrices, self.node_features, self.edges_lists = [], [], []
        for index, sample in enumerate(self.sp_data):
            mean_px, coord = sample[:2]
            coord = coord / self.img_size
            A = compute_adjacency_matrix_images(coord, mean_px, use_feat=self.use_feat_for_graph)
            edges_list, _ = compute_edges_list(A)
            N_nodes = A.shape[0]

            x = None
            if self.use_mean_px:
                x = mean_px.reshape(N_nodes, -1)
            if self.use_coord:
                coord = coord.reshape(N_nodes, 2)
                if self.use_mean_px:
                    x = np.concatenate((x, coord), axis=1)
                else:
                    x = coord
            if x is None:
                x = np.ones(N_nodes, 1)  # dummy features

            self.node_features.append(x)
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, index):
        g = dgl.DGLGraph()
        g.add_nodes(self.node_features[index].shape[0])
        g.ndata['feat'] = torch.Tensor(self.node_features[index])
        for src, dsts in enumerate(self.edges_lists[index]):
            g.add_edges(src, dsts[dsts!=src])

        return g, self.labels[index]

use_feat_for_graph_construct = False
tt = time.time()
data_with_feat_knn = newCIFARSuperPix("/content/drive/MyDrive/CMINST_data/CMNIST09_60000_75sp_train.pkl",use_feat_for_graph_construct=use_feat_for_graph_construct)

data_with_feat_knn.precompute_graph_images()
training_data = np.load('/content/drive/MyDrive/CMINST_data/colorMNIST09_60000_data.npy')
training_label=np.load('/content/drive/MyDrive/CMINST_data/colorMNIST09_60000_label.npy')

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [3]:
import numpy as np
import os.path as osp
import pickle
import torch
import torch.utils
import torch.utils.data
import torch.nn.functional as F
from scipy.spatial.distance import cdist
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import InMemoryDataset,Data
#/content/drive/MyDrive/Colab_Notebooks/mnist08_83_75sp_train.pkl
def compute_adjacency_matrix_images(coord, sigma=0.1):
    coord = coord.reshape(-1, 2)
    dist = cdist(coord, coord)
    A = np.exp(- dist / (sigma * np.pi) ** 2)
    A[np.diag_indices_from(A)] = 0
    return A


def list_to_torch(data):
    for i in range(len(data)):
        if data[i] is None:
            continue
        elif isinstance(data[i], np.ndarray):
            if data[i].dtype == np.bool:
                data[i] = data[i].astype(np.float32)
            data[i] = torch.from_numpy(data[i]).float()
        elif isinstance(data[i], list):
            data[i] = list_to_torch(data[i])
    return data
def process(data_file):
  use_mean_px=True
  use_coord=True
  node_gt_att_threshold=0
  transform=None
  pre_transform=None
  pre_filter=None

  #data_file ='/content/drive/MyDrive/Colab_Notebooks/colorMNIST05_2000_75sp_train.pkl'

  with open(osp.join(data_file), 'rb') as f:
      labels,sp_data = pickle.load(f)

  #use_mean_px = self.use_mean_px
  #self.use_coord = self.use_coord
  n_samples = len(labels)
  img_size = 32
  #node_gt_att_threshold = self.node_gt_att_threshold

  edge_indices,xs,edge_attrs,node_gt_atts,edge_gt_atts = [], [], [], [], []
  data_list = []
  for index, sample in enumerate(sp_data):
      mean_px, coord = sample[:2]
      coord = coord / img_size
      A = compute_adjacency_matrix_images(coord)
      N_nodes = A.shape[0]

      A = torch.FloatTensor((A > 0.1) * A)
      edge_index, edge_attr = dense_to_sparse(A)

      x = None
      if use_mean_px:
          x = mean_px.reshape(N_nodes, -1)
      if use_coord:
          coord = coord.reshape(N_nodes, 2)
          if use_mean_px:
              x = np.concatenate((x, coord), axis=1)
          else:
              x = coord
      if x is None:
          x = np.ones(N_nodes, 1)  # dummy features

      # replicate features to make it possible to test on colored images
      x = np.pad(x, ((0, 0), (2, 0)), 'edge')
      if node_gt_att_threshold == 0:
          node_gt_att = (mean_px > 0).astype(np.float32)
      else:
          node_gt_att = mean_px.copy()
          node_gt_att[node_gt_att < node_gt_att_threshold] = 0

      node_gt_att = torch.LongTensor(node_gt_att).view(-1)
      row, col = edge_index
      edge_gt_att = torch.LongTensor(node_gt_att[row] * node_gt_att[col]).view(-1)

      data_list.append(
          Data(
              x=torch.tensor(x),
              y=torch.LongTensor([labels[index]]),
              edge_index=edge_index,
              edge_attr=edge_attr,
              node_gt_att=node_gt_att,
              edge_gt_att=edge_gt_att

          )
      )

  #torch.save(InMemoryDataset.collate(data_list), '/content/drive/MyDrive/Colab_Notebooks/colorMINST05_2000.pt')
  return data_list

In [14]:
train_dir='/content/drive/MyDrive/CMINST_data/CMNIST095_10000_75sp_train.pkl'
val_dir='/content/drive/MyDrive/CMINST_data/CMNIST5000_75sp_val.pkl'
test_dir='/content/drive/MyDrive/CMINST_data/CMNIST10000_75sp_test.pkl'
training_final=process(data_file=train_dir)
valing_final=process(data_file=val_dir)
testing_final=process(data_file=test_dir)



In [5]:
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros
import pdb

def mask_graph(graph_x,select_node):
  mask_value=np.array([0.0001]*32)
  result=np.array([t.detach().numpy() for t in graph_x])
  for i in range(graph_x.shape[0]):
    if(i in select_node):
      continue
    result[i]=mask_value
  return torch.tensor(result)





def split_graph(graph_x,node_of_graph,type_of_graph=True):
  if(type_of_graph==True):
    select_node_number=int(node_of_graph/3)
    select_node=torch.topk(graph_x.mean(axis=1),select_node_number)[1]
    #print(select_node)
    return mask_graph(graph_x,select_node),select_node
  else:
    select_node_number=int(node_of_graph/3*2)
    select_node=torch.topk(graph_x.mean(axis=1),select_node_number)[1]
    return mask_graph(graph_x,select_node),select_node

def uncertainty_mask_gnerate(node_in_graph,number_of_mask):
  all_mask=[]
  for i in range(number_of_mask):
    random_mask=random.sample(node_in_graph.tolist(),int(len(node_in_graph)/3))
    all_mask.append(random_mask)
  return all_mask
def mean(l):
  return sum(l)/len(l)
class GCNConv(MessagePassing):

    def __init__(self,
                 in_channels,
                 out_channels,
                 improved=False,
                 cached=False,
                 bias=True,
                 edge_norm=True,
                 gfn=False):
        super(GCNConv, self).__init__('add')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.edge_norm = edge_norm
        self.gfn = gfn
        self.message_mask = None
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)

        edge_weight = edge_weight.view(-1)


        assert edge_weight.size(0) == edge_index.size(1)

        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        # Add edge_weight for loop edges.
        loop_weight = torch.full((num_nodes, ),
                                 1 if not improved else 2,
                                 dtype=edge_weight.dtype,
                                 device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0

        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""

        x = torch.matmul(x, self.weight)
        if self.gfn:
            return x

        if not self.cached or self.cached_result is None:
            if self.edge_norm:
                edge_index, norm = GCNConv.norm(
                    edge_index,
                    x.size(0),
                    edge_weight,
                    self.improved,
                    x.dtype)
            else:
                norm = None
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):

        if self.edge_norm:
            return norm.view(-1, 1) * x_j
        else:
            return x_j

    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [6]:
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, BatchNorm1d, Sequential, ReLU
from torch_geometric.nn import global_mean_pool, global_add_pool, GINConv, GATConv

import random
import pdb
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

layers=3
with_random=True
fc_num=222
hidden=32
eval_random=False
class GINNet(torch.nn.Module):
    def __init__(self, num_features,
                       num_classes,
                       hidden=32,
                       num_fc_layers=2,
                       num_conv_layers=3,
                       dropout=0):

        super(GINNet, self).__init__()
        self.global_pool = global_add_pool
        self.dropout = dropout
        hidden_in = num_features
        hidden_out = num_classes

        self.bn_feat = BatchNorm1d(hidden_in)
        self.conv_feat = GCNConv(hidden_in, hidden, gfn=True) # linear transform

        self.convs = torch.nn.ModuleList()
        for i in range(num_conv_layers):
            self.convs.append(GINConv(
            Sequential(Linear(hidden, hidden),
                       BatchNorm1d(hidden),
                       ReLU(),
                       Linear(hidden, hidden),
                       ReLU())))

        self.bn_hidden = BatchNorm1d(hidden)
        self.bns_fc = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()

        for i in range(num_fc_layers - 1):
            self.bns_fc.append(BatchNorm1d(hidden))
            self.lins.append(Linear(hidden, hidden))
        self.lin_class = Linear(hidden, hidden_out)

        # BN initialization.
        for m in self.modules():
            if isinstance(m, (torch.nn.BatchNorm1d)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0.0001)

    def forward(self, data):
        x = data.x if data.x is not None else data.feat
        edge_index, batch = data.edge_index, data.batch
        # x, edge_index, batch = data.feat, data.edge_index, data.batch
        x = self.bn_feat(x)
        x = F.relu(self.conv_feat(x, edge_index))

        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)

        x = self.global_pool(x, batch)
        for i, lin in enumerate(self.lins):
            x = self.bns_fc[i](x)
            x = F.relu(lin(x))
        x = self.bn_hidden(x)
        x = self.lin_class(x)

        prediction = F.log_softmax(x, dim=-1)
        return prediction


In [7]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader
from torch import tensor
import torch_geometric.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
import pdb
import random
import numpy as np
from torch.autograd import grad
from torch_geometric.data import Batch


In [15]:
train_loader = DataLoader(training_final, batch_size=128, shuffle=True)
val_loader = DataLoader(valing_final, batch_size=128, shuffle=False)
test_loader = DataLoader(testing_final, batch_size=128, shuffle=False)
#t_load=[]
#for i in train_loader:
#  t_load.append(i)
#  if(len(t_load)==10000):
#    break


In [16]:
number_of_class=10
Epo=500
model= GINNet(7,number_of_class).to(device)

optimizer = Adam(model.parameters(), lr=0.001)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=Epo, eta_min=1e-6, last_epoch=-1, verbose=False)


In [17]:
import time
import json

loss_value=[]
loss_value_valation=[]

def num_graphs(data):
  if data.batch is not None:
      return data.num_graphs
  else:
      return data.x.size(0)
from tqdm import tqdm

for epoch in range(Epo):
  model.train()
  total_loss = 0
  correct = 0
  nb=0
  print(f"-----training-------{epoch}")
  loop = tqdm(enumerate(train_loader),total=len(train_loader))
  for it, data in loop:
#  for it, data in enumerate(train_loader):
      nb+=1
      optimizer.zero_grad()
      data = data.to(device)
      out = model(data)
      loss = F.nll_loss(out, data.y.view(-1))
      pred = out.max(1)[1]
      correct += pred.eq(data.y.view(-1)).sum().item()
      loss.backward()
      total_loss += loss.item() #* num_graphs(data)
      optimizer.step()
      loop.set_description(f"Epoch [{epoch}/{Epo}]")
      loop.set_postfix(loss = loss.item())

  #num = len(train_loader.dataset)
  total_loss = total_loss / nb
  lr_scheduler.step()

  print(f'number of {epoch} with total loss:{total_loss}')
  loss_value.append(total_loss)
  correct = correct / nb
  with torch.no_grad():
    model.eval()
    correct = 0
    print(f"------valation---------{epoch}")
    for data in test_loader:
      data = data.to(device)
      pred = model(data).max(1)[1]
      correct += pred.eq(data.y.view(-1)).sum().item()
    acc_o = correct / len(test_loader.dataset)
    print(f"causal val accuracy:{acc_o}")
    loss_value_valation.append(acc_o)
    dictionary={"number of epoch":epoch,
                "training loss list":loss_value,
                "valation accuracy list":loss_value_valation}

    # Serializing json
    json_object = json.dumps(dictionary,indent=3)

    # Writing to sample.json
    with open("/content/drive/MyDrive/running_nodebias_mnist/numberGIN_tl_va_e09.json", "w") as outfile:
        outfile.write(json_object)

    #torch.save(causal_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430cau_my6000.pt')
    #torch.save(predictco_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430caupred_my6000.pt')
    #torch.save(predictno_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430noncaupred_my6000.pt')
    #torch.save(model_heter.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430heter6000.pt')
    torch.save({
            'GCN_model.state_dic': model.state_dict(),
            'opt':optimizer.state_dict()
            }, '/content/drive/MyDrive/running_nodebias_mnist/GINmodel_09.pt')
    if(epoch>50):
      check=abs(acc_o-loss_value_valation[len(loss_value_valation)-20])/20
      if(check<=0.001):
        break


  #model_optimizer.zero_grad()
  #total_loss.backward()
  #model_optimizer.step()
#test_loader = DataLoader(testing_final, batch_size=1, shuffle=False)

model.eval()
correct = 0
print(f"------test---------{00}")
for data in test_loader:
  data = data.to(device)
  pred = model(data).max(1)[1]
  correct += pred.eq(data.y.view(-1)).sum().item()
acc_o = correct / len(test_loader.dataset)
print(f"causal val accuracy:{acc_o}")
dictionary={"number of epoch":epoch,
        "training loss list":loss_value,
        "valation accuracy list":loss_value_valation,
        "test accuracy value":acc_o}

# Serializing json
json_object = json.dumps(dictionary,indent=4)

# Writing to sample.json
with open("/content/drive/MyDrive/running_nodebias_mnist/numberGIN_tl_va_e09.json", "w") as outfile:
  outfile.write(json_object)

-----training-------0


Epoch [0/500]: 100%|██████████| 79/79 [00:07<00:00, 10.09it/s, loss=0.92]


number of 0 with total loss:0.8482138680506356
------valation---------0
causal val accuracy:0.1042
-----training-------1


Epoch [1/500]: 100%|██████████| 79/79 [00:07<00:00, 11.15it/s, loss=0.411]


number of 1 with total loss:0.34317672686486306
------valation---------1
causal val accuracy:0.0995
-----training-------2


Epoch [2/500]: 100%|██████████| 79/79 [00:08<00:00,  8.94it/s, loss=0.51]


number of 2 with total loss:0.2925340548346314
------valation---------2
causal val accuracy:0.1076
-----training-------3


Epoch [3/500]: 100%|██████████| 79/79 [00:08<00:00,  9.56it/s, loss=0.253]


number of 3 with total loss:0.27823452236531654
------valation---------3
causal val accuracy:0.0978
-----training-------4


Epoch [4/500]: 100%|██████████| 79/79 [00:07<00:00, 11.04it/s, loss=0.517]


number of 4 with total loss:0.2693803995093213
------valation---------4
causal val accuracy:0.1126
-----training-------5


Epoch [5/500]: 100%|██████████| 79/79 [00:08<00:00,  9.57it/s, loss=0.0669]


number of 5 with total loss:0.2663979685004753
------valation---------5
causal val accuracy:0.1191
-----training-------6


Epoch [6/500]: 100%|██████████| 79/79 [00:08<00:00,  8.93it/s, loss=0.118]


number of 6 with total loss:0.25299726030494596
------valation---------6
causal val accuracy:0.119
-----training-------7


Epoch [7/500]: 100%|██████████| 79/79 [00:07<00:00, 11.14it/s, loss=0.251]


number of 7 with total loss:0.25006031622237795
------valation---------7
causal val accuracy:0.1151
-----training-------8


Epoch [8/500]: 100%|██████████| 79/79 [00:07<00:00, 10.69it/s, loss=0.9]


number of 8 with total loss:0.25444487067340293
------valation---------8
causal val accuracy:0.1138
-----training-------9


Epoch [9/500]: 100%|██████████| 79/79 [00:08<00:00,  8.90it/s, loss=0.898]


number of 9 with total loss:0.24809577542392514
------valation---------9
causal val accuracy:0.1154
-----training-------10


Epoch [10/500]: 100%|██████████| 79/79 [00:07<00:00,  9.88it/s, loss=0.295]


number of 10 with total loss:0.2521335549558265
------valation---------10
causal val accuracy:0.1208
-----training-------11


Epoch [11/500]: 100%|██████████| 79/79 [00:06<00:00, 11.39it/s, loss=1.43]


number of 11 with total loss:0.252011372507373
------valation---------11
causal val accuracy:0.129
-----training-------12


Epoch [12/500]: 100%|██████████| 79/79 [00:08<00:00,  9.14it/s, loss=0.0631]


number of 12 with total loss:0.2388455137799058
------valation---------12
causal val accuracy:0.1262
-----training-------13


Epoch [13/500]: 100%|██████████| 79/79 [00:08<00:00,  9.00it/s, loss=0.971]


number of 13 with total loss:0.23536147856259648
------valation---------13
causal val accuracy:0.1268
-----training-------14


Epoch [14/500]: 100%|██████████| 79/79 [00:07<00:00, 10.98it/s, loss=0.124]


number of 14 with total loss:0.2508641863359681
------valation---------14
causal val accuracy:0.1331
-----training-------15


Epoch [15/500]: 100%|██████████| 79/79 [00:07<00:00,  9.99it/s, loss=0.341]


number of 15 with total loss:0.22707797002188768
------valation---------15
causal val accuracy:0.1271
-----training-------16


Epoch [16/500]: 100%|██████████| 79/79 [00:09<00:00,  8.69it/s, loss=0.459]


number of 16 with total loss:0.22669229882804653
------valation---------16
causal val accuracy:0.1219
-----training-------17


Epoch [17/500]: 100%|██████████| 79/79 [00:07<00:00,  9.92it/s, loss=0.305]


number of 17 with total loss:0.2180659432860115
------valation---------17
causal val accuracy:0.133
-----training-------18


Epoch [18/500]: 100%|██████████| 79/79 [00:07<00:00, 11.24it/s, loss=0.865]


number of 18 with total loss:0.22157107812317112
------valation---------18
causal val accuracy:0.1335
-----training-------19


Epoch [19/500]: 100%|██████████| 79/79 [00:08<00:00,  9.04it/s, loss=0.515]


number of 19 with total loss:0.22267491144092777
------valation---------19
causal val accuracy:0.1172
-----training-------20


Epoch [20/500]: 100%|██████████| 79/79 [00:08<00:00,  8.95it/s, loss=0.125]


number of 20 with total loss:0.21162551353815234
------valation---------20
causal val accuracy:0.1343
-----training-------21


Epoch [21/500]: 100%|██████████| 79/79 [00:07<00:00, 11.04it/s, loss=0.118]


number of 21 with total loss:0.20126400878535042
------valation---------21
causal val accuracy:0.141
-----training-------22


Epoch [22/500]: 100%|██████████| 79/79 [00:07<00:00, 10.12it/s, loss=0.588]


number of 22 with total loss:0.21610771987257124
------valation---------22
causal val accuracy:0.1298
-----training-------23


Epoch [23/500]: 100%|██████████| 79/79 [00:09<00:00,  8.74it/s, loss=0.0384]


number of 23 with total loss:0.20403415728596191
------valation---------23
causal val accuracy:0.1359
-----training-------24


Epoch [24/500]: 100%|██████████| 79/79 [00:08<00:00,  9.61it/s, loss=0.668]


number of 24 with total loss:0.20305045486628254
------valation---------24
causal val accuracy:0.1367
-----training-------25


Epoch [25/500]: 100%|██████████| 79/79 [00:07<00:00, 11.06it/s, loss=0.287]


number of 25 with total loss:0.21299382417073734
------valation---------25
causal val accuracy:0.1446
-----training-------26


Epoch [26/500]: 100%|██████████| 79/79 [00:08<00:00,  9.41it/s, loss=0.0445]


number of 26 with total loss:0.19677178674860846
------valation---------26
causal val accuracy:0.1339
-----training-------27


Epoch [27/500]: 100%|██████████| 79/79 [00:08<00:00,  8.84it/s, loss=0.215]


number of 27 with total loss:0.1905370348993736
------valation---------27
causal val accuracy:0.1369
-----training-------28


Epoch [28/500]: 100%|██████████| 79/79 [00:07<00:00, 10.95it/s, loss=0.245]


number of 28 with total loss:0.18963140202096745
------valation---------28
causal val accuracy:0.1554
-----training-------29


Epoch [29/500]: 100%|██████████| 79/79 [00:07<00:00, 10.35it/s, loss=0.154]


number of 29 with total loss:0.18753657508877258
------valation---------29
causal val accuracy:0.1485
-----training-------30


Epoch [30/500]: 100%|██████████| 79/79 [00:08<00:00,  8.85it/s, loss=0.114]


number of 30 with total loss:0.19256037685878669
------valation---------30
causal val accuracy:0.1582
-----training-------31


Epoch [31/500]: 100%|██████████| 79/79 [00:07<00:00, 10.00it/s, loss=0.0646]


number of 31 with total loss:0.18626866989497898
------valation---------31
causal val accuracy:0.1459
-----training-------32


Epoch [32/500]: 100%|██████████| 79/79 [00:07<00:00, 11.17it/s, loss=0.817]


number of 32 with total loss:0.18798545243430742
------valation---------32
causal val accuracy:0.1535
-----training-------33


Epoch [33/500]: 100%|██████████| 79/79 [00:08<00:00,  8.93it/s, loss=0.248]


number of 33 with total loss:0.1860834164427051
------valation---------33
causal val accuracy:0.1674
-----training-------34


Epoch [34/500]: 100%|██████████| 79/79 [00:08<00:00,  8.83it/s, loss=0.137]


number of 34 with total loss:0.18127908699120146
------valation---------34
causal val accuracy:0.1597
-----training-------35


Epoch [35/500]: 100%|██████████| 79/79 [00:07<00:00, 11.03it/s, loss=0.582]


number of 35 with total loss:0.18185265748938428
------valation---------35
causal val accuracy:0.1713
-----training-------36


Epoch [36/500]: 100%|██████████| 79/79 [00:07<00:00, 10.17it/s, loss=0.356]


number of 36 with total loss:0.18197207954488223
------valation---------36
causal val accuracy:0.175
-----training-------37


Epoch [37/500]: 100%|██████████| 79/79 [00:08<00:00,  8.83it/s, loss=0.0694]


number of 37 with total loss:0.1756074989709673
------valation---------37
causal val accuracy:0.1757
-----training-------38


Epoch [38/500]: 100%|██████████| 79/79 [00:07<00:00, 10.36it/s, loss=1.19]


number of 38 with total loss:0.1776611720553682
------valation---------38
causal val accuracy:0.1723
-----training-------39


Epoch [39/500]: 100%|██████████| 79/79 [00:07<00:00, 11.02it/s, loss=1.52]


number of 39 with total loss:0.21865779835777946
------valation---------39
causal val accuracy:0.1456
-----training-------40


Epoch [40/500]: 100%|██████████| 79/79 [00:09<00:00,  8.73it/s, loss=0.171]


number of 40 with total loss:0.19809550546769855
------valation---------40
causal val accuracy:0.153
-----training-------41


Epoch [41/500]: 100%|██████████| 79/79 [00:08<00:00,  9.12it/s, loss=0.222]


number of 41 with total loss:0.17653076770373538
------valation---------41
causal val accuracy:0.1658
-----training-------42


Epoch [42/500]: 100%|██████████| 79/79 [00:07<00:00, 11.18it/s, loss=0.215]


number of 42 with total loss:0.18652558637947975
------valation---------42
causal val accuracy:0.1599
-----training-------43


Epoch [43/500]: 100%|██████████| 79/79 [00:07<00:00,  9.95it/s, loss=0.0894]


number of 43 with total loss:0.16818679411765897
------valation---------43
causal val accuracy:0.1802
-----training-------44


Epoch [44/500]: 100%|██████████| 79/79 [00:08<00:00,  8.85it/s, loss=0.3]


number of 44 with total loss:0.1667749780359902
------valation---------44
causal val accuracy:0.1497
-----training-------45


Epoch [45/500]: 100%|██████████| 79/79 [00:07<00:00, 10.47it/s, loss=0.548]


number of 45 with total loss:0.1635385867255398
------valation---------45
causal val accuracy:0.166
-----training-------46


Epoch [46/500]: 100%|██████████| 79/79 [00:07<00:00, 11.10it/s, loss=0.567]


number of 46 with total loss:0.1672742119224011
------valation---------46
causal val accuracy:0.1875
-----training-------47


Epoch [47/500]: 100%|██████████| 79/79 [00:09<00:00,  8.77it/s, loss=0.303]


number of 47 with total loss:0.17466160117448132
------valation---------47
causal val accuracy:0.1764
-----training-------48


Epoch [48/500]: 100%|██████████| 79/79 [00:08<00:00,  9.28it/s, loss=0.543]


number of 48 with total loss:0.16996486574600014
------valation---------48
causal val accuracy:0.166
-----training-------49


Epoch [49/500]: 100%|██████████| 79/79 [00:07<00:00, 11.16it/s, loss=0.0224]


number of 49 with total loss:0.1611057416406236
------valation---------49
causal val accuracy:0.1985
-----training-------50


Epoch [50/500]: 100%|██████████| 79/79 [00:08<00:00,  9.63it/s, loss=0.134]


number of 50 with total loss:0.15038080935519707
------valation---------50
causal val accuracy:0.1746
-----training-------51


Epoch [51/500]: 100%|██████████| 79/79 [00:08<00:00,  8.80it/s, loss=0.447]


number of 51 with total loss:0.15664366132850888
------valation---------51
causal val accuracy:0.1943
-----training-------52


Epoch [52/500]: 100%|██████████| 79/79 [00:07<00:00, 10.34it/s, loss=0.0334]


number of 52 with total loss:0.15797060397984106
------valation---------52
causal val accuracy:0.201
-----training-------53


Epoch [53/500]: 100%|██████████| 79/79 [00:07<00:00, 10.91it/s, loss=0.121]


number of 53 with total loss:0.14788537439477595
------valation---------53
causal val accuracy:0.1851
-----training-------54


Epoch [54/500]: 100%|██████████| 79/79 [00:09<00:00,  8.68it/s, loss=0.091]


number of 54 with total loss:0.15043567043222203
------valation---------54
causal val accuracy:0.1764
------test---------0
causal val accuracy:0.1764


In [None]:
aaa

NameError: ignored

In [None]:
checkpoint = torch.load('/content/drive/MyDrive/running_nodebias_mnist/GINmodel_095.pt')
model.load_state_dict(checkpoint['GCN_model.state_dic'])

In [None]:
test_dir='/content/drive/MyDrive/CMINST_data/CMNIST6000_75sp_test.pkl'
testing_final=process(data_file=test_dir)
test_loader = DataLoader(testing_final, batch_size=1, shuffle=False)

model.eval()
correct = 0
print(f"------test---------{00}")
for data in test_loader:
  data = data.to(device)
  pred = model(data).max(1)[1]
  correct += pred.eq(data.y.view(-1)).sum().item()
acc_o = correct / len(test_loader.dataset)
print(f"causal val accuracy:{acc_o}")
