In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.utils import MatrixVectorizer
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import KFold
from torch.utils.data import TensorDataset, DataLoader
import random
from src.evaluation import evaluate
import copy
from evaluation_metric import evaluate_all


In [None]:
# Set a fixed random seed for reproducibility across multiple libraries
random_seed = 42
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)

# Check for CUDA (GPU support) and set device accordingly
if torch.cuda.is_available():
    DEVICE = torch.device("cuda")
    print("CUDA is available. Using GPU.")
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # For multi-GPU setups
    # Additional settings for ensuring reproducibility on CUDA
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
else:
    DEVICE = torch.device("cpu")
    print("CUDA not available. Using CPU.")

In [None]:
def get_datasets(lr_dim=160, hr_dim=268):
    """ Retrieves train and test datasets from disk, and applies necessary pre-processing. """
    lr_train1 = []
    lr_train2 = []
    lr_train3 = []
    lr_test = []
    hr_train1 = []
    hr_train2 = []
    hr_train3 = []

    lr_train_vec1 = pd.read_csv("../../RandomCV/Train/Split1/lr_clusterA.csv").clip(lower=0).fillna(0).values
    hr_train_vec1 = pd.read_csv("../../RandomCV/Train/Split1/hr_clusterA_modified.csv").clip(lower=0).fillna(0).values

    lr_train_vec2 = pd.read_csv("../../RandomCV/Train/Split2/lr_clusterB.csv").clip(lower=0).fillna(0).values
    hr_train_vec2 = pd.read_csv("../../RandomCV/Train/Split2/hr_clusterB_modified.csv").clip(lower=0).fillna(0).values

    lr_train_vec3 = pd.read_csv("../../RandomCV/Train/Split3/lr_clusterC.csv").clip(lower=0).fillna(0).values
    hr_train_vec3 = pd.read_csv("../../RandomCV/Train/Split3/hr_clusterC_modified.csv").clip(lower=0).fillna(0).values

    # lr_train_vec = pd.read_csv("data/lr_train.csv").clip(lower=0).fillna(0).values
    # lr_test_vec = pd.read_csv("data/lr_test.csv").clip(lower=0).fillna(0).values
    # hr_train_vec = pd.read_csv("data/hr_train.csv").clip(lower=0).fillna(0).values
    for x,y in zip(lr_train_vec1, lr_train_vec2):
        # Reconstitute matrices from flattened representation
        adj1 = MatrixVectorizer.anti_vectorize(x, lr_dim).astype(float)
        lr_train1.append(adj1 + np.eye(lr_dim))

        adj2 = MatrixVectorizer.anti_vectorize(y, lr_dim).astype(float)
        lr_train2.append(adj2 + np.eye(lr_dim))

    for x,y in zip(hr_train_vec1, hr_train_vec2):
        # Reconstitute matrices from flattened representation
        adj1 = MatrixVectorizer.anti_vectorize(x, hr_dim).astype(float)
        hr_train1.append(adj1 + np.eye(hr_dim))

        adj2 = MatrixVectorizer.anti_vectorize(y, hr_dim).astype(float)
        hr_train2.append(adj2 + np.eye(hr_dim))

    for x,y in zip(lr_train_vec3,hr_train_vec3):
        adj1 = MatrixVectorizer.anti_vectorize(x, lr_dim).astype(float)
        lr_train3.append(adj1 + np.eye(lr_dim))

        adj2 = MatrixVectorizer.anti_vectorize(y, hr_dim).astype(float)
        hr_train3.append(adj2 + np.eye(hr_dim))

    return np.array(lr_train1), np.array(lr_train2), np.array(lr_train3), \
            np.array(hr_train1), np.array(hr_train2), np.array(hr_train3)


def get_svd_dataset(adjs, k=1):
    """ Convert adajcency matrices to their rank-k approximations. """
    svd_approxes = []
    for adj in adjs:
        U, S, Vt = np.linalg.svd(adj)
        svd_approx = U[:, 0:k] @ np.diag(S[0:k]) @ Vt[0:k, :]
        svd_approxes.append(svd_approx)
    return np.array(svd_approxes)


lr_train1, lr_train2, lr_train3, hr_train1, hr_train2, hr_train3 = get_datasets()

In [None]:
graph_items_train = [lr_train1, lr_train2, lr_train3]
graph_items_test = [ hr_train1, hr_train2, hr_train3]

print(lr_train1.shape)
print(lr_train2.shape)
print(lr_train3.shape)
print(hr_train1.shape)
print(hr_train2.shape)
print(hr_train3.shape)


In [None]:

class GCNConv(torch.nn.Module):
    """A single GCN layer without non-linear activation."""

    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.W = self.init_parameters()

    def init_parameters(self):
        weight = torch.zeros(self.in_dim, self.out_dim)
        stdv = 1.0 / np.sqrt(weight.size(1))
        weight.uniform_(-stdv, stdv)
        return nn.Parameter(weight)

    def forward(self, features, adjacency):
        return (adjacency) @ features @ self.W


class GCNBlock(torch.nn.Module):
    """A GCN block with activation."""

    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.conv = GCNConv(in_dim, out_dim)
        self.activation = nn.Tanh()

    def forward(self, features, adjacency):
        x = features
        x = self.conv(x, adjacency)
        x = self.activation(x)
        return x


class GraphSvdModel(torch.nn.Module):
    """Graph SVD model for brain graph super-resolution."""

    def __init__(self, in_dim, gcn_hidden_dim, out_dim, rank):
        """
        Constructs the GraphSVD model.

        Args:
            in_dim (int): the node attribute size
            gcn_hidden_dim (int): the hidden layers node attribute size
            out_dim (int): the number of nodes in the high resolution graph output
            rank (int): the number of singular values & vectors to use

        Returns:
            None

        """
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        self.rank = rank

        # Construct GCN encoder layers
        self.conv1 = GCNBlock(in_dim, gcn_hidden_dim)
        self.conv2 = GCNBlock(gcn_hidden_dim, gcn_hidden_dim)
        self.conv3 = GCNBlock(gcn_hidden_dim, gcn_hidden_dim)
        self.conv4 = GCNBlock(gcn_hidden_dim, gcn_hidden_dim)
        self.conv5 = GCNBlock(gcn_hidden_dim, gcn_hidden_dim)

        # Construct U matrix layers
        self.svd_u_fc1 = nn.Linear(gcn_hidden_dim, gcn_hidden_dim, bias=False)
        self.svd_u_fc2 = nn.Linear(gcn_hidden_dim, gcn_hidden_dim, bias=False)
        self.svd_u_proj = nn.Linear(
            gcn_hidden_dim, self.rank * self.out_dim, bias=False
        )

        # Construct V matrix layers
        self.svd_v_fc1 = nn.Linear(gcn_hidden_dim, gcn_hidden_dim, bias=False)
        self.svd_v_fc2 = nn.Linear(gcn_hidden_dim, gcn_hidden_dim, bias=False)
        self.svd_v_proj = nn.Linear(
            gcn_hidden_dim, self.rank * self.out_dim, bias=False
        )

        # Construct S matrix layers
        self.svd_sv_fc1 = nn.Linear(gcn_hidden_dim, gcn_hidden_dim, bias=False)
        self.svd_sv_fc2 = nn.Linear(gcn_hidden_dim, gcn_hidden_dim, bias=False)
        self.sv_proj = nn.Linear(gcn_hidden_dim, rank, bias=False)

    def forward(self, features, adjacency):
        """
        Run a forward pass of the model.

        Args:
            features (torch.tensor): batched node features
            adjacency (torch.tensor): batched adjacency matrics

        Returns:
            torch.tensor of batched super-resolved adjacency matrices

        """
        x = features

        # Run encoder message passing
        x = self.conv1(x, adjacency)
        x = self.conv2(x, adjacency)
        x = self.conv3(x, adjacency)
        x = self.conv4(x, adjacency)
        x = self.conv5(x, adjacency)

        # Calculate graph-level embedding
        x = torch.sum(x, 1, keepdims=True)

        # Decode U matrix
        svd_u = self.svd_u_fc1(x)
        svd_u = F.tanh(svd_u)
        svd_u = self.svd_u_fc2(svd_u)
        svd_u = F.tanh(svd_u)
        svd_u = self.svd_u_proj(svd_u)

        # Decode V matrix
        svd_v = self.svd_v_fc1(x)
        svd_v = F.tanh(svd_v)
        svd_v = self.svd_v_fc2(svd_v)
        svd_v = F.tanh(svd_v)
        svd_v = self.svd_v_proj(svd_v)

        # Deocde S matrix
        sv = self.svd_sv_fc1(x)
        sv = F.tanh(sv)
        sv = self.svd_sv_fc2(sv)
        sv = F.tanh(sv)
        sv = self.sv_proj(sv)

        # Reshape flattened matrices
        svd_u = torch.reshape(svd_u, (features.shape[0], self.out_dim, self.rank))
        svd_v = torch.reshape(svd_v, (features.shape[0], self.out_dim, self.rank))
        sv = torch.diag_embed(sv).squeeze(1)

        # Compute low-rank approximation
        approx = svd_u @ sv @ torch.transpose(svd_v, 1, 2)

        return approx

def test_model(model, test_adj, test_labels,
               source_res=160,feature_dim=50, HR_size=268):
    """
    Test the GAN AGSR model function

    :param model: The trained GAN model
    :param test_adj: The adjacency matrices of the test subjects
    :param test_labels: The labels of the test subjects
    :param args: The arguments for the model
    :return: The mean absolute error of the model on the test data
    """
    def cal_error(model_outputs, hr):
        return torch.nn.L1Loss(model_outputs, hr)

    model.eval()
    features = torch.ones(test_adj.shape[0], source_res, feature_dim).to(DEVICE)
    test_error = []
    predictions = []

    # TESTING
    with torch.no_grad():
            test_adj = torch.from_numpy(test_adj).type(torch.FloatTensor).to(DEVICE)
            test_labels = torch.from_numpy(test_labels).type(torch.FloatTensor).to(DEVICE)
            preds = model(features,test_adj)
            # evaluate_all(hr,preds)
            predictions.append(preds)
            error = torch.nn.functional.l1_loss(preds, test_labels)
            

    return error

def train_model(
    lr_train,
    hr_train,
    learning_rate=0.0001,
    batch_size=16,
    num_epochs=1000,
    rank=10,
    feature_dim=50,
    hidden_dim=50,
    source_res=160,
    target_res=268,
    step_size=500,
    gamma=0.1,
    test_adj=None,
    test_ground_truth=None
):
    """
    Main model training loop

    Args:
        lr_train (np.ndarray): batched low-resolution training dataset
        hr_train (np.ndarray): batched high-resolution training dataset
        learning_rate (float): learning rate to use in the optimiser
        batch_size (int): mini-batch size
        num_epochs (int): number of epochs to train for
        rank (int): rank of approximation to use in model
        feature_dim (int): dimension of input node features
        hidden_dim (int): encoder hidden dimension
        source_res (int): number of nodes in low-resolution graph
        target_res (int): number of nodes in high-resolution graph
        step_size (int): frequency (in steps) of learning rate scheduler decay
        gamma (float): decay factor to be applied to learning rate every (step_size) steps

    Returns:

        model (GraphSVD): a trained model
    """

    model = GraphSvdModel(
        in_dim=feature_dim, gcn_hidden_dim=hidden_dim, out_dim=target_res, rank=rank
    )
    model = model.to(DEVICE)
    model.train()
    best_mae = np.inf  # Initialize best mean absolute error

    print(
        f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}\n"
    )

    # Construct dataset
    x_train = torch.ones(lr_train.shape[0], source_res, feature_dim).to(DEVICE)
    lr_train_tensor = torch.tensor(lr_train, dtype=torch.float32).to(DEVICE)
    hr_train_tensor = torch.tensor(hr_train, dtype=torch.float32).to(DEVICE)
    train_dataset = TensorDataset(x_train, lr_train_tensor, hr_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    loss_fn = torch.nn.L1Loss()
    optimizer = torch.optim.Adamax(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=step_size, gamma=gamma
    )

    for epoch in range(num_epochs):
        for i, data in enumerate(train_loader):
            x_train, lr, hr = data
            optimizer.zero_grad()
            out = model(x_train, lr)
            loss = loss_fn(out, hr)
            loss.backward()
            optimizer.step()

        # Evaluate on test data if provided
        if test_adj is not None and test_ground_truth is not None:
            test_error = test_model(model, test_adj, test_ground_truth)

            if test_error < best_mae:
                best_mae = test_error
                early_stop_count = 0
                best_model = copy.deepcopy(model)
            elif early_stop_count >= 300:
                # Early stopping condition met
                if test_adj is not None and test_ground_truth is not None:
                    test_error = test_model(
                        best_model, test_adj, test_ground_truth)
                    print(f"Val Error: {test_error:.6f}")
                return best_model
            else:
                early_stop_count += 1


        if epoch % 100 == 0:
            print(f"Epoch {epoch}: \t loss: {loss.item()} val loss: {test_error}")
        scheduler.step()

    return model


def predict(model, lr_data, source_res=160, feature_dim=50):
    """Get model predictions on provided dataset."""
    features = torch.ones(lr_data.shape[0], source_res, feature_dim).to(DEVICE)
    lr_data = torch.tensor(lr_data, dtype=torch.float32).to(DEVICE)
    with torch.no_grad():
        preds = torch.clamp(model(features, lr_data).squeeze(0), 0, 1).cpu().numpy()
    return preds


def write_predictions(filename, preds):
    """Save model predictions to disk."""
    preds_vectorised = []
    for pred in preds:
        preds_vectorised.append(MatrixVectorizer.vectorize(pred))
    preds_vectorised = np.array(preds_vectorised).flatten()

    with open(filename, "w") as outfile:
        outfile.write("ID,Predicted\n")
        for i, pred in enumerate(preds_vectorised):
            outfile.write(f"{i+1},{pred}\n")

In [None]:


kf = KFold(n_splits=3, shuffle=True, random_state=random_seed)
predictions = []
ground_truths = []

best_model_fold_list = []
data_fold_list = []
i = 1
# Store the fold results
fold_results = []
models = []
print(f"Starting Cross Validation.")
for index in range(len(graph_items_test)):
    print(f"----- Fold {index} -----")

    train = graph_items_train[index]
    test = graph_items_test[index]

    # subjects_adj, test_adj, subjects_ground_truth, test_ground_truth = (
    #     X[train_index],
    #     X[test_index],
    #     Y[train_index],
    #     Y[test_index],
    # )
    # data_fold_list.append(
    #     (subjects_adj, test_adj, subjects_ground_truth, test_ground_truth)
    # )
    # Create a deep copy of list1
    new_train = copy.deepcopy(graph_items_train)
    new_train.pop(index)
    new_test = copy.deepcopy(graph_items_test)
    new_test.pop(index)
    new_train = np.concatenate(new_train, axis=0)
    new_test = np.concatenate(new_test, axis=0)
    
    # lr_train_split = lr_train[train_idxs]
    # hr_train_split = hr_train[train_idxs]
    # lr_val_split = lr_train[val_idxs]
    # hr_val_split = hr_train[val_idxs]

    model = train_model(new_train, new_test, test_adj=train, test_ground_truth=test)
    models.append(model)
    preds = predict(model, train)
    predictions.append(preds)
    ground_truths.append(new_test)
    evaluate_all(test,preds)

In [None]:
from evaluation_metric import evaluate_all
# final_model = model
# final_model.eval()
with torch.no_grad():
    for model in models:
        pred_train_matrices = []
        for j, test_adj in enumerate(train):
            model.eval()
            pred = model(torch.from_numpy(test_adj))[0]
            pred = torch.clamp(pred, min=0.0, max=1.0)
            pred = pred.cpu()
            pred_train_matrices.append(pred)

        print("Train")
        pred_train_matrices = np.array(pred_train_matrices)
        evaluate_all(test, pred_train_matrices)

In [None]:
def visualise_prediction(prediction, target):
    """Plot prediction vs target matrix"""
    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(target)
    axs[0].set_title("Ground Truth")

    axs[1].imshow(prediction)
    axs[1].set_title("Prediction")

    axs[2].imshow(target - prediction)
    axs[2].set_title("Residual")

    fig.set_size_inches(15, 9)
    fig.savefig("figures/prediction.pdf", bbox_inches="tight")
    plt.show()


i = 5
target = ground_truths[0][i]
prediction = predictions[0][i]
visualise_prediction(prediction, target)
print(f"MAE: {np.abs(target-prediction).sum()/(target.shape[0]**2)}")

In [None]:
def visualise_prediction(prediction, target):
    """Plot prediction vs target matrix"""
    fig, axs = plt.subplots(1, 3)
    axs[0].imshow(target)
    axs[0].set_title("Ground Truth")

    axs[1].imshow(prediction)
    axs[1].set_title("Prediction")

    axs[2].imshow(target - prediction)
    axs[2].set_title("Residual")

    fig.set_size_inches(15, 9)
    fig.savefig("figures/prediction.pdf", bbox_inches="tight")
    plt.show()


i = 5
target = ground_truths[0][i]
prediction = predictions[0][i]
visualise_prediction(prediction, target)
print(f"MAE: {np.abs(target-prediction).sum()/(target.shape[0]**2)}")

In [None]:
# Plot a sample of predictions
fig, axs = plt.subplots(1, 5)
for i in range(5):
    axs[i].imshow(predictions[1][i])
    axs[i].get_xaxis().set_ticks([])
    axs[i].get_yaxis().set_ticks([])
fig.set_size_inches(14, 8)
fig.savefig("figures/prediction_examples.pdf", bbox_inches="tight")
plt.show()

In [None]:
# Compute evaluations (note this is extremely slow due to the betweeness centrality calculation)
evaluations = []
for i, (pred, gt) in enumerate(zip(predictions, ground_truths)):
    evaluation = evaluate(pred.shape[0], pred, gt)
    evaluations.append(evaluation)
    write_predictions(f"predictions_fold_{i+1}.csv", pred)
print(evaluations)

In [None]:
def plot_cv_results(evaluations):
    """Plot bar charts of evaluation results required in report."""
    df_dict = {
        "Fold": [],
        "MAE": [],
        "PCC": [],
        "JSD": [],
        "MAE (PC)": [],
        "MAE (EC)": [],
        "MAE (BC)": [],
    }
    for i, evaluation in enumerate(evaluations):
        df_dict["Fold"].append(i + 1)
        df_dict["MAE"].append(evaluation["mae"])
        df_dict["PCC"].append(evaluation["pcc"])
        df_dict["JSD"].append(evaluation["js_dis"])
        df_dict["MAE (PC)"].append(evaluation["avg_mae_pc"])
        df_dict["MAE (EC)"].append(evaluation["avg_mae_ec"])
        df_dict["MAE (BC)"].append(evaluation["avg_mae_bc"])

    df = pd.DataFrame(df_dict)
    df.to_csv("evaluations.csv")

    # Split into distributional and topological measures
    df_dbn = df[["MAE", "PCC", "JSD"]]
    df_centrality = df[["MAE (PC)", "MAE (EC)", "MAE (BC)"]]

    plot_folds(df_dbn, "distribution_measures")
    plot_folds(df_centrality, "centrality_measures")


def plot_folds(df, name):
    """Plot results across folds."""
    colors = ["tab:blue", "tab:orange", "tab:green"]
    fig, axs = plt.subplots(2, 2)
    fig.set_size_inches((8, 8))
    axs[0][0].bar(df.columns.values, df.iloc[0, :], color=colors)
    axs[0][1].bar(df.columns.values, df.iloc[1, :], color=colors)
    axs[1][0].bar(df.columns.values, df.iloc[2, :], color=colors)
    axs[1][1].bar(
        df.columns.values, df.mean(), color=colors, yerr=2 * df.std().values, capsize=5
    )
    axs[0][0].set_title("Fold 1")
    axs[0][1].set_title("Fold 2")
    axs[1][0].set_title("Fold 3")
    axs[1][1].set_title("Avg Across Folds")
    fig.savefig(f"figures/{name}.pdf", bbox_inches="tight")
    plt.show()


plot_cv_results(evaluations)

In [None]:
# Train final model on all the data for kaggle submission
final_model = train_model(lr_train, hr_train)

In [None]:
kaggle_preds = predict(final_model, lr_test, feature_dim=50)
print(kaggle_preds.min(), kaggle_preds.max()) # sanity check: should be between 0 and 1
write_predictions("predictions.csv", kaggle_preds)

In [None]:
# Compute the performance metrics for a random baseline model
random_baseline = np.random.rand(hr_train.shape[0], 268, 268)
print(random_baseline.shape)
evaluate(random_baseline.shape[0], random_baseline, hr_train[0:random_baseline.shape[0]])