In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import glob
import os
import shutil
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim
import metrics
from scipy.stats import pearsonr

def check_modular(points, num_thetas=500):
    # expects points to be : dim X N
    # Given a set of points generate all the lines and test whether the inequalities are satisfied
    points_demeaned = points - np.mean(points, axis=1, keepdims=True)
    extreme_points = np.min(np.abs(np.array([np.min(points_demeaned, axis=1), np.max(points_demeaned, axis=1)])),
                            axis=0)
    corr = np.mean(np.multiply(points_demeaned[0, :], points_demeaned[1, :]))
    S = np.array([[extreme_points[0] ** 2, -corr], [-corr, extreme_points[1] ** 2]])
    flag = 0

    thetas = np.arange(0, np.pi * 2, 2 * np.pi / num_thetas) + 0.01
    diffs = []
    for theta in thetas:
        w = np.array([np.cos(theta), np.sin(theta)])
        crit_value = w.T @ S @ w
        diff = np.min(w @ points_demeaned) ** 2 - crit_value
        diffs.append(diff)
        if diff < 0:
            flag = 1
    return {'mixed': flag,
            'S': S,
            'ds_demean': points_demeaned,
            'diffs': diffs,
            }

N = 32
d = 2
n_repeats = 100
n_epochs = 1000000
batch_size = 32
latent_dim = 16
w_reg = 1e-4
z_reg = 1e-3
z_nn = 5e-1


class AE(torch.nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(AE, self).__init__()

        self.linear1 = torch.nn.Linear(input_dim, latent_dim)
        # self.activation = torch.nn.ReLU()
        self.linear2 = torch.nn.Linear(latent_dim, input_dim)

    def forward(self, x):
        z = self.linear1(x)
        x_hat = self.linear2(z)
        return x_hat, z


def most_mixed_neuron(model):
    a, b = torch.abs(model.linear1.weight).detach().numpy().T
    keep = a + b > np.max(a + b) / 10
    mixed = (np.minimum(a, b) / (a + b))[keep]
    # return angle (from 0 of 90) in 1st quadrant:
    angle = np.arctan(a / b)[keep]
    angle = np.minimum(angle, np.pi / 2 - angle)

    return {'av_mixed': np.mean(mixed),
            'most_mixed': np.max(mixed),
            'most_angle': np.max(angle),
            'av_angle': np.mean(angle),
            }

# create your datset
results = {'mod':
               {'av_mixed': [],
                'most_mixed': [],
                'most_angle': [],
                'av_angle': [],
                'diffs': [],
                'multiinfo': [],
                'lcinfom': [],
                'pred_loss': [],
                'z_loss': [],
                'nn_loss': [],
                'weight_loss': []
                },
           'mix':
               {'av_mixed': [],
                'most_mixed': [],
                'most_angle': [],
                'av_angle': [],
                'diffs': [],
                'multiinfo': [],
                'lcinfom': [],
                'pred_loss': [],
                'z_loss': [],
                'nn_loss': [],
                'weight_loss': []
                },
           }

for repeat in range(n_repeats): # you will want to parralelise this...
    if repeat % 2 == 0:
        mixed = 1
        while mixed == 1:
            dataset = np.random.rand(N, d).astype(np.float32)
            mod_res = check_modular(dataset.T)
            mixed = mod_res['mixed']
    else:
        mixed = 0
        while mixed == 0:
            dataset = np.random.rand(N, d).astype(np.float32)
            mod_res = check_modular(dataset.T)
            mixed = mod_res['mixed']

    sources = dataset
    corr_h = pearsonr(sources[:, 0], sources[:, 1])[0]
    sources = metrics.discretize_binning(sources, bins='auto')
    mi_h = metrics.normalized_multiinformation(sources)

    msg = "repeat={:.2f}, mixed={:.2f}, multi_info={:.2f}".format(repeat, mixed, mi_h)
    print(msg)

    results_run = {'av_mixed': [],
                   'most_mixed': [],
                   'most_angle': [],
                   'av_angle': [],
                   'diffs': mod_res['diffs'],
                   'multiinfo': mi_h,
                   'lcinfom': [],
                   'pred_loss': [],
                   'z_loss': [],
                   'nn_loss': [],
                   'weight_loss': []
                   }

    my_dataset = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    loss = torch.nn.MSELoss()
    model = AE(d, latent_dim)
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    for epoch in range(n_epochs):
        for X_batch in my_dataset:
            for param in model.parameters():
                param.grad = None
            x_hat, z = model(X_batch)
            pred_loss = loss(X_batch, x_hat)
            z_loss = 0.5 * torch.mean(torch.sum(z ** 2, dim=1))
            nn_loss = torch.mean(torch.sum(torch.nn.ReLU()(-z), dim=1))
            weight_l2 = 0.0
            for name, p in model.named_parameters():
                if 'weight' in name:
                    weight_l2 += 0.5 * (p ** 2).sum()
            loss_tot = pred_loss + z_reg * z_loss + z_nn * nn_loss + w_reg * weight_l2

            loss_tot.backward()
            optimizer.step()

            losses = {'pred_loss': pred_loss.detach().numpy(),
                      'z_loss': z_loss.detach().numpy(),
                      'nn_loss': nn_loss.detach().numpy(),
                      'weight_loss': weight_l2.detach().numpy()}

        if epoch % 1000 == 0:
            res = most_mixed_neuron(model)
            for key, val in res.items():
                results_run[key].append(val)
            for key, val in losses.items():
                results_run[key].append(val)
        if epoch % 10000 == 0:
            print('', end='.')
            # this is slow to computre
            latents = z.detach().numpy()
            sources = X_batch.detach().numpy()
            lcinfom = metrics.compute_linear_metrics(sources, latents, 'continuous', 'continuous')
            results_run['lcinfom'].append(lcinfom['linear_cinfom'])

    for key, val in results_run.items():
        results['mix' if mixed == 1 else 'mod'][key].append(val)

    msg = "repeat={:.2f}, ".format(repeat) + ''.join(
        f'{key}={str(val[-1])[:7]}, ' for key, val in results_run.items() if key not in ['diffs', 'multiinfo'])
    print(msg)

save_path = '.' # choose path to save to
np.save(save_path + '/results_all' + '.npy', results)

In [None]:
import os
import pickle
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from skimage.transform import resize
import itertools
from itertools import repeat 
from disentangled_rnn_utils import DotDict as Dd
from scipy import stats

import seaborn
seaborn.set_style(style='white')

In [None]:
import os
save_path = '.'
results = np.load(save_path + '/results_all.npy', allow_pickle=True).item()

In [None]:
# MOD VS MIXED : VS DIFF / MULTINFO
no_axes = False
num_plots = 4
s = 10
figsize = (2.5,2)
for key in results['mix'].keys():
    if key not in ['most_angle']:
        continue
    if key in ['diffs', 'multiinfo']:
        continue
    plt.figure(figsize=figsize)

    x = results['mod']['multiinfo']
    y = np.array(results['mod'][key])[:,-1]
    plt.scatter(x, y, label='mod', s=s)
    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))

    x = results['mix']['multiinfo']
    y = np.array(results['mix'][key])[:,-1]
    plt.scatter(x, y, label='mix', s=s)
    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))
    if no_axes:
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.gca().spines['left'].set_visible(False)
        plt.gca().spines['bottom'].set_visible(False)
        plt.tick_params(left=False, right=False, labelleft=False, labelright=False, bottom=False, top=False)
    plt.xlabel('Normalised Source Multiinformation')
    plt.ylabel("Most Mixed Neuron's Angle")
    plt.title('Correlation: ' + str(np.round(stats.pearsonr(x, y).statistic, 3)) + ' , ' + 'p=' + str(np.round(stats.pearsonr(x, y).pvalue, 3)))

    plt.savefig('nsmi.png', bbox_inches='tight', dpi=300)

    plt.figure(figsize=figsize)

    x = -np.min(np.array(results['mod']['diffs']), axis=1)
    y = np.array(results['mod'][key])[:,-1]
    plt.scatter(x, y, label='mod', s=s)
    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))

    x = -np.min(np.array(results['mix']['diffs']), axis=1)
    y = np.array(results['mix'][key])[:,-1]
    plt.scatter(x, y, label='mix', s=s)
    plt.plot(np.unique(x), np.poly1d(np.polyfit(x, y, 1))(np.unique(x)))

    if no_axes:
        plt.gca().spines['top'].set_visible(False)
        plt.gca().spines['right'].set_visible(False)
        plt.gca().spines['left'].set_visible(False)
        plt.gca().spines['bottom'].set_visible(False)
        plt.tick_params(left=False, right=False, labelleft=False, labelright=False, bottom=False, top=False)
    plt.xlabel('Mixing Energy Gain from Theory')
    plt.ylabel("Most Mixed Neuron's Angle")
    plt.title('Correlation: ' + str(np.round(stats.pearsonr(x, y).statistic, 3)) + ' , ' + 'p=' + str(np.round(stats.pearsonr(x, y).pvalue, 3)))

    plt.savefig('energy.png', bbox_inches='tight', dpi=300)