In [1]:
import torch
from aligner import Aligner

## Get Data

In [2]:
from data_preparation import load_data_tensor

lr_train, lr_test, hr_train = load_data_tensor("dgl-icl")

lr_X_dim1 = torch.load('model_autoencoder/final_embeddings/encode_lr.pt')
lr_X_dim3 = torch.load('model_autoencoder/final_embeddings/encode_lr_3.pt')
hr_X_dim1 = torch.load('model_autoencoder/final_embeddings/encode_hr.pt')
hr_X_dim3 = torch.load('model_autoencoder/final_embeddings/encode_hr_3.pt')
lr_X_dim1_test = torch.load('model_autoencoder/final_embeddings/encode_lr_test.pt')
hr_X_dim3_test = torch.load('model_autoencoder/final_embeddings/encode_lr_test_3.pt')

## Specify Model

In [15]:
import torch
from torch.nn import Sequential, Linear, ReLU, Sigmoid, Tanh, Dropout, Upsample
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric
from torch_geometric.nn import NNConv
from torch_geometric.nn import GCNConv
from torch_geometric.nn import BatchNorm
import numpy as np
from torch_geometric.data import Data
from torch.autograd import Variable

N_SOURCE_NODES = 160
N_TARGET_NODES = 268

class Aligner(torch.nn.Module):
    def __init__(self):
        
        super(Aligner, self).__init__()

        nn = Sequential(Linear(1, N_SOURCE_NODES*N_SOURCE_NODES), ReLU())
        self.conv1 = NNConv(N_SOURCE_NODES, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv11 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        nn = Sequential(Linear(1, N_SOURCE_NODES), ReLU())
        self.conv2 = NNConv(N_SOURCE_NODES, 1, nn, aggr='mean', root_weight=True, bias=True)
        self.conv22 = BatchNorm(1, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)

        nn = Sequential(Linear(1, N_SOURCE_NODES), ReLU())
        self.conv3 = NNConv(1, N_SOURCE_NODES, nn, aggr='mean', root_weight=True, bias=True)
        self.conv33 = BatchNorm(N_SOURCE_NODES, eps=1e-03, momentum=0.1, affine=True, track_running_stats=True)


    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr

        x1 = F.sigmoid(self.conv11(self.conv1(x, edge_index, edge_attr)))
        x1 = F.dropout(x1, training=self.training)

        x2 = F.sigmoid(self.conv22(self.conv2(x1, edge_index, edge_attr)))
        x2 = F.dropout(x2, training=self.training)

        x3 = torch.cat([F.sigmoid(self.conv33(self.conv3(x2, edge_index, edge_attr))), x1], dim=1)
        x4 = x3[:, 0:N_SOURCE_NODES]
        x5 = x3[:, N_SOURCE_NODES:2*N_SOURCE_NODES]

        x6 = (x4 + x5) / 2
        return x6

def create_batch(X, A):
    data_list = []
    for x, adj in zip(X, A):
        edge_index = adj.nonzero().t()
        edge_weights = adj[edge_index[0], edge_index[1]]
        edge_index, edge_weights = torch_geometric.utils.add_self_loops(edge_index, edge_weights) # add self connections
        data = Data(x=x, edge_index=edge_index, edge_attr=edge_weights.view(-1, 1))
        data_list.append(data)
    return data_list

def convert_generated_to_graph(data):
    """
        convert generated output from G to a graph
    """

    dataset = []

    for data in data1:
        counter = 0
        N_ROI = N_TARGET_NODES
        pos_edge_index = torch.zeros(2, N_ROI * N_ROI, dtype=torch.long)
        for i in range(N_ROI):
            for j in range(N_ROI):
                pos_edge_index[:, counter] = torch.tensor([i, j])
                counter += 1

        x = data
        pos_edge_index = torch.tensor(pos_edge_index, dtype=torch.long)
        data = Data(x=x, pos_edge_index= pos_edge_index, edge_attr=data.view(N_TARGET_NODES**2, 1))
        dataset.append(data)

    return dataset



In [16]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

aligner = Aligner()
aligner.to(DEVICE)
Aligner_optimizer = torch.optim.AdamW(aligner.parameters(), lr=0.025, betas=(0.5, 0.999))

## Choo Choo Train

In [17]:
aligner.train()

Aligner(
  (conv1): NNConv(160, 160, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=25600, bias=True)
    (1): ReLU()
  ))
  (conv11): BatchNorm(160)
  (conv2): NNConv(160, 1, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=160, bias=True)
    (1): ReLU()
  ))
  (conv22): BatchNorm(1)
  (conv3): NNConv(1, 160, aggr=mean, nn=Sequential(
    (0): Linear(in_features=1, out_features=160, bias=True)
    (1): ReLU()
  ))
  (conv33): BatchNorm(160)
)

In [20]:
lr_data = create_batch(lr_X_dim1, lr_train)
hr_data = create_batch(hr_X_dim1, hr_train)
N_TARGET_NODES = 268
nbre_epochs = 50

for epochs in range(nbre_epochs):
        # Train Generator
        with torch.autograd.set_detect_anomaly(True):
            Al_losses = []

            i = 0
            for data_source, data_target in zip(lr_data, hr_data):
                # print(i)
                # print(data_source.shape)
                # print(data_target.shape)
                # ************    Domain alignment    ************
                A_output = aligner(data_source)

                target = data_target
                target_mean = np.mean(target)
                target_std = np.std(target)

                d_target = torch.normal(target_mean, target_std, size=(1, N_SOURCE_NODES_F))
                dd_target = cast_data_vector_RH(d_target)
                dd_target = dd_target[0]
                target_d = dd_target.edge_attr.view(N_SOURCE_NODES, N_SOURCE_NODES)

                kl_loss = Alignment_loss(target_d, A_output)

                Al_losses.append(kl_loss)
                i += 1

            torch.cuda.empty_cache()
            Aligner_optimizer.zero_grad()
            Al_losses = torch.mean(torch.stack(Al_losses))
            Al_losses.backward(retain_graph=True)
            Aligner_optimizer.step()

        print("[Epoch: %d]| [Al loss: %f]|" % (epochs, Al_losses))


AttributeError: 'GlobalStorage' object has no attribute 'pos_edge_index'