In [1]:
import os
import pickle
DATA_PATH = '../data/single_channel_nonlinear_mixing_tri_circ.pickle'
with open(DATA_PATH, 'rb') as outfile:
    data = pickle.load(outfile, encoding='latin1')

In [2]:
%config InlineBackend.figure_formats = ['svg']
import matplotlib.pyplot as plt
%matplotlib inline
from mpl_toolkits.axes_grid1 import ImageGrid
fig = plt.figure(figsize=(6, 6))
grid = ImageGrid(fig, 111,
                 nrows_ncols=(1, 3),
                 axes_pad=0.1,
                 )

idx = 0
sample, circle, triangle = data[idx]

labels = ['Mixed', 'Circle', 'Triangle']
for ax, im, label in zip(grid, [sample, circle, triangle], labels):
    ax.set_title(label)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(im, cmap='gray')

plt.show()

In [3]:
from omegaconf import OmegaConf
import torch

import sys
sys.path.append('..')
from experiments.triangles_circles import Experiment

# Bigger network
model_path_1 = '/2024-03-12/14-43-53/'
ckpt_name_1 = 'last.ckpt'

# Only two encoders
model_path_2 = '/2024-03-12/14-40-16/'
ckpt_name_2 = 'last.ckpt'

# Reference implementation
model_path_3 = '/2024-03-12/14-40-04/'
ckpt_name_3 = 'last.ckpt'

# load model
EXPR_PATH = f'/home/maxja/Uni/self-supervised-bss-via-multi-encoder-ae/outputs{model_path_2}logs/my_experiment/version_0/checkpoints/{ckpt_name_2}'
CONFIG_PATH = f'/home/maxja/Uni/self-supervised-bss-via-multi-encoder-ae/outputs{model_path_2}.hydra/config.yaml'

#EXPR_PATH = f'/home/maxja/Uni/self-supervised-bss-via-multi-encoder-ae/outputs/2024-03-12/14-40-16/logs/my_experiment/version_0/checkpoints/last.ckpt'
#CONFIG_PATH = f'/home/maxja/Uni/self-supervised-bss-via-multi-encoder-ae/outputs/2024-03-12/14-40-16/.hydra/config.yaml'

three_encoders = not EXPR_PATH.__contains__('/2024-03-12/14-40-16/')

config = OmegaConf.load(CONFIG_PATH)
experiment = Experiment.load_from_checkpoint(checkpoint_path=EXPR_PATH, config=config.experiment_config)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = experiment.model
model.to(device)
model.eval()

In [4]:
import torch
import matplotlib
import numpy as np

line_width = 0
matplotlib.rcParams['axes.linewidth'] = line_width
matplotlib.rcParams['ytick.major.width'] = line_width
matplotlib.rcParams['xtick.major.width'] = line_width
default_c = '#434343'
matplotlib.rcParams.update({'axes.titlesize':15,
                            'axes.labelsize': 16,
                            'text.color' : f'{default_c}',
                            'axes.labelcolor' : f'{default_c}'})

def min_max(x):
    return (x - np.min(x))/(np.max(x) - np.min(x))

for idx in range(10):
    sample, circle, triangle = data[idx]
    sample, circle, triangle = min_max(sample), min_max(circle), min_max(triangle)

    x = torch.tensor(sample, dtype=torch.float32).permute(2, 0, 1).unsqueeze_(0).to(device)
    pred, _ = model(x)
    pred = torch.sigmoid(pred).squeeze().unsqueeze(-1).detach().cpu().numpy()

    with torch.no_grad():
        z = model.encode(x)
        
        if three_encoders:
            z_a = [z[0], torch.zeros_like(z[1]), torch.zeros_like(z[0])]
            z_b = [torch.zeros_like(z[0]), z[1], torch.zeros_like(z[0])]
            z_c = [torch.zeros_like(z[0]), torch.zeros_like(z[0]), z[2]]
        else:
            z_a = [z[0], torch.zeros_like(z[1])]
            z_b = [torch.zeros_like(z[0]), z[1]]            
        
        y_a = model.decode(z_a)
        y_b = model.decode(z_b)

        if three_encoders:
            y_c = model.decode(z_c)
        
    x_pred_a = torch.sigmoid(y_a).squeeze().unsqueeze(-1).detach().cpu().numpy()
    x_pred_b = torch.sigmoid(y_b).squeeze().unsqueeze(-1).detach().cpu().numpy()

    if three_encoders:
        x_pred_c = torch.sigmoid(y_c).squeeze().unsqueeze(-1).detach().cpu().numpy()

    %config InlineBackend.figure_formats = ['svg']
    import matplotlib.pyplot as plt
    %matplotlib inline
    from mpl_toolkits.axes_grid1 import ImageGrid
    fig = plt.figure(figsize=(6, 6))
    grid = ImageGrid(fig, 111,
                    nrows_ncols=(2, 4),
                    axes_pad=0.15,
                    )

    labels = ['Mixed', 'Circle', 'Triangle']
    images = [sample, circle, triangle, None, pred, x_pred_c, x_pred_a, x_pred_b] if three_encoders else [sample, circle, triangle, None, pred, x_pred_a, x_pred_b] # order is switched for comparison reasons
    y_labels = ['True', 'Pred.']
    for i, (ax, im) in enumerate(zip(grid, images)):
        if i != 3:
            if i < len(labels):
                ax.set_title(labels[i])
            if i % 4 == 0:
                ax.set_ylabel(y_labels[(i)//4])
            if i+1 == len(images) and three_encoders:
                ax.set_title('(Dead Enc.)', color='gray', fontsize=12)
            ax.set_xticks([])
            ax.set_yticks([])
            ax.imshow(im, cmap='gray')

    plt.show()