In [1]:
import numpy as np
import torch
from torch.autograd import Variable
from math_support import graph_random_walk, convert_sequence_to_graph, compute_index_subsample, graph_random_walk_fixed_start
from data_loader import read_data, perform_ttv_split
from tqdm.auto import tqdm
import os

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import numpy as np
import torch
from torch.autograd import Variable

def extract_features_for_classifier(nnet, xembed_numpy, ylabel_numpy, eglist, 
                                    idx_train, idx_valid, idx_test, 
                                    nbatch=64):
    nnet.eval()
    
    if isinstance(xembed_numpy, np.ndarray):
        xembed = torch.from_numpy(xembed_numpy).double()
    else:
        xembed = xembed_numpy.clone()
    
    num_vertices = xembed.shape[0]
    
    final_features = torch.zeros_like(xembed)
    feature_counts = torch.zeros(num_vertices)
    
    for start_vertex in tqdm(list(range(num_vertices))):
        
        random_walk_data = graph_random_walk_fixed_start(eglist, nbatch, start_vertex)
        
        wgraph_numpy, idx_node = convert_sequence_to_graph(random_walk_data)
        wgraph = torch.from_numpy(wgraph_numpy).double()
        
        idx_subsample_train = compute_index_subsample(idx_node, idx_train)
        idx_subsample_valid = compute_index_subsample(idx_node, idx_valid)
        idx_subsample_test = compute_index_subsample(idx_node, idx_test)
        
        with torch.no_grad():
            subgraph_features = perform_message_passing(nnet, xembed, wgraph, idx_node)
        
        for subset_name, idx_subset in [
            ('train', idx_subsample_train), 
            ('valid', idx_subsample_valid), 
            ('test', idx_subsample_test)
        ]:
            if len(idx_subset) > 0:
                for local_idx in idx_subset:
                    global_idx = idx_node[local_idx]
                    
                    final_features[global_idx] += subgraph_features[local_idx]
                    feature_counts[global_idx] += 1
    
    for i in range(num_vertices):
        if feature_counts[i] > 0:
            final_features[i] /= feature_counts[i]
        else:
            final_features[i] = xembed[i]
            print(f"Warning: Vertex {i} not included in any random walk. Using original features.")
    
    final_features_np = final_features.detach().cpu().numpy()
    
    features_dict = {
        'features': final_features_np,
        'labels': ylabel_numpy,
        
        'train_features': final_features_np[idx_train],
        'train_labels': ylabel_numpy[idx_train],
        
        'val_features': final_features_np[idx_valid],
        'val_labels': ylabel_numpy[idx_valid],
        
        'test_features': final_features_np[idx_test],
        'test_labels': ylabel_numpy[idx_test],
        
        'idx_train': idx_train,
        'idx_valid': idx_valid,
        'idx_test': idx_test
    }
    
    print(f"Features extracted successfully")
    print(f"Train set: {len(idx_train)} samples")
    print(f"Validation set: {len(idx_valid)} samples")
    print(f"Test set: {len(idx_test)} samples")
    
    return features_dict

def perform_message_passing(nnet, xembed, wgraph, idx_node):
    xmaped = xembed[idx_node].clone()
    
    num_subgraph_vertices = len(idx_node)
    
    edge_indices = []
    for i in range(num_subgraph_vertices):
        for j in range(num_subgraph_vertices):
            if wgraph[i, j] > 0:
                edge_indices.append((i, j))
    
    # Итерации message passing
    for conv_idx in range(nnet.nconv):
        if len(edge_indices) > 0:
            source_indices = [i for i, j in edge_indices]
            target_indices = [j for i, j in edge_indices]
            
            source_features = xmaped[source_indices]
            target_features = xmaped[target_indices]
            
            edge_matrices_batch = nnet.get_edge_matrix(source_features, target_features)
            
            new_xmaped = torch.zeros_like(xmaped)
            node_counts = torch.zeros(num_subgraph_vertices).to(xmaped.device)
            
            messages = torch.bmm(
                edge_matrices_batch,
                source_features.unsqueeze(2)
            ).squeeze(2)
            
            for idx, (i, j) in enumerate(edge_indices):
                weight = wgraph[i, j]
                new_xmaped[j:j+1, :] += weight * messages[idx:idx+1]
                node_counts[j] += weight
            
            for j in range(num_subgraph_vertices):
                if node_counts[j] > 0:
                    new_xmaped[j] /= node_counts[j]
            
            xmaped = new_xmaped
        else:
            pass
    
    return xmaped

In [3]:
datasets = {}

for dataset_name in ['CiteSeer', 'PubMed']:
    xembed, eglist, ylabel, ylprob, xsvd = read_data(embedding_dimension=1,
                                                     dataset_name=dataset_name, eps=1.0e-6)
    print('ylabel.shape = ' + str(ylabel.shape))
    nsample = xembed.shape[0]
    idx_train, idx_ttest, idx_valid = perform_ttv_split(nsample, ftrain=0.6, fttest=0.2, fvalid=0.2)

    datasets[dataset_name] = (xembed, ylabel, eglist, idx_train, idx_valid, idx_ttest)

ylabel.shape = (4230,)
ylabel.shape = (19717,)


In [4]:
for filename in os.listdir("/home/ubuntu/simulations/nnet_folder"):
    model_path = f"/home/ubuntu/simulations/nnet_folder/{filename}"

    model = torch.load(model_path, map_location=torch.device('cpu'), weights_only=False)

    dataset_name = filename.split('_')[2]
    xembed, ylabel, eglist, idx_train, idx_valid, idx_ttest = datasets[dataset_name]

    extracted_features_result = extract_features_for_classifier(model, xembed, ylabel, eglist, 
                                        idx_train, idx_valid, idx_ttest, 
                                        nbatch=64)

    folder_name = filename.split('.')[0]
    os.makedirs(f"/home/ubuntu/simulations/classificator_features/{folder_name}", exist_ok=True)

    for key, arr in extracted_features_result.items():
        np.save(f"/home/ubuntu/simulations/classificator_features/{folder_name}/{key}", arr)

100%|██████████| 19717/19717 [3:27:52<00:00,  1.58it/s]  


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples


100%|██████████| 4230/4230 [1:12:48<00:00,  1.03s/it]


Features extracted successfully
Train set: 2538 samples
Validation set: 846 samples
Test set: 846 samples


100%|██████████| 4230/4230 [1:11:59<00:00,  1.02s/it]


Features extracted successfully
Train set: 2538 samples
Validation set: 846 samples
Test set: 846 samples


100%|██████████| 4230/4230 [1:11:39<00:00,  1.02s/it]


Features extracted successfully
Train set: 2538 samples
Validation set: 846 samples
Test set: 846 samples


100%|██████████| 19717/19717 [3:37:49<00:00,  1.51it/s]  


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples


100%|██████████| 4230/4230 [58:32<00:00,  1.20it/s]  


Features extracted successfully
Train set: 2538 samples
Validation set: 846 samples
Test set: 846 samples


100%|██████████| 19717/19717 [3:29:36<00:00,  1.57it/s]  


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples


100%|██████████| 19717/19717 [1:26:08<00:00,  3.81it/s]


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples


100%|██████████| 19717/19717 [1:12:50<00:00,  4.51it/s]


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples


100%|██████████| 19717/19717 [1:12:44<00:00,  4.52it/s]


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples


100%|██████████| 19717/19717 [1:11:59<00:00,  4.57it/s]


Features extracted successfully
Train set: 11831 samples
Validation set: 3943 samples
Test set: 3943 samples
