In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
!pip install warnings
!pip install dgl
!pip install texttable

2.0.0+cu118
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m83.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.8/4.8 MB[0m [31m64.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting dgl
  Downloading dgl-1.1.0-cp310-cp310-manylinux1_x86_64.whl (5.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.9/5.9 MB[0m [31m67.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.1.0
Looking in indexes: https

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [3]:
import random
from torchvision import transforms, datasets

import os
import pickle
from scipy.spatial.distance import cdist
from scipy import ndimage
import numpy as np

import dgl
import torch
import time
import networkx as nx
import matplotlib.pyplot as plt
import matplotlib
def sigma(dists, kth=8):
    # Get k-nearest neighbors for each node
    knns = np.partition(dists, kth, axis=-1)[:, kth::-1]

    # Compute sigma and reshape
    sigma = knns.sum(axis=1).reshape((knns.shape[0], 1))/kth
    return sigma + 1e-8 # adding epsilon to avoid zero value of sigma

def compute_adjacency_matrix_images(coord, feat, use_feat=False, kth=8):
    coord = coord.reshape(-1, 2)
    # Compute coordinate distance
    c_dist = cdist(coord, coord)
    
    if use_feat:
        # Compute feature distance
        f_dist = cdist(feat, feat)
        # Compute adjacency
        A = np.exp(- (c_dist/sigma(c_dist))**2 - (f_dist/sigma(f_dist))**2 )
    else:
        A = np.exp(- (c_dist/sigma(c_dist))**2)
        
    # Convert to symmetric matrix
    A = 0.5 * A * A.T
    A[np.diag_indices_from(A)] = 0
    return A

def compute_edges_list(A, kth=8+1):
    # Get k-similar neighbor indices for each node
    if 1==1:   
        num_nodes = A.shape[0]
        new_kth = num_nodes - kth
        knns = np.argpartition(A, new_kth-1, axis=-1)[:, new_kth:-1]
        knns_d = np.partition(A, new_kth-1, axis=-1)[:, new_kth:-1]
    else:
        knns = np.argpartition(A, kth, axis=-1)[:, kth::-1]
        knns_d = np.partition(A, kth, axis=-1)[:, kth::-1]
    return knns, knns_d
class newCIFARSuperPix(torch.utils.data.Dataset):
    def __init__(self,
                 data_dir,
                 use_mean_px=True,
                 use_coord=True,
                 use_feat_for_graph_construct=False,):

        #self.split = split
        #self.is_test = split.lower() in ['test', 'val']
        with open(data_dir, 'rb') as f:
            self.labels, self.sp_data = pickle.load(f)

        self.use_mean_px = use_mean_px
        self.use_feat_for_graph = use_feat_for_graph_construct
        self.use_coord = use_coord
        self.n_samples = len(self.labels)
        self.img_size = 32

    def precompute_graph_images(self):
        #print('precompute all data for the %s set...' % self.split.upper())
        self.Adj_matrices, self.node_features, self.edges_lists = [], [], []
        for index, sample in enumerate(self.sp_data):
            mean_px, coord = sample[:2]
            coord = coord / self.img_size
            A = compute_adjacency_matrix_images(coord, mean_px, use_feat=self.use_feat_for_graph)
            edges_list, _ = compute_edges_list(A)
            N_nodes = A.shape[0]
            
            x = None
            if self.use_mean_px:
                x = mean_px.reshape(N_nodes, -1)
            if self.use_coord:
                coord = coord.reshape(N_nodes, 2)
                if self.use_mean_px:
                    x = np.concatenate((x, coord), axis=1)
                else:
                    x = coord
            if x is None:
                x = np.ones(N_nodes, 1)  # dummy features
            
            self.node_features.append(x)
            self.Adj_matrices.append(A)
            self.edges_lists.append(edges_list)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, index):
        g = dgl.DGLGraph()
        g.add_nodes(self.node_features[index].shape[0])
        g.ndata['feat'] = torch.Tensor(self.node_features[index])
        for src, dsts in enumerate(self.edges_lists[index]):
            g.add_edges(src, dsts[dsts!=src])

        return g, self.labels[index]

use_feat_for_graph_construct = False
tt = time.time()
data_with_feat_knn = newCIFARSuperPix("/content/drive/MyDrive/CMINST_data/CMNIST095_60000_75sp_train.pkl",use_feat_for_graph_construct=use_feat_for_graph_construct)

data_with_feat_knn.precompute_graph_images()
training_data = np.load('/content/drive/MyDrive/CMINST_data/colorMNIST095_60000_data.npy')
training_label=np.load('/content/drive/MyDrive/CMINST_data/colorMNIST095_60000_label.npy')

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)


In [4]:
import numpy as np
import os.path as osp
import pickle
import torch
import torch.utils
import torch.utils.data
import torch.nn.functional as F
from scipy.spatial.distance import cdist
from torch_geometric.utils import dense_to_sparse
from torch_geometric.data import InMemoryDataset,Data
#/content/drive/MyDrive/Colab_Notebooks/mnist08_83_75sp_train.pkl
def compute_adjacency_matrix_images(coord, sigma=0.1):
    coord = coord.reshape(-1, 2)
    dist = cdist(coord, coord)
    A = np.exp(- dist / (sigma * np.pi) ** 2)
    A[np.diag_indices_from(A)] = 0
    return A


def list_to_torch(data):
    for i in range(len(data)):
        if data[i] is None:
            continue
        elif isinstance(data[i], np.ndarray):
            if data[i].dtype == np.bool:
                data[i] = data[i].astype(np.float32)
            data[i] = torch.from_numpy(data[i]).float()
        elif isinstance(data[i], list):
            data[i] = list_to_torch(data[i])
    return data
def process(data_file):
  use_mean_px=True
  use_coord=True
  node_gt_att_threshold=0
  transform=None
  pre_transform=None
  pre_filter=None

  #data_file ='/content/drive/MyDrive/Colab_Notebooks/colorMNIST05_2000_75sp_train.pkl' 

  with open(osp.join(data_file), 'rb') as f:
      labels,sp_data = pickle.load(f)
      
  #use_mean_px = self.use_mean_px
  #self.use_coord = self.use_coord
  n_samples = len(labels)
  img_size = 32
  #node_gt_att_threshold = self.node_gt_att_threshold

  edge_indices,xs,edge_attrs,node_gt_atts,edge_gt_atts = [], [], [], [], []
  data_list = []
  for index, sample in enumerate(sp_data):
      mean_px, coord = sample[:2]
      coord = coord / img_size
      A = compute_adjacency_matrix_images(coord)
      N_nodes = A.shape[0]
      
      A = torch.FloatTensor((A > 0.1) * A)
      edge_index, edge_attr = dense_to_sparse(A)

      x = None
      if use_mean_px:
          x = mean_px.reshape(N_nodes, -1)
      if use_coord:
          coord = coord.reshape(N_nodes, 2)
          if use_mean_px:
              x = np.concatenate((x, coord), axis=1)
          else:
              x = coord
      if x is None:
          x = np.ones(N_nodes, 1)  # dummy features
          
      # replicate features to make it possible to test on colored images
      x = np.pad(x, ((0, 0), (2, 0)), 'edge')  
      if node_gt_att_threshold == 0:
          node_gt_att = (mean_px > 0).astype(np.float32)
      else:
          node_gt_att = mean_px.copy()
          node_gt_att[node_gt_att < node_gt_att_threshold] = 0

      node_gt_att = torch.LongTensor(node_gt_att).view(-1)
      row, col = edge_index
      edge_gt_att = torch.LongTensor(node_gt_att[row] * node_gt_att[col]).view(-1)

      data_list.append(
          Data(
              x=torch.tensor(x), 
              y=torch.LongTensor([labels[index]]), 
              edge_index=edge_index,
              edge_attr=edge_attr, 
              node_gt_att=node_gt_att,
              edge_gt_att=edge_gt_att

          )
      )

  #torch.save(InMemoryDataset.collate(data_list), '/content/drive/MyDrive/Colab_Notebooks/colorMINST05_2000.pt')
  return data_list

In [5]:
train_dir='/content/drive/MyDrive/CMINST_data/CMNIST095_10000_75sp_train.pkl'
val_dir='/content/drive/MyDrive/CMINST_data/CMNIST5000_75sp_val.pkl'
test_dir='/content/drive/MyDrive/CMINST_data/CMNIST10000_75sp_test.pkl'
training_final=process(data_file=train_dir)
valing_final=process(data_file=val_dir)
testing_final=process(data_file=test_dir)



In [6]:
import torch
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
from torch_geometric.nn.inits import glorot, zeros
import pdb

def mask_graph(graph_x,select_node):
  mask_value=np.array([0.0001]*32)
  result=np.array([t.detach().numpy() for t in graph_x])
  for i in range(graph_x.shape[0]):
    if(i in select_node):
      continue
    result[i]=mask_value
  return torch.tensor(result)





def split_graph(graph_x,node_of_graph,type_of_graph=True):
  if(type_of_graph==True):
    select_node_number=int(node_of_graph/3)
    select_node=torch.topk(graph_x.mean(axis=1),select_node_number)[1]
    #print(select_node)
    return mask_graph(graph_x,select_node),select_node
  else:
    select_node_number=int(node_of_graph/3*2)
    select_node=torch.topk(graph_x.mean(axis=1),select_node_number)[1]
    return mask_graph(graph_x,select_node),select_node

def uncertainty_mask_gnerate(node_in_graph,number_of_mask):
  all_mask=[]
  for i in range(number_of_mask):
    random_mask=random.sample(node_in_graph.tolist(),int(len(node_in_graph)/3))
    all_mask.append(random_mask)
  return all_mask
def mean(l):
  return sum(l)/len(l)
class GCNConv(MessagePassing):
    
    def __init__(self,
                 in_channels,
                 out_channels,
                 improved=False,
                 cached=False,
                 bias=True,
                 edge_norm=True,
                 gfn=False):
        super(GCNConv, self).__init__('add')

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.cached_result = None
        self.edge_norm = edge_norm
        self.gfn = gfn
        self.message_mask = None
        self.weight = Parameter(torch.Tensor(in_channels, out_channels))

        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        glorot(self.weight)
        zeros(self.bias)
        self.cached_result = None

    @staticmethod
    def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ),
                                     dtype=dtype,
                                     device=edge_index.device)
        
        edge_weight = edge_weight.view(-1)
        
        
        assert edge_weight.size(0) == edge_index.size(1)
        
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
        # Add edge_weight for loop edges.
        loop_weight = torch.full((num_nodes, ),
                                 1 if not improved else 2,
                                 dtype=edge_weight.dtype,
                                 device=edge_weight.device)
        edge_weight = torch.cat([edge_weight, loop_weight], dim=0)

        row, col = edge_index
        deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        
        return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]

    def forward(self, x, edge_index, edge_weight=None):
        """"""
        
        x = torch.matmul(x, self.weight)
        if self.gfn:
            return x
    
        if not self.cached or self.cached_result is None:
            if self.edge_norm:
                edge_index, norm = GCNConv.norm(
                    edge_index, 
                    x.size(0), 
                    edge_weight, 
                    self.improved, 
                    x.dtype)
            else:
                norm = None
            self.cached_result = edge_index, norm

        edge_index, norm = self.cached_result
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):

        if self.edge_norm:
            return norm.view(-1, 1) * x_j
        else:
            return x_j
        
    def update(self, aggr_out):
        if self.bias is not None:
            aggr_out = aggr_out + self.bias
        return aggr_out

    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)

In [7]:
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Linear, BatchNorm1d, Sequential, ReLU
from torch_geometric.nn import global_mean_pool, global_add_pool, GINConv, GATConv

import random
import pdb
device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")

layers=3
with_random=True
fc_num=222
hidden=32
eval_random=False
class GATNet(torch.nn.Module):
    def __init__(self, num_features, 
                       num_classes,
                       hidden=32,
                       head=4,
                       num_fc_layers=2, 
                       num_conv_layers=3, 
                       dropout=0.2):

        super(GATNet, self).__init__()

        self.global_pool = global_add_pool
        self.dropout = dropout
        hidden_in = num_features
        hidden_out = num_classes
   
        self.bn_feat = BatchNorm1d(hidden_in)
        self.conv_feat = GCNConv(hidden_in, hidden, gfn=True) # linear transform
        self.bns_conv = torch.nn.ModuleList()
        self.convs = torch.nn.ModuleList()

        for i in range(num_conv_layers):
            self.bns_conv.append(BatchNorm1d(hidden))
            self.convs.append(GATConv(hidden, int(hidden / head), heads=head, dropout=dropout))
        self.bn_hidden = BatchNorm1d(hidden)
        self.bns_fc = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()

        for i in range(num_fc_layers - 1):
            self.bns_fc.append(BatchNorm1d(hidden))
            self.lins.append(Linear(hidden, hidden))
        self.lin_class = Linear(hidden, hidden_out)

        # BN initialization.
        for m in self.modules():
            if isinstance(m, (torch.nn.BatchNorm1d)):
                torch.nn.init.constant_(m.weight, 1)
                torch.nn.init.constant_(m.bias, 0.0001)

    def forward(self, data):
        
        x = data.x if data.x is not None else data.feat
        edge_index, batch = data.edge_index, data.batch
        
        x = self.bn_feat(x)
        x = F.relu(self.conv_feat(x, edge_index))
        
        for i, conv in enumerate(self.convs):
            x = self.bns_conv[i](x)
            x = F.relu(conv(x, edge_index))

        x = self.global_pool(x, batch)
        for i, lin in enumerate(self.lins):
            x = self.bns_fc[i](x)
            x = F.relu(lin(x))

        x = self.bn_hidden(x)
        if self.dropout > 0:
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lin_class(x)
        return F.log_softmax(x, dim=-1)

In [8]:
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torch_geometric.data import DataLoader, DenseDataLoader as DenseLoader
from torch import tensor
import torch_geometric.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
import pdb
import random
import numpy as np
from torch.autograd import grad
from torch_geometric.data import Batch


In [9]:
train_loader = DataLoader(training_final, batch_size=128, shuffle=True)
val_loader = DataLoader(valing_final, batch_size=128, shuffle=False)
#t_load=[]
#for i in train_loader:
#  t_load.append(i)
#  if(len(t_load)==10000):
#    break




In [10]:
number_of_class=10
Epo=500
model= GATNet(7,number_of_class).to(device)

optimizer = Adam(model.parameters(), lr=0.001)
lr_scheduler = CosineAnnealingLR(optimizer, T_max=Epo, eta_min=1e-6, last_epoch=-1, verbose=False)


In [11]:
import time
import json

loss_value=[]
loss_value_valation=[]

def num_graphs(data):
  if data.batch is not None:
      return data.num_graphs
  else:
      return data.x.size(0)
from tqdm import tqdm

for epoch in range(Epo):
  model.train()
  total_loss = 0
  correct = 0
  nb=0
  print(f"-----training-------{epoch}")
  loop = tqdm(enumerate(train_loader),total=len(train_loader))
  for it, data in loop:
#  for it, data in enumerate(train_loader):
      nb+=1
      optimizer.zero_grad()
      data = data.to(device)
      out = model(data)
      loss = F.nll_loss(out, data.y.view(-1))
      pred = out.max(1)[1]
      correct += pred.eq(data.y.view(-1)).sum().item()
      loss.backward()
      total_loss += loss.item() #* num_graphs(data)
      optimizer.step()
      loop.set_description(f"Epoch [{epoch}/{Epo}]")
      loop.set_postfix(loss = loss.item())

  #num = len(train_loader.dataset)
  total_loss = total_loss / nb
  lr_scheduler.step()

  print(f'number of {epoch} with total loss:{total_loss}')
  loss_value.append(total_loss)
  correct = correct / nb
  with torch.no_grad():
    model.eval()
    correct = 0
    print(f"------valation---------{epoch}")
    for data in val_loader:
      data = data.to(device)
      pred = model(data).max(1)[1]
      correct += pred.eq(data.y.view(-1)).sum().item()
    acc_o = correct / len(val_loader.dataset)
    print(f"causal val accuracy:{acc_o}")
    loss_value_valation.append(acc_o)
    dictionary={"number of epoch":epoch,
                "training loss list":loss_value,
                "valation accuracy list":loss_value_valation}
 
    # Serializing json
    json_object = json.dumps(dictionary,indent=3)
    
    # Writing to sample.json
    with open("/content/drive/MyDrive/running_nodebias_mnist/numberGAT_tl_va_e095.json", "w") as outfile:
        outfile.write(json_object)

    #torch.save(causal_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430cau_my6000.pt')
    #torch.save(predictco_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430caupred_my6000.pt')
    #torch.save(predictno_model.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430noncaupred_my6000.pt')
    #torch.save(model_heter.state_dict(), '/content/drive/MyDrive/Colab_Notebooks/430heter6000.pt')
    torch.save({
            'GCN_model.state_dic': model.state_dict(),
            'opt':optimizer.state_dict()
            }, '/content/drive/MyDrive/running_nodebias_mnist/GATmodel_095.pt')
    if(epoch>50):
      check=abs(acc_o-loss_value_valation[len(loss_value_valation)-50])/50
      if(check<=0.0001):
        break


  #model_optimizer.zero_grad()
  #total_loss.backward()
  #model_optimizer.step()
test_loader = DataLoader(testing_final, batch_size=1, shuffle=False)

model.eval()
correct = 0
print(f"------test---------{00}")
for data in test_loader:
  data = data.to(device)
  pred = model(data).max(1)[1]
  correct += pred.eq(data.y.view(-1)).sum().item()
acc_o = correct / len(test_loader.dataset)
print(f"causal val accuracy:{acc_o}")
dictionary={"number of epoch":epoch,
        "training loss list":loss_value,
        "valation accuracy list":loss_value_valation,
        "test accuracy value":acc_o}

# Serializing json
json_object = json.dumps(dictionary,indent=4)

# Writing to sample.json
with open("/content/drive/MyDrive/running_nodebias_mnist/numberGAT_tl_va_e095.json", "w") as outfile:
  outfile.write(json_object)

-----training-------0


Epoch [0/500]: 100%|██████████| 79/79 [00:16<00:00,  4.73it/s, loss=0.994]


number of 0 with total loss:0.9065336656721332
------valation---------0
causal val accuracy:0.5006
-----training-------1


Epoch [1/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.604]


number of 1 with total loss:0.3933515852387947
------valation---------1
causal val accuracy:0.5006
-----training-------2


Epoch [2/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.59]


number of 2 with total loss:0.33589904934545106
------valation---------2
causal val accuracy:0.5006
-----training-------3


Epoch [3/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.277]


number of 3 with total loss:0.3120342519464372
------valation---------3
causal val accuracy:0.5006
-----training-------4


Epoch [4/500]: 100%|██████████| 79/79 [00:15<00:00,  5.03it/s, loss=0.167]


number of 4 with total loss:0.30409726783444607
------valation---------4
causal val accuracy:0.5014
-----training-------5


Epoch [5/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.669]


number of 5 with total loss:0.29834614543220667
------valation---------5
causal val accuracy:0.505
-----training-------6


Epoch [6/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.525]


number of 6 with total loss:0.2974803980581368
------valation---------6
causal val accuracy:0.5054
-----training-------7


Epoch [7/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.106]


number of 7 with total loss:0.2847896453134621
------valation---------7
causal val accuracy:0.5046
-----training-------8


Epoch [8/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.713]


number of 8 with total loss:0.2842545392392557
------valation---------8
causal val accuracy:0.5034
-----training-------9


Epoch [9/500]: 100%|██████████| 79/79 [00:16<00:00,  4.75it/s, loss=0.655]


number of 9 with total loss:0.2792958899200717
------valation---------9
causal val accuracy:0.5074
-----training-------10


Epoch [10/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.105]


number of 10 with total loss:0.2724177112119107
------valation---------10
causal val accuracy:0.5098
-----training-------11


Epoch [11/500]: 100%|██████████| 79/79 [00:16<00:00,  4.94it/s, loss=0.709]


number of 11 with total loss:0.27004607191568686
------valation---------11
causal val accuracy:0.5144
-----training-------12


Epoch [12/500]: 100%|██████████| 79/79 [00:16<00:00,  4.83it/s, loss=0.072]


number of 12 with total loss:0.27298381481366824
------valation---------12
causal val accuracy:0.5076
-----training-------13


Epoch [13/500]: 100%|██████████| 79/79 [00:15<00:00,  5.00it/s, loss=0.556]


number of 13 with total loss:0.25973228723565234
------valation---------13
causal val accuracy:0.5122
-----training-------14


Epoch [14/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=1.25]


number of 14 with total loss:0.2522515646075901
------valation---------14
causal val accuracy:0.517
-----training-------15


Epoch [15/500]: 100%|██████████| 79/79 [00:16<00:00,  4.88it/s, loss=0.313]


number of 15 with total loss:0.24181119975032686
------valation---------15
causal val accuracy:0.5152
-----training-------16


Epoch [16/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.0937]


number of 16 with total loss:0.2266472402629973
------valation---------16
causal val accuracy:0.5282
-----training-------17


Epoch [17/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.535]


number of 17 with total loss:0.22507488793587382
------valation---------17
causal val accuracy:0.5322
-----training-------18


Epoch [18/500]: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s, loss=0.1]


number of 18 with total loss:0.21257015242229535
------valation---------18
causal val accuracy:0.538
-----training-------19


Epoch [19/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.456]


number of 19 with total loss:0.20666695716260355
------valation---------19
causal val accuracy:0.542
-----training-------20


Epoch [20/500]: 100%|██████████| 79/79 [00:15<00:00,  5.00it/s, loss=0.303]


number of 20 with total loss:0.19537254932183254
------valation---------20
causal val accuracy:0.5628
-----training-------21


Epoch [21/500]: 100%|██████████| 79/79 [00:16<00:00,  4.77it/s, loss=0.228]


number of 21 with total loss:0.1925224907036069
------valation---------21
causal val accuracy:0.5702
-----training-------22


Epoch [22/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.411]


number of 22 with total loss:0.18624181536179554
------valation---------22
causal val accuracy:0.5568
-----training-------23


Epoch [23/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.583]


number of 23 with total loss:0.18896325257948682
------valation---------23
causal val accuracy:0.572
-----training-------24


Epoch [24/500]: 100%|██████████| 79/79 [00:16<00:00,  4.69it/s, loss=0.163]


number of 24 with total loss:0.17173373057872435
------valation---------24
causal val accuracy:0.5718
-----training-------25


Epoch [25/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.0509]


number of 25 with total loss:0.17278915384336363
------valation---------25
causal val accuracy:0.582
-----training-------26


Epoch [26/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.456]


number of 26 with total loss:0.16324401128141186
------valation---------26
causal val accuracy:0.57
-----training-------27


Epoch [27/500]: 100%|██████████| 79/79 [00:16<00:00,  4.81it/s, loss=0.602]


number of 27 with total loss:0.17057370809437353
------valation---------27
causal val accuracy:0.572
-----training-------28


Epoch [28/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.0498]


number of 28 with total loss:0.16097293676266186
------valation---------28
causal val accuracy:0.5932
-----training-------29


Epoch [29/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.427]


number of 29 with total loss:0.1581427267740799
------valation---------29
causal val accuracy:0.6032
-----training-------30


Epoch [30/500]: 100%|██████████| 79/79 [00:16<00:00,  4.84it/s, loss=0.419]


number of 30 with total loss:0.15895625213279
------valation---------30
causal val accuracy:0.597
-----training-------31


Epoch [31/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.289]


number of 31 with total loss:0.16235595121036603
------valation---------31
causal val accuracy:0.6042
-----training-------32


Epoch [32/500]: 100%|██████████| 79/79 [00:15<00:00,  5.04it/s, loss=0.6]


number of 32 with total loss:0.1583151654635049
------valation---------32
causal val accuracy:0.599
-----training-------33


Epoch [33/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.0721]


number of 33 with total loss:0.1478317809067195
------valation---------33
causal val accuracy:0.6136
-----training-------34


Epoch [34/500]: 100%|██████████| 79/79 [00:15<00:00,  5.00it/s, loss=0.354]


number of 34 with total loss:0.14200492535682419
------valation---------34
causal val accuracy:0.6276
-----training-------35


Epoch [35/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.0707]


number of 35 with total loss:0.14782732121552092
------valation---------35
causal val accuracy:0.6302
-----training-------36


Epoch [36/500]: 100%|██████████| 79/79 [00:16<00:00,  4.84it/s, loss=0.0245]


number of 36 with total loss:0.1312453634093834
------valation---------36
causal val accuracy:0.6332
-----training-------37


Epoch [37/500]: 100%|██████████| 79/79 [00:15<00:00,  5.04it/s, loss=0.176]


number of 37 with total loss:0.13099774025097677
------valation---------37
causal val accuracy:0.6284
-----training-------38


Epoch [38/500]: 100%|██████████| 79/79 [00:15<00:00,  5.05it/s, loss=0.351]


number of 38 with total loss:0.13097333054565177
------valation---------38
causal val accuracy:0.6382
-----training-------39


Epoch [39/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.18]


number of 39 with total loss:0.1238926866952377
------valation---------39
causal val accuracy:0.6354
-----training-------40


Epoch [40/500]: 100%|██████████| 79/79 [00:15<00:00,  5.00it/s, loss=0.0855]


number of 40 with total loss:0.12689081188055534
------valation---------40
causal val accuracy:0.6256
-----training-------41


Epoch [41/500]: 100%|██████████| 79/79 [00:15<00:00,  5.02it/s, loss=0.145]


number of 41 with total loss:0.11967211075221436
------valation---------41
causal val accuracy:0.6404
-----training-------42


Epoch [42/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.071]


number of 42 with total loss:0.12178078095746946
------valation---------42
causal val accuracy:0.6432
-----training-------43


Epoch [43/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.128]


number of 43 with total loss:0.11789317089545576
------valation---------43
causal val accuracy:0.6586
-----training-------44


Epoch [44/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.271]


number of 44 with total loss:0.12021789343768283
------valation---------44
causal val accuracy:0.6514
-----training-------45


Epoch [45/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.229]


number of 45 with total loss:0.11568183692384369
------valation---------45
causal val accuracy:0.6448
-----training-------46


Epoch [46/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.141]


number of 46 with total loss:0.11832283354729792
------valation---------46
causal val accuracy:0.6568
-----training-------47


Epoch [47/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.0469]


number of 47 with total loss:0.11451765517645245
------valation---------47
causal val accuracy:0.6512
-----training-------48


Epoch [48/500]: 100%|██████████| 79/79 [00:16<00:00,  4.69it/s, loss=0.222]


number of 48 with total loss:0.10935874415349357
------valation---------48
causal val accuracy:0.6638
-----training-------49


Epoch [49/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.136]


number of 49 with total loss:0.11420179545124874
------valation---------49
causal val accuracy:0.6618
-----training-------50


Epoch [50/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.325]


number of 50 with total loss:0.11518580480655537
------valation---------50
causal val accuracy:0.6714
-----training-------51


Epoch [51/500]: 100%|██████████| 79/79 [00:16<00:00,  4.73it/s, loss=0.0924]


number of 51 with total loss:0.11026569592613208
------valation---------51
causal val accuracy:0.6688
-----training-------52


Epoch [52/500]: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s, loss=0.821]


number of 52 with total loss:0.11582401166139525
------valation---------52
causal val accuracy:0.6898
-----training-------53


Epoch [53/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.00532]


number of 53 with total loss:0.1113643747212106
------valation---------53
causal val accuracy:0.6678
-----training-------54


Epoch [54/500]: 100%|██████████| 79/79 [00:16<00:00,  4.74it/s, loss=0.0704]


number of 54 with total loss:0.10467168801947485
------valation---------54
causal val accuracy:0.6728
-----training-------55


Epoch [55/500]: 100%|██████████| 79/79 [00:16<00:00,  4.84it/s, loss=0.36]


number of 55 with total loss:0.10341158351283285
------valation---------55
causal val accuracy:0.6828
-----training-------56


Epoch [56/500]: 100%|██████████| 79/79 [00:16<00:00,  4.94it/s, loss=0.303]


number of 56 with total loss:0.10967086649299422
------valation---------56
causal val accuracy:0.694
-----training-------57


Epoch [57/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.354]


number of 57 with total loss:0.10237806299819222
------valation---------57
causal val accuracy:0.6688
-----training-------58


Epoch [58/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.226]


number of 58 with total loss:0.09816574338304845
------valation---------58
causal val accuracy:0.682
-----training-------59


Epoch [59/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.0462]


number of 59 with total loss:0.09500320274618608
------valation---------59
causal val accuracy:0.684
-----training-------60


Epoch [60/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.203]


number of 60 with total loss:0.09595934568043751
------valation---------60
causal val accuracy:0.6852
-----training-------61


Epoch [61/500]: 100%|██████████| 79/79 [00:15<00:00,  5.01it/s, loss=0.118]


number of 61 with total loss:0.09247335709065577
------valation---------61
causal val accuracy:0.693
-----training-------62


Epoch [62/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.0772]


number of 62 with total loss:0.09749005244502539
------valation---------62
causal val accuracy:0.6946
-----training-------63


Epoch [63/500]: 100%|██████████| 79/79 [00:16<00:00,  4.80it/s, loss=0.0975]


number of 63 with total loss:0.09152741246868538
------valation---------63
causal val accuracy:0.6998
-----training-------64


Epoch [64/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.117]


number of 64 with total loss:0.09074926568501734
------valation---------64
causal val accuracy:0.7146
-----training-------65


Epoch [65/500]: 100%|██████████| 79/79 [00:16<00:00,  4.83it/s, loss=0.285]


number of 65 with total loss:0.0959411691166932
------valation---------65
causal val accuracy:0.6856
-----training-------66


Epoch [66/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.0252]


number of 66 with total loss:0.08845958691326124
------valation---------66
causal val accuracy:0.6934
-----training-------67


Epoch [67/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.0368]


number of 67 with total loss:0.08843730880489832
------valation---------67
causal val accuracy:0.7074
-----training-------68


Epoch [68/500]: 100%|██████████| 79/79 [00:16<00:00,  4.79it/s, loss=0.247]


number of 68 with total loss:0.09335447252504056
------valation---------68
causal val accuracy:0.6972
-----training-------69


Epoch [69/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.0945]


number of 69 with total loss:0.09073013229932211
------valation---------69
causal val accuracy:0.7026
-----training-------70


Epoch [70/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.0223]


number of 70 with total loss:0.08261858242787892
------valation---------70
causal val accuracy:0.6972
-----training-------71


Epoch [71/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.183]


number of 71 with total loss:0.08416355384773092
------valation---------71
causal val accuracy:0.7066
-----training-------72


Epoch [72/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.0994]


number of 72 with total loss:0.09075838514851241
------valation---------72
causal val accuracy:0.6848
-----training-------73


Epoch [73/500]: 100%|██████████| 79/79 [00:15<00:00,  5.05it/s, loss=0.137]


number of 73 with total loss:0.08027046936552358
------valation---------73
causal val accuracy:0.7044
-----training-------74


Epoch [74/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.102]


number of 74 with total loss:0.08073714903638332
------valation---------74
causal val accuracy:0.6982
-----training-------75


Epoch [75/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.425]


number of 75 with total loss:0.07893693975255459
------valation---------75
causal val accuracy:0.7096
-----training-------76


Epoch [76/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.558]


number of 76 with total loss:0.09076514726952661
------valation---------76
causal val accuracy:0.7236
-----training-------77


Epoch [77/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.338]


number of 77 with total loss:0.089406790518308
------valation---------77
causal val accuracy:0.7204
-----training-------78


Epoch [78/500]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s, loss=0.181]


number of 78 with total loss:0.08241099879570023
------valation---------78
causal val accuracy:0.7124
-----training-------79


Epoch [79/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.0286]


number of 79 with total loss:0.07600876712535001
------valation---------79
causal val accuracy:0.7124
-----training-------80


Epoch [80/500]: 100%|██████████| 79/79 [00:16<00:00,  4.75it/s, loss=0.378]


number of 80 with total loss:0.08034439461565093
------valation---------80
causal val accuracy:0.7124
-----training-------81


Epoch [81/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.113]


number of 81 with total loss:0.08263414143289946
------valation---------81
causal val accuracy:0.696
-----training-------82


Epoch [82/500]: 100%|██████████| 79/79 [00:15<00:00,  4.95it/s, loss=0.0482]


number of 82 with total loss:0.07667011260703395
------valation---------82
causal val accuracy:0.7142
-----training-------83


Epoch [83/500]: 100%|██████████| 79/79 [00:16<00:00,  4.71it/s, loss=0.0325]


number of 83 with total loss:0.07412116534866486
------valation---------83
causal val accuracy:0.741
-----training-------84


Epoch [84/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.00818]


number of 84 with total loss:0.07536846626832892
------valation---------84
causal val accuracy:0.709
-----training-------85


Epoch [85/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.039]


number of 85 with total loss:0.06671189546254994
------valation---------85
causal val accuracy:0.704
-----training-------86


Epoch [86/500]: 100%|██████████| 79/79 [00:16<00:00,  4.83it/s, loss=0.247]


number of 86 with total loss:0.0782227688881604
------valation---------86
causal val accuracy:0.722
-----training-------87


Epoch [87/500]: 100%|██████████| 79/79 [00:16<00:00,  4.80it/s, loss=0.018]


number of 87 with total loss:0.06955482981674656
------valation---------87
causal val accuracy:0.7238
-----training-------88


Epoch [88/500]: 100%|██████████| 79/79 [00:16<00:00,  4.83it/s, loss=0.0594]


number of 88 with total loss:0.07185587591692051
------valation---------88
causal val accuracy:0.737
-----training-------89


Epoch [89/500]: 100%|██████████| 79/79 [00:16<00:00,  4.82it/s, loss=0.414]


number of 89 with total loss:0.07260514084909912
------valation---------89
causal val accuracy:0.7206
-----training-------90


Epoch [90/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.031]


number of 90 with total loss:0.07009463596947585
------valation---------90
causal val accuracy:0.7132
-----training-------91


Epoch [91/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.0975]


number of 91 with total loss:0.06974122390339646
------valation---------91
causal val accuracy:0.7216
-----training-------92


Epoch [92/500]: 100%|██████████| 79/79 [00:16<00:00,  4.77it/s, loss=0.186]


number of 92 with total loss:0.07308389267683783
------valation---------92
causal val accuracy:0.7228
-----training-------93


Epoch [93/500]: 100%|██████████| 79/79 [00:16<00:00,  4.87it/s, loss=0.352]


number of 93 with total loss:0.08256406531539522
------valation---------93
causal val accuracy:0.7158
-----training-------94


Epoch [94/500]: 100%|██████████| 79/79 [00:16<00:00,  4.85it/s, loss=0.111]


number of 94 with total loss:0.07959463643028011
------valation---------94
causal val accuracy:0.7292
-----training-------95


Epoch [95/500]: 100%|██████████| 79/79 [00:16<00:00,  4.92it/s, loss=0.191]


number of 95 with total loss:0.06746492151736835
------valation---------95
causal val accuracy:0.7318
-----training-------96


Epoch [96/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.0586]


number of 96 with total loss:0.06753570581137945
------valation---------96
causal val accuracy:0.7054
-----training-------97


Epoch [97/500]: 100%|██████████| 79/79 [00:16<00:00,  4.80it/s, loss=0.171]


number of 97 with total loss:0.07409696140668437
------valation---------97
causal val accuracy:0.7272
-----training-------98


Epoch [98/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.194]


number of 98 with total loss:0.06848164364884171
------valation---------98
causal val accuracy:0.7226
-----training-------99


Epoch [99/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.0609]


number of 99 with total loss:0.06712956318655346
------valation---------99
causal val accuracy:0.7326
-----training-------100


Epoch [100/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.207]


number of 100 with total loss:0.06812271191679602
------valation---------100
causal val accuracy:0.7116
-----training-------101


Epoch [101/500]: 100%|██████████| 79/79 [00:15<00:00,  5.02it/s, loss=0.00866]


number of 101 with total loss:0.06466704057647457
------valation---------101
causal val accuracy:0.743
-----training-------102


Epoch [102/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.278]


number of 102 with total loss:0.06665549327065295
------valation---------102
causal val accuracy:0.7298
-----training-------103


Epoch [103/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.00208]


number of 103 with total loss:0.07071804929866538
------valation---------103
causal val accuracy:0.7274
-----training-------104


Epoch [104/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.432]


number of 104 with total loss:0.06493449354945105
------valation---------104
causal val accuracy:0.7288
-----training-------105


Epoch [105/500]: 100%|██████████| 79/79 [00:16<00:00,  4.87it/s, loss=0.348]


number of 105 with total loss:0.073067946713182
------valation---------105
causal val accuracy:0.7272
-----training-------106


Epoch [106/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.193]


number of 106 with total loss:0.07529074640921023
------valation---------106
causal val accuracy:0.7332
-----training-------107


Epoch [107/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.42]


number of 107 with total loss:0.06910296555467044
------valation---------107
causal val accuracy:0.728
-----training-------108


Epoch [108/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.026]


number of 108 with total loss:0.0633464454307775
------valation---------108
causal val accuracy:0.7268
-----training-------109


Epoch [109/500]: 100%|██████████| 79/79 [00:16<00:00,  4.78it/s, loss=0.00253]


number of 109 with total loss:0.05327216822276764
------valation---------109
causal val accuracy:0.7362
-----training-------110


Epoch [110/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.155]


number of 110 with total loss:0.06064795382037948
------valation---------110
causal val accuracy:0.7364
-----training-------111


Epoch [111/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.0248]


number of 111 with total loss:0.06511055311494612
------valation---------111
causal val accuracy:0.7154
-----training-------112


Epoch [112/500]: 100%|██████████| 79/79 [00:16<00:00,  4.76it/s, loss=0.0346]


number of 112 with total loss:0.061902322482223375
------valation---------112
causal val accuracy:0.7252
-----training-------113


Epoch [113/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.0781]


number of 113 with total loss:0.056735252955621936
------valation---------113
causal val accuracy:0.745
-----training-------114


Epoch [114/500]: 100%|██████████| 79/79 [00:16<00:00,  4.93it/s, loss=0.0254]


number of 114 with total loss:0.06323299205661574
------valation---------114
causal val accuracy:0.7354
-----training-------115


Epoch [115/500]: 100%|██████████| 79/79 [00:16<00:00,  4.86it/s, loss=0.00191]


number of 115 with total loss:0.05627174985365164
------valation---------115
causal val accuracy:0.732
-----training-------116


Epoch [116/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.278]


number of 116 with total loss:0.05976090915619007
------valation---------116
causal val accuracy:0.7356
-----training-------117


Epoch [117/500]: 100%|██████████| 79/79 [00:16<00:00,  4.90it/s, loss=0.0251]


number of 117 with total loss:0.05975731693302529
------valation---------117
causal val accuracy:0.7346
-----training-------118


Epoch [118/500]: 100%|██████████| 79/79 [00:16<00:00,  4.84it/s, loss=0.0484]


number of 118 with total loss:0.05822263279623246
------valation---------118
causal val accuracy:0.731
-----training-------119


Epoch [119/500]: 100%|██████████| 79/79 [00:15<00:00,  5.00it/s, loss=0.192]


number of 119 with total loss:0.05298227684784539
------valation---------119
causal val accuracy:0.744
-----training-------120


Epoch [120/500]: 100%|██████████| 79/79 [00:15<00:00,  4.97it/s, loss=0.0492]


number of 120 with total loss:0.05733627062055129
------valation---------120
causal val accuracy:0.7496
-----training-------121


Epoch [121/500]: 100%|██████████| 79/79 [00:16<00:00,  4.79it/s, loss=0.00435]


number of 121 with total loss:0.04985916144745071
------valation---------121
causal val accuracy:0.7452
-----training-------122


Epoch [122/500]: 100%|██████████| 79/79 [00:15<00:00,  5.06it/s, loss=0.0944]


number of 122 with total loss:0.05186266176213947
------valation---------122
causal val accuracy:0.7316
-----training-------123


Epoch [123/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.02]


number of 123 with total loss:0.052535158829598486
------valation---------123
causal val accuracy:0.7408
-----training-------124


Epoch [124/500]: 100%|██████████| 79/79 [00:16<00:00,  4.89it/s, loss=0.0959]


number of 124 with total loss:0.05648651084850861
------valation---------124
causal val accuracy:0.741
-----training-------125


Epoch [125/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.009]


number of 125 with total loss:0.05625690286389635
------valation---------125
causal val accuracy:0.7454
-----training-------126


Epoch [126/500]: 100%|██████████| 79/79 [00:15<00:00,  4.99it/s, loss=0.388]


number of 126 with total loss:0.05340759385719047
------valation---------126
causal val accuracy:0.7518
-----training-------127


Epoch [127/500]: 100%|██████████| 79/79 [00:16<00:00,  4.87it/s, loss=0.487]


number of 127 with total loss:0.0658671610599643
------valation---------127
causal val accuracy:0.733
-----training-------128


Epoch [128/500]: 100%|██████████| 79/79 [00:16<00:00,  4.91it/s, loss=0.095]


number of 128 with total loss:0.061198706843277224
------valation---------128
causal val accuracy:0.7548
-----training-------129


Epoch [129/500]: 100%|██████████| 79/79 [00:15<00:00,  4.98it/s, loss=0.222]


number of 129 with total loss:0.05518472437522834
------valation---------129
causal val accuracy:0.7716
-----training-------130


Epoch [130/500]: 100%|██████████| 79/79 [00:15<00:00,  4.96it/s, loss=0.000921]


number of 130 with total loss:0.05756665993249044
------valation---------130
causal val accuracy:0.7298
-----training-------131


Epoch [131/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.264]


number of 131 with total loss:0.05795770610057855
------valation---------131
causal val accuracy:0.7572
-----training-------132


Epoch [132/500]: 100%|██████████| 79/79 [00:15<00:00,  4.94it/s, loss=0.385]


number of 132 with total loss:0.05595901135732479
------valation---------132
causal val accuracy:0.7428
------test---------0
causal val accuracy:0.5241
