In [23]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [24]:
# Imports
import torch
import torch.nn as nn
import pandas as pd
import copy
import matplotlib.pyplot as plt
import numpy as np

In [25]:
from vae import VAE
from vae import GroupSoftmax
from trainer import Trainer

In [26]:
# Device config 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")

Device: cuda


In [27]:
data_name = 'one_hot_noNaNremoved.csv'
model_name = 'model_agep_fNaN.pth'


In [28]:
# load data 
data = pd.read_csv(f'/workspace/data/{data_name}')
data_tensor = torch.tensor(data.values, dtype=torch.float32)
# load model 
cols = list(data.columns)
cols = [col.split(":")[0] for col in cols]

onehot_counts = {col: sum(data.columns.str.startswith(f"{col}:")) for col in cols}
group_sizes = list(onehot_counts.values())

model = VAE(526, 1500, 6, 500, group_sizes)
model.load_state_dict(torch.load(f'/workspace/models/{model_name}'))
model.eval()

VAE(
  (encoder): Encoder(
    (layers): Sequential(
      (0): MLPBlock(
        (layers): Sequential(
          (0): Linear(in_features=526, out_features=1500, bias=True)
          (1): BatchNorm1d(1500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU()
        )
      )
      (1): ResidualBlock(
        (mlp1): MLPBlock(
          (layers): Sequential(
            (0): Linear(in_features=1500, out_features=1500, bias=True)
            (1): BatchNorm1d(1500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
        (mlp2): MLPBlock(
          (layers): Sequential(
            (0): Linear(in_features=1500, out_features=1500, bias=True)
            (1): BatchNorm1d(1500, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): ReLU()
          )
        )
      )
      (2): ResidualBlock(
        (mlp1): MLPBlock(
          (layers): Sequential(
            (0): Linea

In [29]:
# sample from the model 
n_samples = len(data_tensor)

# sample from the model
sample = model.pretrain_sample(n_samples)
predicted = torch.mean(sample, dim=0).detach().cpu().numpy()

# apply the group softmax to the data
group_softmax = GroupSoftmax(group_sizes)
data_softmax = group_softmax(data_tensor)
ground_truth = torch.mean(data_softmax, dim=0).detach().cpu().numpy()

print(f'ground_truth: {ground_truth.shape}')
print(f'predicted: {predicted.shape}')


ground_truth: (526,)
predicted: (526,)


In [30]:
# prepare dicts and namings for plotting
labels = list(data.columns)
labels = [label.split(":") for label in labels]
feature_label_dict = {}
for key, label in labels:
    if key in feature_label_dict:
        feature_label_dict[key].append(label) 
    else:
        feature_label_dict[key] = [label]

indices = {}

for key, label in labels:
    if key in indices:
        indices[key].append(data.columns.get_loc(f"{key}:{label}"))
    else:
        indices[key] = [data.columns.get_loc(f"{key}:{label}")]

print(feature_label_dict)
print(indices)
#data_dict = {data.columns}

{'TEN': ['owned or mortgaged', 'rented', 'nan'], 'HHL': ['english', 'spanish', 'other indo-european', 'asian and pacific island languages', 'other', 'nan'], 'VEH': ['no vehicles', '1 vehicle', '2 vehicles', '3 vehicles', '4 or more vehicles', 'nan'], 'HINCP': ['under 5k', '5k-10k', '10k-15k', '15k-20k', '20k-25k', '25k-35k', '35k-50k', '50k-75k', '75k-100k', '100k-150k', '150k+', 'nan'], 'R65': ['no', 'yes', 'nan'], 'R18': ['no', 'yes', 'nan'], 'SEX_1': ['male', 'female', 'nan'], 'SEX_2': ['male', 'female', 'nan'], 'SEX_3': ['male', 'female', 'nan'], 'SEX_4': ['male', 'female', 'nan'], 'SEX_5': ['male', 'female', 'nan'], 'SEX_6': ['male', 'female', 'nan'], 'SEX_7': ['male', 'female', 'nan'], 'SEX_8': ['male', 'female', 'nan'], 'SEX_9': ['male', 'female', 'nan'], 'SEX_10': ['male', 'female', 'nan'], 'SEX_11': ['male', 'female', 'nan'], 'SEX_12': ['male', 'female', 'nan'], 'SEX_13': ['male', 'female', 'nan'], 'SEX_14': ['male', 'female', 'nan'], 'SEX_15': ['male', 'female', 'nan'], 'SEX_

In [31]:
ten_dict = {key: [] for key in feature_label_dict['TEN']}
hhl_dict = {key: [] for key in feature_label_dict['HHL']}
veh_dict = {key: [] for key in feature_label_dict['VEH']}
hincp_dict = {key: [] for key in feature_label_dict['HINCP']}
sex_dict = {key: [] for key in feature_label_dict['SEX_1']}
education_dict = {key: [] for key in feature_label_dict['SCHL_1']}
age_dict = {key: [] for key in feature_label_dict['AGEP_1']}
r18_dict = {key: [] for key in feature_label_dict['R18']}
r65_dict = {key: [] for key in feature_label_dict['R65']}

# create dicts for the groups and their indices 
for key, labels in feature_label_dict.items():
    if 'SEX' in key:
        for idx, label in enumerate(labels):
            sex_dict[label].append(indices[key][idx])
    elif 'SCHL' in key:
        for idx, label in enumerate(labels):
            education_dict[label].append(indices[key][idx])
    elif 'AGEP' in key:
        for idx, label in enumerate(labels):
            age_dict[label].append(indices[key][idx])
    elif 'TEN' in key:
        for idx, label in enumerate(labels):
            ten_dict[label].append(indices[key][idx])
    elif 'HHL' in key:
        for idx, label in enumerate(labels):
            hhl_dict[label].append(indices[key][idx])
    elif 'VEH' in key:
        for idx, label in enumerate(labels):
            veh_dict[label].append(indices[key][idx])
    elif 'HINCP' in key:
        for idx, label in enumerate(labels):
            hincp_dict[label].append(indices[key][idx])
    elif 'R18' in key:
        for idx, label in enumerate(labels):
            r18_dict[label].append(indices[key][idx])
    elif 'R65' in key:
        for idx, label in enumerate(labels):
            r65_dict[label].append(indices[key][idx])




In [32]:
def expected_count(data, dict):
    counts = np.empty((len(dict.keys())-1,len(data)))
    for i, key in enumerate(dict.keys()):
        if key == 'nan':
            break
        for idx, row in enumerate(data):
            tmp = row[dict[key]]
            nans = row[dict['nan']]
            
            # rule out the nans equal to 1
            valid_mask = (nans != 1)
            counts[i, idx] = np.mean(tmp[valid_mask]/(1-nans[valid_mask]))
    return np.mean(counts, axis=1)


In [33]:
# rewrite the cell above into a function
def plot_group(group_name, ground_truth, predicted, label_dict, dict):
    fig, ax = plt.subplots()
    labels = copy.deepcopy(label_dict)
    gt_mean = expected_count(ground_truth, dict)
    pr_mean = expected_count(predicted, dict)
    x = np.arange(len(labels[group_name])-1) # remove nan  
    width = 0.35
    labels[group_name].remove('nan')
    keys = labels[group_name]
    ax.bar(x - width/2, gt_mean, width, label='Ground Truth')
    ax.bar(x + width/2, pr_mean, width, label='Predicted')
    ax.set_ylabel('Probability')
    ax.set_title(group_name)
    ax.set_xticks(x)
    ax.set_xticklabels(keys)
    ax.legend()
    plt.xticks(rotation=70)
    plt.tight_layout()
    plt.savefig(f'/workspace/GNN_Project/plots/{group_name}.png')
    plt.close()

In [34]:
ground_truth = data_softmax.detach().cpu().numpy()
predicted = sample.detach().cpu().numpy()

plot_group('TEN', ground_truth, predicted, feature_label_dict, ten_dict)
plot_group('HHL', ground_truth, predicted, feature_label_dict, hhl_dict)
plot_group('VEH', ground_truth, predicted, feature_label_dict, veh_dict)
plot_group('HINCP', ground_truth, predicted, feature_label_dict, hincp_dict)
plot_group('R18', ground_truth, predicted, feature_label_dict, r18_dict)
plot_group('R65', ground_truth, predicted, feature_label_dict, r65_dict)

In [35]:
# plot the sex and education group without the average 
plot_group('SEX_1', ground_truth, predicted, feature_label_dict, sex_dict)
plot_group('SCHL_1', ground_truth, predicted, feature_label_dict, education_dict)
plot_group('AGEP_1', ground_truth, predicted, feature_label_dict, age_dict)