In [1]:
import sys
sys.path.append('/home/dldx/UniRep/pipgcn')
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from my_dataloader import GraphDataLoader, collate
from gin import GIN
from tqdm import tqdm
import random
import os
import time
import scipy.sparse as sp
import pickle
import dgl
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F

Using backend: pytorch


In [2]:
def train( net, trainloader, optimizer, criterion, epoch):
    net.train()
    running_loss = 0
    total_iters = len(trainloader)
    # setup the offset to avoid the overlap with mouse cursor
    bar = tqdm(range(total_iters), unit='batch', position=2, file=sys.stdout)

    for pos, (graphs, labels,names) in zip(bar, trainloader):
        # batch graphs will be shipped to device in forward part of model
        labels = labels.to(device)
        #feat = graphs.ndata['attr'].to(device)
        feat = graphs.ndata['feat'].float().to(device)
        output = net(graphs, feat)
#         outputs=output[0]
        loss = criterion(output, labels)
        running_loss += loss.item()

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # report
        bar.set_description('epoch-{}'.format(epoch))
    bar.close()
    # the final batch will be aligned
    running_loss = running_loss / total_iters

    return running_loss 

@torch.no_grad()
def eval_net( net, dataloader, criterion):
    net.eval()

    total = 0
    total_loss = 0
    total_correct = 0
    targets,outputs=[],[]
    for data in dataloader:
        graphs, labels,names = data
        feat = graphs.ndata['feat'].float().to(device)
        labels = labels.to(device)
        total += len(labels)
        output = net(graphs, feat)
#         outputs=output[0]
        _, predicted = torch.max(output.data, 1)

        total_correct += (predicted == labels.data).sum().item()
        loss = criterion(output, labels)
        # crossentropy(reduce=True) for default
        total_loss += loss.item() * len(labels)
        targets.append(labels)
        outputs.append(output)
    loss, acc = 1.0*total_loss / total, 1.0*total_correct / total

    return acc,targets,outputs

In [3]:
# set up seeds, args.seed supported
seed=2021
torch.manual_seed(seed=seed)
np.random.seed(seed=seed)

#指定GPU
torch.cuda.set_device(1)
if torch.cuda.is_available():

    device = torch.device("cuda")
    torch.cuda.manual_seed_all(seed=seed)
else:
    device = torch.device("cpu")
print(device)
    
with open("/home/dldx/UniRep/data/data45_true.p", 'rb') as f:                      
    dataset = pickle.load(f)
print(len(dataset))

cuda
45


In [4]:
#模型参数
start = time.time()
acc_scores=[]
datalist=[]
for fold_idx in range(10):
    model =GIN(5, 2, 17, 64, 3, 0.5, True, "sum", "sum").to(device)

    epochs=100
    criterion = nn.CrossEntropyLoss()  # defaul reduce is true
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)


    tbar = tqdm(range(epochs), unit="epoch", position=3, ncols=0, file=sys.stdout)
    vbar = tqdm(range(epochs), unit="epoch", position=4, ncols=0, file=sys.stdout)
    lrbar = tqdm(range(epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)
    trainloader, validloader = GraphDataLoader(dataset, batch_size=45, device=device,
                                                collate_fn=collate, seed=2021, shuffle=True,
                                                split_name='fold10', fold_idx=fold_idx).train_valid_loader()


    valid_acc_tmp=0
    for epoch, _, _ in zip(tbar, vbar, lrbar):

        train( model, trainloader, optimizer, criterion, epoch)
        scheduler.step()

        
#---------------------------动态输出训练结果-----------------------------     
#         train_loss, train_acc = eval_net(model, trainloader, criterion)
#         train_losses.append(train_loss)
#         train_accs.append(train_acc)
        
#         tbar.set_description(
#             'train set - average loss: {:.4f}, accuracy: {:.0f}%'
#             .format(train_loss, 100. * train_acc))
        


        valid_acc,target,output = eval_net(model, validloader, criterion)
        targets=target
        outputs=output
        if valid_acc > valid_acc_tmp:
            valid_acc_tmp=valid_acc
        
    datalist.append([valid_acc_tmp,targets,outputs])


    
    tbar.close()
    vbar.close()
    lrbar.close()
    #    #保存模型
    #PATH='/home/dldx/UniRep/Model_Trained/gain_1440_5fold_'+str(fold_idx)
    #torch.save(model.state_dict(),PATH)
    print("第{}折完成,准确率{}".format(fold_idx,valid_acc_tmp)) 

loss_acc='gin_acc_5fold_45.p'
with open(loss_acc, 'wb') as f:
    pickle.dump(datalist, f)
end = time.time()
print("运行时间:%.2f秒"%(end-start))      
print("work down!")




  0% 0/100 [00:00<?, ?epoch/s][A[A[A



  0% 0/100 [00:00<?, ?epoch/s][A[A[A[A




  0% 0/100 [00:00<?, ?epoch/s][A[A[A[A[Atrain_set : test_set = 40 : 5


  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-0:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-0: 100%|██████████| 1/1 [00:00<00:00,  1.27batch/s][A[A



  1% 1/100 [00:00<01:31,  1.08epoch/s][A[A[A



  1% 1/100 [00:00<01:31,  1.08epoch/s][A[A[A[A




  1% 1/100 [00:00<01:31,  1.08epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-1:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-1: 100%|██████████| 1/1 [00:00<00:00,  6.21batch/s][A[A



  2% 2/100 [00:01<01:13,  1.34epoch/s][A[A[A



  2% 2/100 [00:01<01:13,  1.34epoch/s][A[A[A[A




  2% 2/100 [00:01<01:13,  1.34epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-2:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-2: 100%|██████████| 1/1 [00:00<00:00,  6.88batch/s][

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-48:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-48: 100%|██████████| 1/1 [00:00<00:00,  7.23batch/s][A[A



 49% 49/100 [00:14<00:14,  3.63epoch/s][A[A[A



 49% 49/100 [00:14<00:14,  3.63epoch/s][A[A[A[A




 49% 49/100 [00:14<00:14,  3.63epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-49:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-49: 100%|██████████| 1/1 [00:00<00:00,  7.07batch/s][A[A



 50% 50/100 [00:14<00:13,  3.62epoch/s][A[A[A



 50% 50/100 [00:14<00:13,  3.62epoch/s][A[A[A[A




 50% 50/100 [00:14<00:13,  3.62epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-50:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-50: 100%|██████████| 1/1 [00:00<00:00,  7.09batch/s][A[A



 51% 51/100 [00:14<00:13,  3.61epoch/s][A[A[A



 51% 51/100 [00:14<00:13,  3.61epoch/s][A[A[A[A




 51% 51/100 [00:14<00:13,  3.61epoch/s][A

epoch-96: 100%|██████████| 1/1 [00:00<00:00,  7.19batch/s][A[A



 97% 97/100 [00:27<00:00,  3.66epoch/s][A[A[A



 97% 97/100 [00:27<00:00,  3.66epoch/s][A[A[A[A




 97% 97/100 [00:27<00:00,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-97:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-97: 100%|██████████| 1/1 [00:00<00:00,  7.21batch/s][A[A



 98% 98/100 [00:27<00:00,  3.67epoch/s][A[A[A



 98% 98/100 [00:27<00:00,  3.67epoch/s][A[A[A[A




 98% 98/100 [00:27<00:00,  3.67epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-98:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-98: 100%|██████████| 1/1 [00:00<00:00,  7.10batch/s][A[A



 99% 99/100 [00:27<00:00,  3.66epoch/s][A[A[A



 99% 99/100 [00:27<00:00,  3.66epoch/s][A[A[A[A




 99% 99/100 [00:27<00:00,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-99:   0%|          | 0/1 [00:00<?, ?ba

epoch-44: 100%|██████████| 1/1 [00:00<00:00,  7.39batch/s][A[A



 45% 45/100 [00:12<00:14,  3.68epoch/s][A[A[A



 45% 45/100 [00:12<00:14,  3.68epoch/s][A[A[A[A




 45% 45/100 [00:12<00:14,  3.68epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-45:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-45: 100%|██████████| 1/1 [00:00<00:00,  7.11batch/s][A[A



 46% 46/100 [00:12<00:14,  3.66epoch/s][A[A[A



 46% 46/100 [00:12<00:14,  3.66epoch/s][A[A[A[A




 46% 46/100 [00:12<00:14,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-46:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-46: 100%|██████████| 1/1 [00:00<00:00,  7.39batch/s][A[A



 47% 47/100 [00:12<00:14,  3.66epoch/s][A[A[A



 47% 47/100 [00:12<00:14,  3.66epoch/s][A[A[A[A




 47% 47/100 [00:12<00:14,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-47:   0%|          | 0/1 [00:00<?, ?ba

 93% 93/100 [00:25<00:01,  3.64epoch/s][A[A[A[A




 93% 93/100 [00:25<00:01,  3.64epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-93:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-93: 100%|██████████| 1/1 [00:00<00:00,  7.21batch/s][A[A



 94% 94/100 [00:25<00:01,  3.64epoch/s][A[A[A



 94% 94/100 [00:25<00:01,  3.64epoch/s][A[A[A[A




 94% 94/100 [00:25<00:01,  3.64epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-94:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-94: 100%|██████████| 1/1 [00:00<00:00,  7.30batch/s][A[A



 95% 95/100 [00:25<00:01,  3.65epoch/s][A[A[A



 95% 95/100 [00:25<00:01,  3.65epoch/s][A[A[A[A




 95% 95/100 [00:25<00:01,  3.65epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-95:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-95: 100%|██████████| 1/1 [00:00<00:00,  7.28batch/s][A[A



 96% 96/100 [00:26<00:01,  3.66epoch/s

 41% 41/100 [00:11<00:16,  3.65epoch/s][A[A[A[A




 41% 41/100 [00:11<00:16,  3.65epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-41:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-41: 100%|██████████| 1/1 [00:00<00:00,  7.35batch/s][A[A



 42% 42/100 [00:11<00:15,  3.66epoch/s][A[A[A



 42% 42/100 [00:11<00:15,  3.66epoch/s][A[A[A[A




 42% 42/100 [00:11<00:15,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-42:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-42: 100%|██████████| 1/1 [00:00<00:00,  7.31batch/s][A[A



 43% 43/100 [00:11<00:15,  3.66epoch/s][A[A[A



 43% 43/100 [00:11<00:15,  3.66epoch/s][A[A[A[A




 43% 43/100 [00:11<00:15,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-43:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-43: 100%|██████████| 1/1 [00:00<00:00,  7.38batch/s][A[A



 44% 44/100 [00:12<00:15,  3.67epoch/s

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-89:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-89: 100%|██████████| 1/1 [00:00<00:00,  7.36batch/s][A[A



 90% 90/100 [00:24<00:02,  3.65epoch/s][A[A[A



 90% 90/100 [00:24<00:02,  3.65epoch/s][A[A[A[A




 90% 90/100 [00:24<00:02,  3.65epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-90:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-90: 100%|██████████| 1/1 [00:00<00:00,  7.13batch/s][A[A



 91% 91/100 [00:24<00:02,  3.64epoch/s][A[A[A



 91% 91/100 [00:24<00:02,  3.64epoch/s][A[A[A[A




 91% 91/100 [00:24<00:02,  3.64epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-91:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-91: 100%|██████████| 1/1 [00:00<00:00,  7.14batch/s][A[A



 92% 92/100 [00:25<00:02,  3.63epoch/s][A[A[A



 92% 92/100 [00:25<00:02,  3.63epoch/s][A[A[A[A




 92% 92/100 [00:25<00:02,  3.63epoch/s][A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-37:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-37: 100%|██████████| 1/1 [00:00<00:00,  7.24batch/s][A[A



 38% 38/100 [00:10<00:17,  3.63epoch/s][A[A[A



 38% 38/100 [00:10<00:17,  3.63epoch/s][A[A[A[A




 38% 38/100 [00:10<00:17,  3.63epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-38:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-38: 100%|██████████| 1/1 [00:00<00:00,  7.13batch/s][A[A



 39% 39/100 [00:10<00:16,  3.63epoch/s][A[A[A



 39% 39/100 [00:10<00:16,  3.63epoch/s][A[A[A[A




 39% 39/100 [00:10<00:16,  3.63epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-39:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-39: 100%|██████████| 1/1 [00:00<00:00,  7.13batch/s][A[A



 40% 40/100 [00:10<00:16,  3.62epoch/s][A[A[A



 40% 40/100 [00:10<00:16,  3.62epoch/s][A[A[A[A




 40% 40/100 [00:10<00:16,  3.62epoch/s][A

epoch-85: 100%|██████████| 1/1 [00:00<00:00,  7.28batch/s][A[A



 86% 86/100 [00:23<00:03,  3.65epoch/s][A[A[A



 86% 86/100 [00:23<00:03,  3.65epoch/s][A[A[A[A




 86% 86/100 [00:23<00:03,  3.65epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-86:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-86: 100%|██████████| 1/1 [00:00<00:00,  7.35batch/s][A[A



 87% 87/100 [00:23<00:03,  3.66epoch/s][A[A[A



 87% 87/100 [00:23<00:03,  3.66epoch/s][A[A[A[A




 87% 87/100 [00:23<00:03,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-87:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-87: 100%|██████████| 1/1 [00:00<00:00,  7.27batch/s][A[A



 88% 88/100 [00:24<00:03,  3.66epoch/s][A[A[A



 88% 88/100 [00:24<00:03,  3.66epoch/s][A[A[A[A




 88% 88/100 [00:24<00:03,  3.66epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-88:   0%|          | 0/1 [00:00<?, ?ba

epoch-33: 100%|██████████| 1/1 [00:00<00:00,  7.26batch/s][A[A



 34% 34/100 [00:09<00:18,  3.65epoch/s][A[A[A



 34% 34/100 [00:09<00:18,  3.65epoch/s][A[A[A[A




 34% 34/100 [00:09<00:18,  3.65epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-34:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-34: 100%|██████████| 1/1 [00:00<00:00,  7.09batch/s][A[A



 35% 35/100 [00:09<00:17,  3.63epoch/s][A[A[A



 35% 35/100 [00:09<00:17,  3.64epoch/s][A[A[A[A




 35% 35/100 [00:09<00:17,  3.64epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-35:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-35: 100%|██████████| 1/1 [00:00<00:00,  7.27batch/s][A[A



 36% 36/100 [00:09<00:17,  3.64epoch/s][A[A[A



 36% 36/100 [00:09<00:17,  3.64epoch/s][A[A[A[A




 36% 36/100 [00:09<00:17,  3.64epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-36:   0%|          | 0/1 [00:00<?, ?ba

 82% 82/100 [00:22<00:05,  3.54epoch/s][A[A[A[A




 82% 82/100 [00:22<00:05,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-82:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-82: 100%|██████████| 1/1 [00:00<00:00,  7.04batch/s][A[A



 83% 83/100 [00:22<00:04,  3.55epoch/s][A[A[A



 83% 83/100 [00:22<00:04,  3.55epoch/s][A[A[A[A




 83% 83/100 [00:22<00:04,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-83:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-83: 100%|██████████| 1/1 [00:00<00:00,  7.40batch/s][A[A



 84% 84/100 [00:23<00:04,  3.58epoch/s][A[A[A



 84% 84/100 [00:23<00:04,  3.58epoch/s][A[A[A[A




 84% 84/100 [00:23<00:04,  3.59epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-84:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-84: 100%|██████████| 1/1 [00:00<00:00,  7.10batch/s][A[A



 85% 85/100 [00:23<00:04,  3.59epoch/s

 30% 30/100 [00:08<00:19,  3.55epoch/s][A[A[A[A




 30% 30/100 [00:08<00:19,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-30:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-30: 100%|██████████| 1/1 [00:00<00:00,  7.16batch/s][A[A



 31% 31/100 [00:08<00:19,  3.56epoch/s][A[A[A



 31% 31/100 [00:08<00:19,  3.56epoch/s][A[A[A[A




 31% 31/100 [00:08<00:19,  3.56epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-31:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-31: 100%|██████████| 1/1 [00:00<00:00,  7.07batch/s][A[A



 32% 32/100 [00:09<00:19,  3.56epoch/s][A[A[A



 32% 32/100 [00:09<00:19,  3.56epoch/s][A[A[A[A




 32% 32/100 [00:09<00:19,  3.56epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-32:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-32: 100%|██████████| 1/1 [00:00<00:00,  7.02batch/s][A[A



 33% 33/100 [00:09<00:18,  3.56epoch/s

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-78:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-78: 100%|██████████| 1/1 [00:00<00:00,  7.01batch/s][A[A



 79% 79/100 [00:22<00:05,  3.56epoch/s][A[A[A



 79% 79/100 [00:22<00:05,  3.56epoch/s][A[A[A[A




 79% 79/100 [00:22<00:05,  3.56epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-79:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-79: 100%|██████████| 1/1 [00:00<00:00,  6.84batch/s][A[A



 80% 80/100 [00:22<00:05,  3.54epoch/s][A[A[A



 80% 80/100 [00:22<00:05,  3.54epoch/s][A[A[A[A




 80% 80/100 [00:22<00:05,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-80:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-80: 100%|██████████| 1/1 [00:00<00:00,  7.10batch/s][A[A



 81% 81/100 [00:22<00:05,  3.55epoch/s][A[A[A



 81% 81/100 [00:22<00:05,  3.55epoch/s][A[A[A[A




 81% 81/100 [00:22<00:05,  3.55epoch/s][A

  2% 2/100 [00:00<00:28,  3.43epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-2:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-2: 100%|██████████| 1/1 [00:00<00:00,  6.92batch/s][A[A



  3% 3/100 [00:00<00:28,  3.46epoch/s][A[A[A



  3% 3/100 [00:00<00:28,  3.46epoch/s][A[A[A[A




  3% 3/100 [00:00<00:28,  3.46epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-3:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-3: 100%|██████████| 1/1 [00:00<00:00,  7.14batch/s][A[A



  4% 4/100 [00:01<00:27,  3.49epoch/s][A[A[A



  4% 4/100 [00:01<00:27,  3.49epoch/s][A[A[A[A




  4% 4/100 [00:01<00:27,  3.49epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-4:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-4: 100%|██████████| 1/1 [00:00<00:00,  7.11batch/s][A[A



  5% 5/100 [00:01<00:27,  3.52epoch/s][A[A[A



  5% 5/100 [00:01<00:27,  3.52epoch/s][A[A[A[A




 

epoch-50: 100%|██████████| 1/1 [00:00<00:00,  7.25batch/s][A[A



 51% 51/100 [00:14<00:13,  3.57epoch/s][A[A[A



 51% 51/100 [00:14<00:13,  3.57epoch/s][A[A[A[A




 51% 51/100 [00:14<00:13,  3.57epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-51:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-51: 100%|██████████| 1/1 [00:00<00:00,  7.04batch/s][A[A



 52% 52/100 [00:14<00:13,  3.57epoch/s][A[A[A



 52% 52/100 [00:14<00:13,  3.56epoch/s][A[A[A[A




 52% 52/100 [00:14<00:13,  3.57epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-52:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-52: 100%|██████████| 1/1 [00:00<00:00,  6.88batch/s][A[A



 53% 53/100 [00:14<00:13,  3.55epoch/s][A[A[A



 53% 53/100 [00:14<00:13,  3.55epoch/s][A[A[A[A




 53% 53/100 [00:14<00:13,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-53:   0%|          | 0/1 [00:00<?, ?ba

 99% 99/100 [00:27<00:00,  3.56epoch/s][A[A[A[A




 99% 99/100 [00:27<00:00,  3.56epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-99:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-99: 100%|██████████| 1/1 [00:00<00:00,  7.11batch/s][A[A



100% 100/100 [00:28<00:00,  3.55epoch/s][A[A[A
 99% 99/100 [00:28<00:00,  3.51epoch/s]
 99% 99/100 [00:28<00:00,  3.51epoch/s]
第6折完成,准确率0.75



  0% 0/100 [00:00<?, ?epoch/s][A[A[A



  0% 0/100 [00:00<?, ?epoch/s][A[A[A[A




  0% 0/100 [00:00<?, ?epoch/s][A[A[A[A[Atrain_set : test_set = 41 : 4


  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-0:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-0: 100%|██████████| 1/1 [00:00<00:00,  7.05batch/s][A[A



  1% 1/100 [00:00<00:27,  3.58epoch/s][A[A[A



  1% 1/100 [00:00<00:27,  3.58epoch/s][A[A[A[A




  1% 1/100 [00:00<00:27,  3.58epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-1:   0%|     

 47% 47/100 [00:13<00:14,  3.55epoch/s][A[A[A[A




 47% 47/100 [00:13<00:14,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-47:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-47: 100%|██████████| 1/1 [00:00<00:00,  7.03batch/s][A[A



 48% 48/100 [00:13<00:14,  3.55epoch/s][A[A[A



 48% 48/100 [00:13<00:14,  3.55epoch/s][A[A[A[A




 48% 48/100 [00:13<00:14,  3.56epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-48:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-48: 100%|██████████| 1/1 [00:00<00:00,  7.06batch/s][A[A



 49% 49/100 [00:13<00:14,  3.56epoch/s][A[A[A



 49% 49/100 [00:13<00:14,  3.56epoch/s][A[A[A[A




 49% 49/100 [00:13<00:14,  3.56epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-49:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-49: 100%|██████████| 1/1 [00:00<00:00,  7.19batch/s][A[A



 50% 50/100 [00:14<00:14,  3.57epoch/s

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-95:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-95: 100%|██████████| 1/1 [00:00<00:00,  7.03batch/s][A[A



 96% 96/100 [00:27<00:01,  3.55epoch/s][A[A[A



 96% 96/100 [00:27<00:01,  3.55epoch/s][A[A[A[A




 96% 96/100 [00:27<00:01,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-96:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-96: 100%|██████████| 1/1 [00:00<00:00,  6.88batch/s][A[A



 97% 97/100 [00:27<00:00,  3.55epoch/s][A[A[A



 97% 97/100 [00:27<00:00,  3.55epoch/s][A[A[A[A




 97% 97/100 [00:27<00:00,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-97:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-97: 100%|██████████| 1/1 [00:00<00:00,  6.90batch/s][A[A



 98% 98/100 [00:27<00:00,  3.54epoch/s][A[A[A



 98% 98/100 [00:27<00:00,  3.54epoch/s][A[A[A[A




 98% 98/100 [00:27<00:00,  3.54epoch/s][A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-43:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-43: 100%|██████████| 1/1 [00:00<00:00,  6.96batch/s][A[A



 44% 44/100 [00:12<00:15,  3.55epoch/s][A[A[A



 44% 44/100 [00:12<00:15,  3.55epoch/s][A[A[A[A




 44% 44/100 [00:12<00:15,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-44:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-44: 100%|██████████| 1/1 [00:00<00:00,  7.02batch/s][A[A



 45% 45/100 [00:12<00:15,  3.55epoch/s][A[A[A



 45% 45/100 [00:12<00:15,  3.55epoch/s][A[A[A[A




 45% 45/100 [00:12<00:15,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-45:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-45: 100%|██████████| 1/1 [00:00<00:00,  7.05batch/s][A[A



 46% 46/100 [00:13<00:15,  3.55epoch/s][A[A[A



 46% 46/100 [00:13<00:15,  3.55epoch/s][A[A[A[A




 46% 46/100 [00:13<00:15,  3.56epoch/s][A

epoch-91: 100%|██████████| 1/1 [00:00<00:00,  7.02batch/s][A[A



 92% 92/100 [00:25<00:02,  3.54epoch/s][A[A[A



 92% 92/100 [00:25<00:02,  3.54epoch/s][A[A[A[A




 92% 92/100 [00:25<00:02,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-92:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-92: 100%|██████████| 1/1 [00:00<00:00,  7.02batch/s][A[A



 93% 93/100 [00:26<00:01,  3.55epoch/s][A[A[A



 93% 93/100 [00:26<00:01,  3.55epoch/s][A[A[A[A




 93% 93/100 [00:26<00:01,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-93:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-93: 100%|██████████| 1/1 [00:00<00:00,  6.96batch/s][A[A



 94% 94/100 [00:26<00:01,  3.55epoch/s][A[A[A



 94% 94/100 [00:26<00:01,  3.55epoch/s][A[A[A[A




 94% 94/100 [00:26<00:01,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-94:   0%|          | 0/1 [00:00<?, ?ba

epoch-39: 100%|██████████| 1/1 [00:00<00:00,  7.08batch/s][A[A



 40% 40/100 [00:11<00:16,  3.54epoch/s][A[A[A



 40% 40/100 [00:11<00:16,  3.54epoch/s][A[A[A[A




 40% 40/100 [00:11<00:16,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-40:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-40: 100%|██████████| 1/1 [00:00<00:00,  6.96batch/s][A[A



 41% 41/100 [00:11<00:16,  3.55epoch/s][A[A[A



 41% 41/100 [00:11<00:16,  3.55epoch/s][A[A[A[A




 41% 41/100 [00:11<00:16,  3.55epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-41:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-41: 100%|██████████| 1/1 [00:00<00:00,  6.83batch/s][A[A



 42% 42/100 [00:11<00:16,  3.54epoch/s][A[A[A



 42% 42/100 [00:11<00:16,  3.54epoch/s][A[A[A[A




 42% 42/100 [00:11<00:16,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-42:   0%|          | 0/1 [00:00<?, ?ba

 88% 88/100 [00:24<00:03,  3.54epoch/s][A[A[A[A




 88% 88/100 [00:24<00:03,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-88:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-88: 100%|██████████| 1/1 [00:00<00:00,  7.01batch/s][A[A



 89% 89/100 [00:25<00:03,  3.54epoch/s][A[A[A



 89% 89/100 [00:25<00:03,  3.54epoch/s][A[A[A[A




 89% 89/100 [00:25<00:03,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-89:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-89: 100%|██████████| 1/1 [00:00<00:00,  7.03batch/s][A[A



 90% 90/100 [00:25<00:02,  3.54epoch/s][A[A[A



 90% 90/100 [00:25<00:02,  3.54epoch/s][A[A[A[A




 90% 90/100 [00:25<00:02,  3.54epoch/s][A[A[A[A[A

  0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-90:   0%|          | 0/1 [00:00<?, ?batch/s][A[A

epoch-90: 100%|██████████| 1/1 [00:00<00:00,  7.06batch/s][A[A



 91% 91/100 [00:25<00:02,  3.55epoch/s

In [2]:
#K folds result analyse

loss_acc='gin_acc_5fold_45.p'
with open(loss_acc, 'rb') as f:
    all_info=pickle.load(f)

In [3]:
acc=[]
for i in range(10):
    acc.append(all_info[i][0])
print(np.mean(acc))

0.665


In [4]:
np.std(acc)

0.11191514642799696

In [5]:
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import average_precision_score
from sklearn.metrics import precision_recall_fscore_support
all_rocauc=[]
AP=[]
f1=[]
f1_micro=[]
for k in range(5):
    
    score=(all_info[k][2][0].cpu().numpy())
    label=(all_info[k][1][0].cpu().numpy())
    label=label_binarize(label, classes=[0, 1, 2])
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    y_test,y_score=label,score
    for i in range(3):
        fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    #print(roc_auc["micro"])
    all_rocauc.append(roc_auc["micro"])
    #=========================================
    Y_test,y_score=y_test,y_score
    n_classes=3
# For each class
    precision = dict()
    recall = dict()
    average_precision = dict()
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(Y_test[:, i],
                                                        y_score[:, i])
        average_precision[i] = average_precision_score(Y_test[:, i], y_score[:, i])

    # A "micro-average": quantifying score on all classes jointly
    precision["micro"], recall["micro"], _ = precision_recall_curve(Y_test.ravel(),
        y_score.ravel())
    average_precision["micro"] = average_precision_score(Y_test, y_score,
                                                     average="micro")
    _, predicted = torch.max(torch.from_numpy(y_score), 1)
    predict=label_binarize(predicted, classes=[0, 1, 2])
    macro= precision_recall_fscore_support(Y_test, predict, average='macro')
    micro= precision_recall_fscore_support(Y_test, predict, average='micro')
#     macro=f1_score(Y_test, y_score, average='macro')
#     micro=f1_score(Y_test, y_score, average='micro')
    f1.append(macro[2])
    f1_micro.append(micro[2])
    AP.append(average_precision["micro"])
    print('Average precision score, micro-averaged over all classes: {0:0.2f}'
          .format(average_precision["micro"]))




print(np.mean(AP),np.std(AP),np.mean(f1),np.std(f1),np.mean(f1_micro),np.std(f1_micro),)

Average precision score, micro-averaged over all classes: 0.70
Average precision score, micro-averaged over all classes: 0.75
Average precision score, micro-averaged over all classes: 0.71
Average precision score, micro-averaged over all classes: 0.49
Average precision score, micro-averaged over all classes: 0.57
0.6451854811854811 0.09719645690239952 0.3755555555555556 0.12656916325242792 0.48 0.15999999999999998


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [7]:
macro

(0.5, 0.5, 0.4444444444444444, None)

In [8]:
micro

(0.6, 0.6, 0.6, None)

In [14]:
np.std(all_rocauc)

0.12887621240942798

In [None]:
train_acc=[]
valid_acc=[]
for i in range(100):
    d=data[i*2+1].split()
    train_acc.append(float(d[1]))
    valid_acc.append(float(d[3]))  

In [None]:
print(len(valid_acc))

In [None]:
plt.plot(train_acc, color='cornflowerblue')
plt.plot(valid_acc, color='darkorange')
labels=['train_acc','valid_acc']
#plt.xlim([0, 50])
#plt.ylim([0, 1])
plt.xlabel('epoch')
plt.ylabel('acc')
plt.title('Train_acc vs Valid_acc')
plt.legend(labels,loc="lower right")
plt.savefig('T_V.png',format='png',dpi=300)