In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install warnings
!pip install dgl
!pip install texttable

2.0.0+cu118
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m55.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m58.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dgl
  Downloading dgl-1.1.0-cp310-cp310-manylinux1_x86_64.whl (5.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.1.0
Looking in indexes: https

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
import random
from torchvision import transforms, datasets

import os
import pickle
from scipy.spatial.distance import cdist
from scipy import ndimage
import numpy as np

import dgl
import torch
import time
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib
def sigma(dists, kth=8):
    # Get k-nearest neighbors for each node
    knns = np.partition(dists, kth, axis=-1)[:, kth::-1]

    # Compute sigma and reshape
    sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth
    return sigma + 1e-8 # adding epsilon to avoid zero value of sigma

def compute_adjacency_matrix_images(coord, feat, use_feat=False, kth=8):
    coord = coord.reshape(-1, 2)
    # Compute coordinate distance
    c_dist = cdist(coord, coord)
    
    if use_feat:
        # Compute feature distance
        f_dist = cdist(feat, feat)
        # Compute adjacency
        A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )
    else:
        A = np.exp(- (c_dist/sigma(c_dist))**2)
        
    # Convert to symmetric matrix
    A = 0.5 * A * A.T
    A[np.diag_indices_from(A)] = 0
    return A

def compute_edges_list(A, kth=8+1):
    # Get k-similar neighbor indices for each node
    if 1==1:   
        num_nodes = A.shape[0]
        new_kth = num_nodes - kth
        knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]
        knns_d = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1]
    else:
        knns = np.argpartition(A, kth, axis=-1)[:, kth::-1]
        knns_d = np.partition(A, kth, axis=-1)[:, kth::-1]
    return knns, knns_d
class newCIFARSuperPix(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir,
                 use_mean_px=True,
                 use_coord=True,
                 use_feat_for_graph_construct=False,):

        #self.split = split
        #self.is_test = split.lower() in ['test', 'val']
        with open(data_dir, 'rb') as f:
            self.labels, self.sp_data = pickle.load(f)

        self.use_mean_px = use_mean_px
        self.use_feat_for_graph = use_feat_for_graph_construct
        self.use_coord = use_coord
        self.n_samples = len(self.labels)
        self.img_size = 32

    def precompute_graph_images(self):
        #print('precompute all data for the %s set...' % self.split.upper())
        self.Adj_matrices, self.node_features, self.edges_lists = [], [], []
        for index, sample in enumerate(self.sp_data):
            mean_px, coord = sample[:2]
            coord = coord / self.img_size
            A = compute_adjacency_matrix_images(coord, mean_px, use_feat=self.use_feat_for_graph)
            edges_list, _ = compute_edges_list(A)
            N_nodes = A.shape[0]
            
            x = None
            if self.use_mean_px:
                x = mean_px.reshape(N_nodes, -1)
            if self.use_coord:
                coord = coord.reshape(N_nodes, 2)
                if self.use_mean_px:
                    x = np.concatenate((x, coord), axis=1)
                else:
                    x = coord
            if x is None:
                x = np.ones(N_nodes, 1)  # dummy features
            
            self.node_features.append(x)
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, index):
        g = dgl.DGLGraph()
        g.add_nodes(self.node_features[index].shape[0])
        g.ndata['feat'] = torch.Tensor(self.node_features[index])
        for src, dsts in enumerate(self.edges_lists[index]):
            g.add_edges(src, dsts[dsts!=src])

        return g, self.labels[index]

use_feat_for_graph_construct = False
tt = time.time()
data_with_feat_knn = newCIFARSuperPix("/content/drive/MyDrive/CMINST_data/CMNIST09_60000_75sp_train.pkl",use_feat_for_graph_construct=use_feat_for_graph_construct)

data_with_feat_knn.precompute_graph_images()
training_data = np.load('/content/drive/MyDrive/CMINST_data/colorMNIST09_60000_data.npy')
training_label=np.load('/content/drive/MyDrive/CMINST_data/colorMNIST09_60000_label.npy')

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [4]:
import numpy as np
import os.path as osp
import pickle
import torch
import torch.utils
import torch.utils.data
import torch.nn.functional as F
from scipy.spatial.distance import cdist
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import InMemoryDataset,Data
#/content/drive/MyDrive/Colab_Notebooks/mnist08_83_75sp_train.pkl
def compute_adjacency_matrix_images(coord, sigma=0.1):
    coord = coord.reshape(-1, 2)
    dist = cdist(coord, coord)
    A = np.exp(- dist / (sigma * np.pi) ** 2)
    A[np.diag_indices_from(A)] = 0
    return A


def list_to_torch(data):
    for i in range(len(data)):
        if data[i] is None:
            continue
        elif isinstance(data[i], np.ndarray):
            if data[i].dtype == np.bool:
                data[i] = data[i].astype(np.float32)
            data[i] = torch.from_numpy(data[i]).float()
        elif isinstance(data[i], list):
            data[i] = list_to_torch(data[i])
    return data
def process(data_file):
  use_mean_px=True
  use_coord=True
  node_gt_att_threshold=0
  transform=None
  pre_transform=None
  pre_filter=None

  #data_file ='/content/drive/MyDrive/Colab_Notebooks/colorMNIST05_2000_75sp_train.pkl' 

  with open(osp.join(data_file), 'rb') as f:
      labels,sp_data = pickle.load(f)
      
  #use_mean_px = self.use_mean_px
  #self.use_coord = self.use_coord
  n_samples = len(labels)
  img_size = 32
  #node_gt_att_threshold = self.node_gt_att_threshold

  edge_indices,xs,edge_attrs,node_gt_atts,edge_gt_atts = [], [], [], [], []
  data_list = []
  for index, sample in enumerate(sp_data):
      mean_px, coord = sample[:2]
      coord = coord / img_size
      A = compute_adjacency_matrix_images(coord)
      N_nodes = A.shape[0]
      
      A = torch.FloatTensor((A > 0.1) * A)
      edge_index, edge_attr = dense_to_sparse(A)

      x = None
      if use_mean_px:
          x = mean_px.reshape(N_nodes, -1)
      if use_coord:
          coord = coord.reshape(N_nodes, 2)
          if use_mean_px:
              x = np.concatenate((x, coord), axis=1)
          else:
              x = coord
      if x is None:
          x = np.ones(N_nodes, 1)  # dummy features
          
      # replicate features to make it possible to test on colored images
      x = np.pad(x, ((0, 0), (2, 0)), 'edge')  
      if node_gt_att_threshold == 0:
          node_gt_att = (mean_px > 0).astype(np.float32)
      else:
          node_gt_att = mean_px.copy()
          node_gt_att[node_gt_att < node_gt_att_threshold] = 0

      node_gt_att = torch.LongTensor(node_gt_att).view(-1)
      row, col = edge_index
      edge_gt_att = torch.LongTensor(node_gt_att[row] * node_gt_att[col]).view(-1)

      data_list.append(
          Data(
              x=torch.tensor(x), 
              y=torch.LongTensor([labels[index]]), 
              edge_index=edge_index,
              edge_attr=edge_attr, 
              node_gt_att=node_gt_att,
              edge_gt_att=edge_gt_att

          )
      )

  #torch.save(InMemoryDataset.collate(data_list), '/content/drive/MyDrive/Colab_Notebooks/colorMINST05_2000.pt')
  return data_list

In [5]:
train_dir='/content/drive/MyDrive/CMINST_data/CMNIST09_10000_75sp_train.pkl'
val_dir='/content/drive/MyDrive/CMINST_data/CMNIST5000_75sp_val.pkl'
test_dir='/content/drive/MyDrive/CMINST_data/CMNIST10000_75sp_test.pkl'
training_final=process(data_file=train_dir)
valing_final=process(data_file=val_dir)
testing_final=process(data_file=test_dir)



In [6]:
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros
import pdb

def mask_graph(graph_x,select_node):
  mask_value=np.array([0.0001]*32)
  result=np.array([t.detach().numpy() for t in graph_x])
  for i in range(graph_x.shape[0]):
    if(i in select_node):
      continue
    result[i]=mask_value
  return torch.tensor(result)





def split_graph(graph_x,node_of_graph,type_of_graph=True):
  if(type_of_graph==True):
    select_node_number=int(node_of_graph/3)
    select_node=torch.topk(graph_x.mean(axis=1),select_node_number)[1]
    #print(select_node)
    return mask_graph(graph_x,select_node),select_node
  else:
    select_node_number=int(node_of_graph/3*2)
    select_node=torch.topk(graph_x.mean(axis=1),select_node_number)[1]
    return mask_graph(graph_x,select_node),select_node

def uncertainty_mask_gnerate(node_in_graph,number_of_mask):
  all_mask=[]
  for i in range(number_of_mask):
    random_mask=random.sample(node_in_graph.tolist(),int(len(node_in_graph)/3))
    all_mask.append(random_mask)
  return all_mask
def mean(l):
  return sum(l)/len(l)
class GCNConv(MessagePassing):
    
    def __init__(self,
                 in_channels,
                 out_channels,
                 improved=False,
                 cached=False,
                 bias=True,
                 edge_norm=True,
                 gfn=False):
        super(GCNConv, self).__init__('add')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.edge_norm = edge_norm
        self.gfn = gfn
        self.message_mask = None
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)
        
        edge_weight = edge_weight.view(-1)
        
        
        assert edge_weight.size(0) == edge_index.size(1)
        
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        # Add edge_weight for loop edges.
        loop_weight = torch.full((num_nodes, ),
                                 1 if not improved else 2,
                                 dtype=edge_weight.dtype,
                                 device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        
        x = torch.matmul(x, self.weight)
        if self.gfn:
            return x
    
        if not self.cached or self.cached_result is None:
            if self.edge_norm:
                edge_index, norm = GCNConv.norm(
                    edge_index, 
                    x.size(0), 
                    edge_weight, 
                    self.improved, 
                    x.dtype)
            else:
                norm = None
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):

        if self.edge_norm:
            return norm.view(-1, 1) * x_j
        else:
            return x_j
        
    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [7]:
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, BatchNorm1d, Sequential, ReLU
from torch_geometric.nn import global_mean_pool, global_add_pool, GINConv, GATConv

import random
import pdb
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

layers=3
with_random=True
fc_num=222
hidden=32
eval_random=False
class GATNet(torch.nn.Module):
    def __init__(self, num_features, 
                       num_classes,
                       hidden=32,
                       head=4,
                       num_fc_layers=2, 
                       num_conv_layers=3, 
                       dropout=0.2):

        super(GATNet, self).__init__()

        self.global_pool = global_add_pool
        self.dropout = dropout
        hidden_in = num_features
        hidden_out = num_classes
   
        self.bn_feat = BatchNorm1d(hidden_in)
        self.conv_feat = GCNConv(hidden_in, hidden, gfn=True) # linear transform
        self.bns_conv = torch.nn.ModuleList()
        self.convs = torch.nn.ModuleList()

        for i in range(num_conv_layers):
            self.bns_conv.append(BatchNorm1d(hidden))
            self.convs.append(GATConv(hidden, int(hidden / head), heads=head, dropout=dropout))
        self.bn_hidden = BatchNorm1d(hidden)
        self.bns_fc = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()

        for i in range(num_fc_layers - 1):
            self.bns_fc.append(BatchNorm1d(hidden))
            self.lins.append(Linear(hidden, hidden))
        self.lin_class = Linear(hidden, hidden_out)

        # BN initialization.
        for m in self.modules():
            if isinstance(m, (torch.nn.BatchNorm1d)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0.0001)

    def forward(self, data):
        
        x = data.x if data.x is not None else data.feat
        edge_index, batch = data.edge_index, data.batch
        
        x = self.bn_feat(x)
        x = F.relu(self.conv_feat(x, edge_index))
        
        for i, conv in enumerate(self.convs):
            x = self.bns_conv[i](x)
            x = F.relu(conv(x, edge_index))

        x = self.global_pool(x, batch)
        for i, lin in enumerate(self.lins):
            x = self.bns_fc[i](x)
            x = F.relu(lin(x))

        x = self.bn_hidden(x)
        if self.dropout > 0:
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin_class(x)
        return F.log_softmax(x, dim=-1)

In [8]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader
from torch import tensor
import torch_geometric.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
import pdb
import random
import numpy as np
from torch.autograd import grad
from torch_geometric.data import Batch


In [9]:
train_loader = DataLoader(training_final, batch_size=128, shuffle=True)
val_loader = DataLoader(valing_final, batch_size=128, shuffle=False)
#t_load=[]
#for i in train_loader:
#  t_load.append(i)
#  if(len(t_load)==10000):
#    break




In [10]:
number_of_class=10
Epo=500
model= GATNet(7,number_of_class).to(device)

optimizer = Adam(model.parameters(), lr=0.001)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=Epo, eta_min=1e-6, last_epoch=-1, verbose=False)


In [11]:
import time
import json

loss_value=[]
loss_value_valation=[]

def num_graphs(data):
  if data.batch is not None:
      return data.num_graphs
  else:
      return data.x.size(0)
from tqdm import tqdm

for epoch in range(Epo):
  model.train()
  total_loss = 0
  correct = 0
  nb=0
  print(f"-----training-------{epoch}")
  loop = tqdm(enumerate(train_loader),total=len(train_loader))
  for it, data in loop:
#  for it, data in enumerate(train_loader):
      nb+=1
      optimizer.zero_grad()
      data = data.to(device)
      out = model(data)
      loss = F.nll_loss(out, data.y.view(-1))
      pred = out.max(1)[1]
      correct += pred.eq(data.y.view(-1)).sum().item()
      loss.backward()
      total_loss += loss.item() #* num_graphs(data)
      optimizer.step()
      loop.set_description(f"Epoch [{epoch}/{Epo}]")
      loop.set_postfix(loss = loss.item())

  #num = len(train_loader.dataset)
  total_loss = total_loss / nb
  lr_scheduler.step()

  print(f'number of {epoch} with total loss:{total_loss}')
  loss_value.append(total_loss)
  correct = correct / nb
  with torch.no_grad():
    model.eval()
    correct = 0
    print(f"------valation---------{epoch}")
    for data in val_loader:
      data = data.to(device)
      pred = model(data).max(1)[1]
      correct += pred.eq(data.y.view(-1)).sum().item()
    acc_o = correct / len(val_loader.dataset)
    print(f"causal val accuracy:{acc_o}")
    loss_value_valation.append(acc_o)
    dictionary={"number of epoch":epoch,
                "training loss list":loss_value,
                "valation accuracy list":loss_value_valation}
 
    # Serializing json
    json_object = json.dumps(dictionary,indent=3)
    
    # Writing to sample.json
    with open("/content/drive/MyDrive/running_nodebias_mnist/numberGAT_tl_va_e09.json", "w") as outfile:
        outfile.write(json_object)

    #torch.save(causal_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430cau_my6000.pt')
    #torch.save(predictco_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430caupred_my6000.pt')
    #torch.save(predictno_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430noncaupred_my6000.pt')
    #torch.save(model_heter.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430heter6000.pt')
    torch.save({
            'GCN_model.state_dic': model.state_dict(),
            'opt':optimizer.state_dict()
            }, '/content/drive/MyDrive/running_nodebias_mnist/GATmodel_09.pt')
    if(epoch>50):
      check=abs(acc_o-loss_value_valation[len(loss_value_valation)-20])/20
      if(check<=0.0001):
        break


  #model_optimizer.zero_grad()
  #total_loss.backward()
  #model_optimizer.step()
test_loader = DataLoader(testing_final, batch_size=1, shuffle=False)

model.eval()
correct = 0
print(f"------test---------{00}")
for data in test_loader:
  data = data.to(device)
  pred = model(data).max(1)[1]
  correct += pred.eq(data.y.view(-1)).sum().item()
acc_o = correct / len(test_loader.dataset)
print(f"causal val accuracy:{acc_o}")
dictionary={"number of epoch":epoch,
        "training loss list":loss_value,
        "valation accuracy list":loss_value_valation,
        "test accuracy value":acc_o}

# Serializing json
json_object = json.dumps(dictionary,indent=4)

# Writing to sample.json
with open("/content/drive/MyDrive/running_nodebias_mnist/numberGAT_tl_va_e09.json", "w") as outfile:
  outfile.write(json_object)

-----training-------0


Epoch [0/500]: 100%|██████████| 79/79 [00:16<00:00,  4.88it/s, loss=0.966]


number of 0 with total loss:1.0586319829844222
------valation---------0
causal val accuracy:0.5006
-----training-------1


Epoch [1/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.853]


number of 1 with total loss:0.6062512367586547
------valation---------1
causal val accuracy:0.5016
-----training-------2


Epoch [2/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.571]


number of 2 with total loss:0.549055634797374
------valation---------2
causal val accuracy:0.5124
-----training-------3


Epoch [3/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.402]


number of 3 with total loss:0.5133008790921562
------valation---------3
causal val accuracy:0.5162
-----training-------4


Epoch [4/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.136]


number of 4 with total loss:0.47582089599174787
------valation---------4
causal val accuracy:0.5244
-----training-------5


Epoch [5/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=1.04]


number of 5 with total loss:0.4534171876273578
------valation---------5
causal val accuracy:0.5338
-----training-------6


Epoch [6/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.316]


number of 6 with total loss:0.41294903770277774
------valation---------6
causal val accuracy:0.5474
-----training-------7


Epoch [7/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.526]


number of 7 with total loss:0.403076695490487
------valation---------7
causal val accuracy:0.561
-----training-------8


Epoch [8/500]: 100%|██████████| 79/79 [00:15<00:00,  5.04it/s, loss=0.634]


number of 8 with total loss:0.377538765533061
------valation---------8
causal val accuracy:0.5706
-----training-------9


Epoch [9/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.646]


number of 9 with total loss:0.3649615739719777
------valation---------9
causal val accuracy:0.585
-----training-------10


Epoch [10/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.327]


number of 10 with total loss:0.35268778061564965
------valation---------10
causal val accuracy:0.5996
-----training-------11


Epoch [11/500]: 100%|██████████| 79/79 [00:15<00:00,  5.02it/s, loss=0.431]


number of 11 with total loss:0.33598347404335116
------valation---------11
causal val accuracy:0.6056
-----training-------12


Epoch [12/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.596]


number of 12 with total loss:0.3318550624424898
------valation---------12
causal val accuracy:0.612
-----training-------13


Epoch [13/500]: 100%|██████████| 79/79 [00:16<00:00,  4.77it/s, loss=0.4]


number of 13 with total loss:0.31982784369323825
------valation---------13
causal val accuracy:0.6172
-----training-------14


Epoch [14/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.669]


number of 14 with total loss:0.30922395613374587
------valation---------14
causal val accuracy:0.6338
-----training-------15


Epoch [15/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.554]


number of 15 with total loss:0.29878616540492337
------valation---------15
causal val accuracy:0.6316
-----training-------16


Epoch [16/500]: 100%|██████████| 79/79 [00:16<00:00,  4.66it/s, loss=0.969]


number of 16 with total loss:0.3008185927815075
------valation---------16
causal val accuracy:0.644
-----training-------17


Epoch [17/500]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s, loss=0.627]


number of 17 with total loss:0.2777213346731814
------valation---------17
causal val accuracy:0.6426
-----training-------18


Epoch [18/500]: 100%|██████████| 79/79 [00:16<00:00,  4.81it/s, loss=0.998]


number of 18 with total loss:0.2778552121753934
------valation---------18
causal val accuracy:0.6418
-----training-------19


Epoch [19/500]: 100%|██████████| 79/79 [00:17<00:00,  4.61it/s, loss=0.0823]


number of 19 with total loss:0.25911326187698147
------valation---------19
causal val accuracy:0.6674
-----training-------20


Epoch [20/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.383]


number of 20 with total loss:0.2539155983094928
------valation---------20
causal val accuracy:0.673
-----training-------21


Epoch [21/500]: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s, loss=0.205]


number of 21 with total loss:0.2433813836755632
------valation---------21
causal val accuracy:0.6876
-----training-------22


Epoch [22/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.111]


number of 22 with total loss:0.23517153888374945
------valation---------22
causal val accuracy:0.695
-----training-------23


Epoch [23/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=1.32]


number of 23 with total loss:0.23876143907067143
------valation---------23
causal val accuracy:0.6914
-----training-------24


Epoch [24/500]: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s, loss=0.0847]


number of 24 with total loss:0.21997474520644056
------valation---------24
causal val accuracy:0.7186
-----training-------25


Epoch [25/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.219]


number of 25 with total loss:0.21351694240223004
------valation---------25
causal val accuracy:0.715
-----training-------26


Epoch [26/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.121]


number of 26 with total loss:0.21004417887594126
------valation---------26
causal val accuracy:0.7014
-----training-------27


Epoch [27/500]: 100%|██████████| 79/79 [00:16<00:00,  4.74it/s, loss=0.117]


number of 27 with total loss:0.20372378882728046
------valation---------27
causal val accuracy:0.729
-----training-------28


Epoch [28/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.451]


number of 28 with total loss:0.20106876872574228
------valation---------28
causal val accuracy:0.7256
-----training-------29


Epoch [29/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.0265]


number of 29 with total loss:0.1912239583892913
------valation---------29
causal val accuracy:0.7382
-----training-------30


Epoch [30/500]: 100%|██████████| 79/79 [00:16<00:00,  4.74it/s, loss=0.354]


number of 30 with total loss:0.19115063293447979
------valation---------30
causal val accuracy:0.7508
-----training-------31


Epoch [31/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.474]


number of 31 with total loss:0.18936664714843413
------valation---------31
causal val accuracy:0.743
-----training-------32


Epoch [32/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.407]


number of 32 with total loss:0.1884529520344885
------valation---------32
causal val accuracy:0.7462
-----training-------33


Epoch [33/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.0409]


number of 33 with total loss:0.17309027160458926
------valation---------33
causal val accuracy:0.759
-----training-------34


Epoch [34/500]: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s, loss=0.391]


number of 34 with total loss:0.18075838383240037
------valation---------34
causal val accuracy:0.737
-----training-------35


Epoch [35/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.188]


number of 35 with total loss:0.16911368741642072
------valation---------35
causal val accuracy:0.7508
-----training-------36


Epoch [36/500]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s, loss=0.125]


number of 36 with total loss:0.17217548493343063
------valation---------36
causal val accuracy:0.7712
-----training-------37


Epoch [37/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.0522]


number of 37 with total loss:0.1683346097699449
------valation---------37
causal val accuracy:0.7536
-----training-------38


Epoch [38/500]: 100%|██████████| 79/79 [00:15<00:00,  5.03it/s, loss=0.19]


number of 38 with total loss:0.1649306635784952
------valation---------38
causal val accuracy:0.7708
-----training-------39


Epoch [39/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.336]


number of 39 with total loss:0.16820182490952407
------valation---------39
causal val accuracy:0.7738
-----training-------40


Epoch [40/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.503]


number of 40 with total loss:0.15851785704682145
------valation---------40
causal val accuracy:0.7808
-----training-------41


Epoch [41/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.444]


number of 41 with total loss:0.15484435262182092
------valation---------41
causal val accuracy:0.7766
-----training-------42


Epoch [42/500]: 100%|██████████| 79/79 [00:16<00:00,  4.81it/s, loss=0.216]


number of 42 with total loss:0.15401136417743527
------valation---------42
causal val accuracy:0.7934
-----training-------43


Epoch [43/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.471]


number of 43 with total loss:0.15629961138850526
------valation---------43
causal val accuracy:0.7842
-----training-------44


Epoch [44/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.372]


number of 44 with total loss:0.15691440665646444
------valation---------44
causal val accuracy:0.8016
-----training-------45


Epoch [45/500]: 100%|██████████| 79/79 [00:16<00:00,  4.87it/s, loss=0.0528]


number of 45 with total loss:0.1447757724436778
------valation---------45
causal val accuracy:0.7978
-----training-------46


Epoch [46/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.188]


number of 46 with total loss:0.1394409774036347
------valation---------46
causal val accuracy:0.7882
-----training-------47


Epoch [47/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.183]


number of 47 with total loss:0.14351837884021712
------valation---------47
causal val accuracy:0.7936
-----training-------48


Epoch [48/500]: 100%|██████████| 79/79 [00:16<00:00,  4.88it/s, loss=0.237]


number of 48 with total loss:0.14562341240765173
------valation---------48
causal val accuracy:0.78
-----training-------49


Epoch [49/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.935]


number of 49 with total loss:0.14787825008359137
------valation---------49
causal val accuracy:0.7982
-----training-------50


Epoch [50/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.0362]


number of 50 with total loss:0.14889297360860848
------valation---------50
causal val accuracy:0.7772
-----training-------51


Epoch [51/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.334]


number of 51 with total loss:0.13988652962106693
------valation---------51
causal val accuracy:0.7956
-----training-------52


Epoch [52/500]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s, loss=0.459]


number of 52 with total loss:0.14042399969847896
------valation---------52
causal val accuracy:0.766
-----training-------53


Epoch [53/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.23]


number of 53 with total loss:0.1365802275323415
------valation---------53
causal val accuracy:0.8052
-----training-------54


Epoch [54/500]: 100%|██████████| 79/79 [00:16<00:00,  4.83it/s, loss=0.358]


number of 54 with total loss:0.13259506815030606
------valation---------54
causal val accuracy:0.7948
-----training-------55


Epoch [55/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.011]


number of 55 with total loss:0.13220224070916825
------valation---------55
causal val accuracy:0.8078
-----training-------56


Epoch [56/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.0345]


number of 56 with total loss:0.12293015906139265
------valation---------56
causal val accuracy:0.7934
-----training-------57


Epoch [57/500]: 100%|██████████| 79/79 [00:16<00:00,  4.72it/s, loss=0.0404]


number of 57 with total loss:0.12203796311647078
------valation---------57
causal val accuracy:0.808
-----training-------58


Epoch [58/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.138]


number of 58 with total loss:0.12654133362671996
------valation---------58
causal val accuracy:0.8132
-----training-------59


Epoch [59/500]: 100%|██████████| 79/79 [00:16<00:00,  4.69it/s, loss=0.0914]


number of 59 with total loss:0.11960022467412526
------valation---------59
causal val accuracy:0.8054
-----training-------60


Epoch [60/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.0473]


number of 60 with total loss:0.1204229143884363
------valation---------60
causal val accuracy:0.7898
-----training-------61


Epoch [61/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.0199]


number of 61 with total loss:0.11823493065430393
------valation---------61
causal val accuracy:0.8064
-----training-------62


Epoch [62/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.0119]


number of 62 with total loss:0.10921520600684836
------valation---------62
causal val accuracy:0.7932
-----training-------63


Epoch [63/500]: 100%|██████████| 79/79 [00:16<00:00,  4.83it/s, loss=0.00745]


number of 63 with total loss:0.10813039943835215
------valation---------63
causal val accuracy:0.815
-----training-------64


Epoch [64/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.104]


number of 64 with total loss:0.11049890178668348
------valation---------64
causal val accuracy:0.8104
-----training-------65


Epoch [65/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.00733]


number of 65 with total loss:0.11517967767070365
------valation---------65
causal val accuracy:0.8232
-----training-------66


Epoch [66/500]: 100%|██████████| 79/79 [00:16<00:00,  4.77it/s, loss=0.0616]


number of 66 with total loss:0.11638470210983784
------valation---------66
causal val accuracy:0.8266
-----training-------67


Epoch [67/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.116]


number of 67 with total loss:0.11240197652125661
------valation---------67
causal val accuracy:0.8132
-----training-------68


Epoch [68/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.051]


number of 68 with total loss:0.1096274077279281
------valation---------68
causal val accuracy:0.8232
-----training-------69


Epoch [69/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.0394]


number of 69 with total loss:0.1093760930190358
------valation---------69
causal val accuracy:0.8132
-----training-------70


Epoch [70/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.0102]


number of 70 with total loss:0.10179953644924526
------valation---------70
causal val accuracy:0.8258
-----training-------71


Epoch [71/500]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s, loss=0.166]


number of 71 with total loss:0.10487798757002323
------valation---------71
causal val accuracy:0.8256
-----training-------72


Epoch [72/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.0835]


number of 72 with total loss:0.10824910409842865
------valation---------72
causal val accuracy:0.802
-----training-------73


Epoch [73/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.689]


number of 73 with total loss:0.11893509766912158
------valation---------73
causal val accuracy:0.829
-----training-------74


Epoch [74/500]: 100%|██████████| 79/79 [00:15<00:00,  5.05it/s, loss=0.493]


number of 74 with total loss:0.11228695874915848
------valation---------74
causal val accuracy:0.8328
-----training-------75


Epoch [75/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.0896]


number of 75 with total loss:0.11459443317372588
------valation---------75
causal val accuracy:0.8204
-----training-------76


Epoch [76/500]: 100%|██████████| 79/79 [00:15<00:00,  5.07it/s, loss=0.0234]


number of 76 with total loss:0.10482019593915608
------valation---------76
causal val accuracy:0.8328
-----training-------77


Epoch [77/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.182]


number of 77 with total loss:0.09640949306703067
------valation---------77
causal val accuracy:0.8358
-----training-------78


Epoch [78/500]: 100%|██████████| 79/79 [00:16<00:00,  4.73it/s, loss=0.332]


number of 78 with total loss:0.1064591868485831
------valation---------78
causal val accuracy:0.8266
-----training-------79


Epoch [79/500]: 100%|██████████| 79/79 [00:15<00:00,  5.03it/s, loss=0.0153]


number of 79 with total loss:0.10493552550390552
------valation---------79
causal val accuracy:0.8254
-----training-------80


Epoch [80/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.128]


number of 80 with total loss:0.10420060586891597
------valation---------80
causal val accuracy:0.8228
-----training-------81


Epoch [81/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.119]


number of 81 with total loss:0.10103787646829328
------valation---------81
causal val accuracy:0.8394
-----training-------82


Epoch [82/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.21]


number of 82 with total loss:0.1009680808439285
------valation---------82
causal val accuracy:0.8416
-----training-------83


Epoch [83/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.166]


number of 83 with total loss:0.0999856371951254
------valation---------83
causal val accuracy:0.8364
-----training-------84


Epoch [84/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.494]


number of 84 with total loss:0.09426448115772462
------valation---------84
causal val accuracy:0.8176
-----training-------85


Epoch [85/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.0558]


number of 85 with total loss:0.10327778749545163
------valation---------85
causal val accuracy:0.8326
-----training-------86


Epoch [86/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.0266]


number of 86 with total loss:0.09430893883109093
------valation---------86
causal val accuracy:0.8368
-----training-------87


Epoch [87/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.0145]


number of 87 with total loss:0.08926573518333555
------valation---------87
causal val accuracy:0.8334
-----training-------88


Epoch [88/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.14]


number of 88 with total loss:0.09201148385771468
------valation---------88
causal val accuracy:0.8416
-----training-------89


Epoch [89/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.106]


number of 89 with total loss:0.09947418567689159
------valation---------89
causal val accuracy:0.8258
------test---------0
causal val accuracy:0.6632
