In [1]:
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from my_dataloader import GraphDataLoader, collate
from gin import GIN
from tqdm import tqdm
import time
import pickle as pkl
from torch.utils.data import DataLoader

Using backend: pytorch


In [None]:
dataset_name ='/home/dldx/gin/data_predict.p'
with open(dataset_name, 'rb') as f:
    data2700 = pkl.load(f)
print(len(data2700) )

def train( net, trainloader, optimizer, criterion, epoch):
    net.train()
    running_loss = 0
    total_iters = len(trainloader)
    # setup the offset to avoid the overlap with mouse cursor
    bar = tqdm(range(total_iters), unit='batch', position=2, file=sys.stdout)

    for pos, (graphs, labels,names) in zip(bar, trainloader):
        # batch graphs will be shipped to device in forward part of model
        labels = labels.to(device)
        #feat = graphs.ndata['attr'].to(device)
        feat = graphs.ndata['feat'].float().to(device)
        output = net(graphs, feat)
#         outputs=output[0]
        loss = criterion(output, labels)
        running_loss += loss.item()

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # report
        bar.set_description('epoch-{}'.format(epoch))
    bar.close()
    # the final batch will be aligned
    running_loss = running_loss / total_iters

    return running_loss 

@torch.no_grad()
def eval_net( net, dataloader, criterion):
    net.eval()
    total = 0
    total_loss = 0
    total_correct = 0
    targets,outputs=[],[]
    for data in dataloader:
        graphs, labels,names = data
        feat = graphs.ndata['feat'].float().to(device)
        labels = labels.to(device)
        # print('\n初始标签: \n',labels)
        total += len(labels)
        output = net(graphs, feat)
        _, predicted = torch.max(output.data, 1)
        # print('\npredict: \n',predicted)
        total_correct += (predicted == labels.data).sum().item()
        loss = criterion(output, labels)
        total_loss += loss.item() * len(labels)
        targets.append(labels)
        outputs.append(output)
    loss, acc = 1.0*total_loss / total, 1.0*total_correct / total

    return acc,targets,outputs

# set up seeds, args.seed supported
seed=2021
torch.manual_seed(seed=seed)
np.random.seed(seed=seed)

#指定GPU
torch.cuda.set_device(0)
if torch.cuda.is_available():

    device = torch.device("cuda")
    torch.cuda.manual_seed_all(seed=seed)
else:
    device = torch.device("cpu")
print(device)
start = time.time()

model =GIN(5, 2, 21, 64, 3, 0.5, True, "sum", "sum").to(device)
epochs=100
criterion = nn.CrossEntropyLoss()  # defaul reduce is true
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

tbar = tqdm(range(epochs), unit="epoch", position=3, ncols=0, file=sys.stdout)
vbar = tqdm(range(epochs), unit="epoch", position=4, ncols=0, file=sys.stdout)
lrbar = tqdm(range(epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)
validloader = DataLoader(data2700, batch_size=64,collate_fn=collate, shuffle=True)
for fold_idx in range(10):
    model =GIN(5, 2, 21, 64, 3, 0.5, True, "sum", "sum").to(device)

    epochs=100
    criterion = nn.CrossEntropyLoss()  # defaul reduce is true
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

    tbar = tqdm(range(epochs), unit="epoch", position=3, ncols=0, file=sys.stdout)
    vbar = tqdm(range(epochs), unit="epoch", position=4, ncols=0, file=sys.stdout)
    lrbar = tqdm(range(epochs), unit="epoch", position=5, ncols=0, file=sys.stdout)
    trainloader, validloader = GraphDataLoader(data2700, batch_size=64, device=device,
                                                collate_fn=collate, seed=2021, shuffle=True,
                                                split_name='fold10', fold_idx=fold_idx).train_valid_loader()
    model.load_state_dict(torch.load('/home/dldx/gin/gin21_onehot9'))
    print(eval_net(model, validloader, criterion))
end = time.time()
print("运行时间:%.2f秒"%(end-start))      
print("work down!")