In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


import functools
import enum
import os



from BH.data_loader import *
from BH.generate_data import *
from training_info import *
from Model_e import Model_e,Direction,Reduction
from Train import train,print_accuracies



os.environ["CUDA_VISIBLE_DEVICES"] = "0"
use_pretrained_weights = True  #@param{type:"boolean"}
hold_graphs_in_memory = False  #@param{type:"boolean"}

gb = 1024**3
total_memory = psutil.virtual_memory().total / gb
if total_memory < 20 and hold_graphs_in_memory:
    raise RuntimeError(f"It is unlikely your machine (with {total_memory}Gb) will have enough memory to complete the colab's execution!")

print("Loading input data...")
full_dataset, train_dataset, test_dataset = load_input_data(DIR_PATH)

  PyTreeDef = type(jax.tree_structure(None))


Loading input data...
Generating data from the directory /Data/Ptab/n=7


In [None]:
num_classes = 2
model = Model_e(
    num_layers=num_layers,
    num_features=num_features,
    num_classes=num_classes,
    direction=Direction.FORWARD,
    reduction=Reduction.SUM,
    apply_relu_activation=True,
    use_mask=False,
    share=False,
    message_relu=True,
    with_bias=True)
loss_val_gr = jax.value_and_grad(model.loss)
opt_init, opt_update = optax.adam(step_size)

In [3]:
from torch_geometric.utils import from_scipy_sparse_matrix

class CustomDataset(Dataset):
    def __init__(self, input_data):
        self.features = input_data.features
        self.labels = input_data.labels
        self.rows = input_data.rows
        self.columns = input_data.columns
        self.edge_types = input_data.edge_types

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        edge_index, _ = from_scipy_sparse_matrix(self.rows[idx])  # Assuming rows represent the adjacency matrix
        return {
            'x': torch.from_numpy(self.features[idx]),
            'edge_index': edge_index,
            'edge_attr': torch.tensor(self.edge_types[idx], dtype=torch.long),  # Assuming edge_types represent edge features
            'y': torch.from_numpy(self.labels[idx])
        }

ModuleNotFoundError: No module named 'torch_geometric'

In [None]:
for epoch in range(num_epochs):
    total_loss = 0
    model.train()
    for batch in loader:
        data = {k: v.to(device) for k, v in batch.items()}
        optimizer.zero_grad()
        # forward
        out = model(data['x'], data['edge_index'], data['edge_attr'])
        # assuming the labels are node-wise
        loss = F.cross_entropy(out, data['y'])
        # backward
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch: {epoch+1}, Loss: {total_loss / len(loader)}')

In [None]:
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MPNN(node_in_channels, edge_in_channels, node_out_channels).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    total_loss = 0
    model.train()
    for data in loader:
        # data is a batch from the loader
        data = data.to(device)
        optimizer.zero_grad()
        # forward
        out = model(data.x, data.edge_index, data.edge_attr)
        # assuming the labels are node-wise and stored in data.y
        loss = F.cross_entropy(out, data.y)
        # backward
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f'Epoch: {epoch+1}, Loss: {total_loss / len(loader)}')

In [None]:
trained_params = model.net.init(
    jax.random.PRNGKey(42),
    features=train_dataset.features[0],
    rows=train_dataset.rows[0],
    cols=train_dataset.columns[0],
    batch_size=1,
    edge_types=train_dataset.edge_types[0])
trained_opt_state = opt_init(trained_params)

best_acc = None
for ep in range(1, num_epochs + 1):
    tr_data = list(
        zip(
            train_dataset.features,
            train_dataset.rows,
            train_dataset.columns,
            train_dataset.labels,
            train_dataset.edge_types,
        ))
    random.shuffle(tr_data)
    features_train, rows_train, cols_train, ys_train, edge_types_train = zip(
        *tr_data)

    features_train = list(features_train)
    rows_train = list(rows_train)
    cols_train = list(cols_train)
    ys_train = np.array(ys_train)
    edge_types_train = list(edge_types_train)

    for i in range(0, len(features_train), batch_size):
        b_features, b_rows, b_cols, b_ys, b_edges = batch_e(
            features_train[i:i + batch_size],
            rows_train[i:i + batch_size],
            cols_train[i:i + batch_size],
            ys_train[i:i + batch_size],
            edge_types_train[i:i + batch_size],
        )

        trained_params, trained_opt_state, curr_loss = train(
            loss_val_gr,
            opt_update,
            trained_params,
            trained_opt_state,
            b_features,
            b_rows,
            b_cols,
            b_ys,
            b_edges,
        )

        accs = model.accuracy(
            trained_params,
            b_features,
            b_rows,
            b_cols,
            b_ys,
            b_edges,
        )
        print(datetime.datetime.now(),
              f"Iteration {i:4d} | Batch loss {curr_loss:.6f}",
              f"Batch accuracy {accs:.2f}")

    print(datetime.datetime.now(), f"Epoch {ep:2d} completed!")

    # Calculate accuracy across full dataset once per epoch
    print(datetime.datetime.now(), f"Epoch {ep:2d}       | ", end="")
    test_acc = print_accuracies(model,trained_params, test_dataset, train_dataset)
    if best_acc == None or best_acc < test_acc:
        best_acc = test_acc
        if save_trained_weights and best_acc > 0.6:
            with open(PARAM_FILE, 'wb') as f:
                pickle.dump(trained_params, f)

# In[ ]: