In [None]:
import torch
import random
import math
from torch_geometric.data import DataLoader

In [None]:
from model import ccVAE
from utils import train

In [None]:
class Args():
    def __init__(self, z_dim = 10, w_dim = 10):
        self.z_dim = z_dim
        self.w_dim = w_dim
        self.n_nodes = 61
        self.decoder_w_mode = 'conv'
        self.cond_prior_w_mode = 'lookup'
        self.cond_prior_z_mode = 'lookup'
        self.use_cuda = 1
        self.n_classes = 2
        self.x_dim = 500
        self.elbo_coefs = [1.,2.,10.]

In [None]:
args = Args()

In [None]:
model = ccVAE(args).cuda()

In [None]:
dataset = torch.load('data/low_corr')
f = lambda t: -0.1 + 0.00909091*121**t 
feature_map = {0: [0,0,0], 1: [0,0,1], 2: [1,0,0], 3: [1,0,1], 4: [1,1,0], 5: [1,1,1]}
get_features = lambda x: torch.tensor(feature_map[x.item()])

dataset_ = []
for data in dataset:
    data.adj = f(data.adj)
    data.adj -= torch.eye(61).unsqueeze(0)
    data.y = get_features(data.y).unsqueeze(0).float()
    if data.y[:,0] == 1:
        data.y = data.y[:,1:]
        dataset_.append(data)
    
    
torch.manual_seed(12345)
random.shuffle(dataset_)

train_dataset = dataset_[:8*len(dataset_)//10]
test_dataset = dataset_[8*len(dataset_)//10:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True, drop_last=True)
test_loader = DataLoader(test_dataset[:500], batch_size=500, shuffle=False, drop_last=True)

#all_test_loader = DataLoader(test_dataset, batch_size=10**10, shuffle=False, drop_last=False)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

In [None]:
for i in range(100):
    _ = train(model, optimizer, train_loader, True, 'cuda', 1)
    if i % 5 == 0:
        qyxa = train(model, optimizer, test_loader, False, 'cuda', 1)

## Reconstruction

In [None]:
import torch.distributions as dist

In [None]:
data = next(iter(train_loader)).cuda()

In [None]:
with torch.no_grad():
    bs = data.adj.shape[0]
    #Z ~ q(Z|X)
    qz_x = dist.Normal(*model.encoder_x(data.x))
    z = qz_x.rsample([model.k])  
    #Z = {Z_c, Z_\c}
    zc, zs = z.split([model.z_classify, model.z_style], -1)
    #w ~ q(w|A)
    qw_a = dist.Normal(*model.encoder_a(data.adj))
    w = qw_a.rsample([model.k])
    #w = {w_c, w_\c}
    wc, ws = w.split([model.w_classify, model.w_style], -1)
    #log q(y|Z_c, w_c),  log q(y|X, A)
    qy_zc_wc = dist.ContinuousBernoulli(probs = model.classifier(zc, wc))
    y = qy_zc_wc.rsample()
    log_qy_zc_wc = qy_zc_wc.log_prob(y)
    log_qy_xa = model.classifier_loss(data)
    #log p(y)
    log_py = dist.ContinuousBernoulli(data.y).log_prob(y)

    #elbo_z
    px_za = dist.Normal(*model.decoder_za(z, data.adj))
    log_px_za = px_za.log_prob(data.x.view(bs,61,500))
    log_qz_x = qz_x.log_prob(z)

    locs_pzc_y, scales_pzc_y = model.cond_prior_z(data.y)
    prior_params_z = (torch.cat([locs_pzc_y, model.zeros_z.expand(bs, 61, -1)], dim=-1), 
                      torch.cat([scales_pzc_y, model.ones_z.expand(bs, 61, -1)], dim=-1))

    log_pz_y = dist.Normal(*prior_params_z).log_prob(z)

    kl_z = (log_qz_x - log_pz_y).mean(0).sum(-1).sum(-1)
    recon_x = log_px_za.sum(-1).sum(-1)
    elbo_z = recon_x - kl_z

    #elbo_w
    pa_w = dist.Normal(*model.decoder_w(w))
    log_pa_w = pa_w.log_prob(data.adj)
    log_qw_a = qw_a.log_prob(w)

    locs_pwc_y, scales_pwc_y = model.cond_prior_w(data.y)
    prior_params_w = (torch.cat([locs_pwc_y, model.zeros_w.expand(bs, -1)], dim=-1), 
                      torch.cat([scales_pwc_y, model.ones_w.expand(bs, -1)], dim=-1))

    log_pw_y = dist.Normal(*prior_params_w).log_prob(w)

    kl_w = (log_qw_a - log_pw_y).mean(0).sum(-1)
    recon_a = log_pa_w.sum(-1).sum(-1)
    elbo_w = recon_a - kl_w

### Reconstruction X

In [None]:
x_recon = px_za.loc.cpu().detach()
x = data.x.view(bs,61,500).cpu().detach()

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_theme()

In [None]:
sns.set_theme()
channels = random.sample(range(61), 2*3)
fig, ax = plt.subplots(3,2, figsize = [10,10], sharex=True)
for i in range(3):
    for j in range(2):
        k = random.choice(range(100))
        c = channels[i*2+j]
        ax[i][j].plot(x_recon[k, c], linewidth = 3, label = 'Reconstruction')
        ax[i][j].plot(x[k, c], alpha = 0.5, label = 'EEG recording')
        ax[i][j].set(xticklabels=[])
        ax[i][j].set(yticklabels=[])
fig.tight_layout()
ax[-1,-1].legend(prop={'size': 24}, loc='lower left', 
             bbox_to_anchor=(-0.85, -0.35), ncol=2)
#fig.savefig('recon_x_small', bbox_inches='tight')

In [None]:
n = 5

channels = sample(range(61), n*n)
fig, ax = plt.subplots(n,n, figsize = [20,20], sharex=True)
for i in range(n):
    for j in range(n):
        k = choice(range(100))
        c = channels[i*n+j]
        ax[i][j].plot(x_recon[k, c], linewidth = 3)
        ax[i][j].plot(x[k, c], alpha = 0.5)
        ax[i][j].set(xticklabels=[])
        ax[i][j].set(yticklabels=[])
fig.tight_layout()
#fig.savefig('recon_x')

### Reconstruction A

In [None]:
a_recon = pa_w.loc.cpu().detach()
a = data.adj.cpu().detach()

In [None]:
inds = random.sample(range(100), 2)
sns.set_theme()
sns.set_style("whitegrid", {'axes.grid' : False})
cmap = sns.color_palette("viridis", as_cmap=True)

fig = plt.figure(constrained_layout=True, figsize = [10,10])
subfigs = fig.subfigures(2, 1, hspace =  0.05, wspace = 0.05)

i = inds[0]
axs0 = subfigs[0].subplots(1, 2)
axs0[0].imshow(a[i], cmap = cmap, vmin = -0.1, vmax = 1.)
axs0[0].set(xticklabels=[])
axs0[0].set(yticklabels=[])
axs0[1].imshow(a_recon[i], cmap = cmap, vmin = -0.1, vmax = 1.)
axs0[1].set(xticklabels=[])
axs0[1].set(yticklabels=[])
subfigs[0].set_facecolor('0.9')

i = inds[1]
axs1 = subfigs[1].subplots(1, 2)
im = axs1[0].imshow(a[i], cmap = cmap, vmin = -0.1, vmax = 1.)
axs1[0].set(xticklabels=[])
axs1[0].set(yticklabels=[])
axs1[1].imshow(a_recon[i], cmap = cmap, vmin = -0.1, vmax = 1.)
axs1[1].set(xticklabels=[])
axs1[1].set(yticklabels=[])
subfigs[1].set_facecolor('0.9')

cbar_ax = fig.add_axes([1.03, 0.15, 0.05, 0.7])
fig.colorbar(im, cax=cbar_ax).ax.tick_params(labelsize=20)

plt.text(-20.1,-0.4,'Functional Connectivity', fontsize=28)
plt.text(-8.9,-0.4,'Reconstruction', fontsize=28)

#fig.savefig('recon_a_small', bbox_inches='tight')

## Densities

In [None]:
bs = 4
y = torch.tensor([[0,0], [0,1], [1,0], [1,1]]).float().cuda()

locs_pwc_y, scales_pwc_y = model.cond_prior_w(y)
prior_params_w = (torch.cat([locs_pwc_y, model.zeros_w.expand(bs, -1)], dim=-1), 
                  torch.cat([scales_pwc_y, model.ones_w.expand(bs, -1)], dim=-1))
pw_y = dist.Normal(*prior_params_w)

In [None]:
import pandas as pd
import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
sns.set_context("paper", rc={"xtick.labelsize":32,"axes.labelsize":48}) 

bs = 100
n_samples = 100
# Create the data
rs = np.random.RandomState(1979)
x = qw_a.sample([n_samples]).flip(-1).reshape(-1).cpu().detach().numpy() #rs.randn(500)
y = pw_y.sample([6250*2//5]).flip(-1).reshape(-1).cpu().detach().numpy()
g = np.tile(list("XYABCDEFGH")[::-1], n_samples*bs)
df = pd.DataFrame(dict(x=x, y=y, g=g))
m = df.g.map(ord)
df.g = df.g.replace({'X': 'schizophrenia', 'Y': 'hallucinations'})
df = df[df['x'].abs() < 6]
df = df[df['y'].abs() < 6]

# Initialize the FacetGrid object
pal = sns.cubehelix_palette(10, rot=-.25, light=.7)
g = sns.FacetGrid(df, row="g", hue="g", aspect=15, height=1.5, palette=pal)

# Draw the densities in a few steps
g.map(sns.kdeplot, "y",
      bw_adjust=.1, clip_on=False,
      fill=True, alpha=.5, linewidth=1.5, color = 'r')
g.map(sns.kdeplot, "x",
      bw_adjust=.5, clip_on=False,
      fill=True, alpha=.8, linewidth=1.5)
g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=2, bw_adjust=.5)

# passing color=None to refline() uses the hue mapping
g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)


# Define and use a simple function to label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(0, .4, label, fontweight="bold", color=color,
            ha="left", va="center", transform=ax.transAxes, fontsize = 32)


g.map(label, "x")

# Set the subplots to overlap
g.figure.subplots_adjust(hspace=-.4)

# Remove axes details that don't play well with overlap
g.set_titles("")
g.set(yticks=[], ylabel="")
g.set(xlabel="$\omega$")
g.despine(bottom=True, left=True)

g.set(xlim=(-6.5, 6.5))

fig = g.fig
fig.savefig('latent_space', bbox_inches='tight')