In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.chdir('../')

In [3]:
import numpy as np
import os
import torch
from torch import nn
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from torch.utils.data import TensorDataset, DataLoader
from torch_geometric.data import DataLoader, Dataset
from copy import deepcopy, copy
import random
import pickle
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Subset, random_split
import glob
from PIL import Image
from collections import Counter
import torch_geometric
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
import torch.nn.functional as F
from torch.nn import Linear, ReLU, BatchNorm1d, Module, Sequential
from torch_geometric.nn import global_mean_pool, global_max_pool, global_sort_pool, global_add_pool
from torch_scatter import scatter
from torch_geometric.data import Data
from torch_geometric.utils import to_undirected
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, f1_score
import wandb

In [4]:
from experiments import hierarchy
from experiments.cbm_models import *
from experiments.dataset import *

## Load Hierarchies in 

In [5]:
dataset = "dsprites"

In [6]:
def load_hierarchy(hierarchy_name,random_seed):
    cub_hierarchy = hierarchy.Hierarchy()
    cub_hierarchy.from_array(np.array(np.load(open("results/hierarchies/{}_cub.npy".format(hierarchy_name),"rb"))),hierarchy.cub_attributes) 
    cub_groups = hierarchy.traverse_hierarchy(cub_hierarchy.root_split)
    cub_groups_idx = [[hierarchy.cub_attributes.index(j) for j in group] for group in cub_groups]
    
    embedding_matrix = np.load(open("concept_vectors/{}_{}.npy".format(hierarchy_name,random_seed),"rb")) # TODO: Change this back to label
    sim_matrix = cosine_similarity(embedding_matrix)
    
    return cub_groups, cub_groups_idx, sim_matrix

## Run Model

In [7]:
if dataset == 'cub':
    train_sequential, val_sequential, test_sequential = load_cub_sequential()
    train_fixed, val_fixed, test_fixed = load_cub_fixed()
elif dataset == 'dsprites':
    train_sequential, val_sequential, test_sequential = load_dataset("dsprites",False)
    train_fixed, val_fixed, test_fixed = load_dataset("dsprites",True)
elif dataset == 'chexpert':
    train_sequential, val_sequential, test_sequential = load_dataset("chexpert",False)
    train_fixed, val_fixed, test_fixed = load_dataset("chexpert",True)

In [8]:
def run_model(model_type, hyperparameters, train, val, test=None, pretrain=False, use_wandb=False,weights={}):
    model = initialize_model(model_type,hyperparameters,dataset,pretrain=pretrain,use_wandb=use_wandb,weights=weights)
    
    if torch.cuda.is_available():
        model = model.cuda()
    
    model = train_model(model,model_type,train,val,hyperparameters,pretrain=pretrain,use_wandb=use_wandb)

    
    if use_wandb:
        test_score = eval_model(model,model_type,test)
        wandb.log({"test_acc": test_score[1]})
        if len(test_score) == 3:
            wandb.log({'test_auc': test_score[2]})
        wandb.finish()
    return model

## CUB Dataset

In [9]:
if dataset == 'cub':
    bottleneck_size = 112
    output_classes = 200
elif dataset == 'dsprites':
    bottleneck_size = 18
    output_classes = 100
elif dataset == 'chexpert':
    bottleneck_size = 13
    output_classes = 2

In [10]:
baseline_hyperparameters = {
    'lr': None,
    'epochs': 10,
    'num_layers': 1,
    'emb_dim': 64,
    'in_dim': bottleneck_size, 
    'out_dim': output_classes, 
    'edge_dim': 1,
    'sim_matrix': None,
    'attributes': hierarchy.cub_attributes, 
    'group': None,
    'indexes': None,
}

In [11]:
def run_experiment(model_type, hierarchy_name, random_seed, pretrain=True, use_fixed=False):
    if dataset != 'cub':
        if model_type != 'mlp': 
            raise Exception("Need to load hierarchies with {}".format(dataset))
        else:
            groups = [[str(i)] for i in range(bottleneck_size)]
            groups_idx = [[i] for i in range(bottleneck_size)]
            sim_matrix = np.ones((bottleneck_size,bottleneck_size))
    else:
        groups, groups_idx, sim_matrix = load_hierarchy(hierarchy_name,random_seed)
    
    hyperparameters = copy(baseline_hyperparameters)
    hyperparameters['sim_matrix'] = sim_matrix
    hyperparameters['indexes'] = groups_idx
    hyperparameters['group'] = groups
    hyperparameters['seed'] = random_seed
    hyperparameters['hierarchy_name'] = hierarchy_name
    hyperparameters['pretrain'] = pretrain
    
    random.seed(random_seed)
    np.random.seed(random_seed)
    torch.manual_seed(random_seed)
    
    
    hyperparameters = update_hyperparameters_graph(model_type,hyperparameters)        
    lr_values = [1e-4,5e-4,1e-3,5e-3,1e-2,5e-2,1e-1]
    
    if pretrain and 'gnn' in model_type:
        if use_fixed:
            train = train_fixed
            val = val_fixed
            test = test_fixed
        else:
            train = train_sequential
            val = val_sequential
            test = test_sequential

        if 'gnn' in model_type:
            edge_attr, edge_index, train, val, test = get_dataset_graph_cub_pretrain(model_type,cub_groups,hierarchy.cub_attributes,sim_matrix
                                                                     , train, val, test) 

        hyperparameters['epochs'] = 5
        hyperparameters['lr'] = 1e-2
        pretrain_model = run_model(model_type, hyperparameters, train, val, pretrain=True, use_wandb=False)
    
        state_dict = pretrain_model.state_dict()
        new_state_dict = {}

        for key,value in state_dict.items():
            if key.startswith('convs.0'):
                new_state_dict[key] = value
            elif key.startswith('lin_in'):
                new_state_dict[key] = value

    if use_fixed:
        train = train_fixed
        val = val_fixed
        test = test_fixed
    else:
        train = train_sequential
        val = val_sequential
        test = test_sequential

    if 'gnn' in model_type:
        edge_attr, edge_index, train, val, test = get_dataset_graph_cub(model_type,cub_groups,
                                                                        hierarchy.cub_attributes,sim_matrix,
                                                                        train, val, test)
    
    hyperparameters['epochs'] = 2
    score_by_lr = {}

    # TODO: Switch test to val
    for lr in lr_values:
        hyperparameters['lr'] = lr
        model = run_model(model_type, hyperparameters, train, val, pretrain=False, use_wandb=False)
        score_by_lr[lr] = float(eval_model(model,model_type,val,pretrain=False)[0].cpu().detach().numpy())
        
    model_lr = min(score_by_lr, key=score_by_lr.get)
    
    hyperparameters['lr'] = model_lr
    hyperparameters['epochs'] = 10
    
    if pretrain:
        final_model = run_model(model_type, hyperparameters, train, val, test=test, use_wandb=True, pretrain=False,weights=new_state_dict) 
    else:
        final_model = run_model(model_type, hyperparameters, train, val, test=test, use_wandb=True, pretrain=False)     
    

In [None]:
for random_seed in [43,44,45]:
    run_experiment('mlp',"shapley",random_seed,pretrain=False)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 7.7 K 
---------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 7.7 K 
---------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated mod

  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 7.7 K 
---------------------------------------
7.7 K     Trainable params
0         Non-trainable params
7.7 K     Total params
0.031     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


In [None]:
for random_seed in [43,44,45]:
    run_experiment('mlp',"shapley",random_seed,pretrain=False,use_fixed=True)

In [12]:
for random_seed in [43,44,45]:
    for hierarchy_name in ['shapley','cem','labels']:
        run_experiment('gnn',hierarchy_name,random_seed,pretrain=True)

To keep the current behavior, use torch.div(a, b, rounding_mode='trunc'), or for actual floor division, use torch.div(a, b, rounding_mode='floor'). (Triggered internally at  ../aten/src/ATen/native/BinaryOps.cpp:467.)
  return torch.floor_divide(other, self)
[34m[1mwandb[0m: Currently logged in as: [33mnavr414[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.111 MB of 0.111 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▅▅▃▃▂▂▃▂▂▂▁▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▇▇▇████
train_loss,█▄▃▂▂▁▁▁▁▁
val_acc,▁▅▆▇▇▇▇███
val_loss,█▃▁▁▁▁▂▂▂▂

0,1
loss,0.27236
test_acc,0.63721
train_acc,0.91681
train_loss,0.2468
val_acc,0.64107
val_loss,1.90749




VBox(children=(Label(value='0.013 MB of 0.108 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.123289…

0,1
loss,█▇▆▄▄▃▂▂▂▂▂▂▂▂▂▁▁▂▂▂▂▁▂▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▇▇█████
train_loss,█▄▂▂▁▁▁▁▁▁
val_acc,▁▅▇▇▇█████
val_loss,█▃▂▁▁▁▂▁▁▁

0,1
loss,0.53315
test_acc,0.62409
train_acc,0.90388
train_loss,0.30513
val_acc,0.6202
val_loss,1.93992




VBox(children=(Label(value='0.013 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.119110…

0,1
loss,█▆▃▂▂▂▂▁▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▃▅▅▆▇▇███
train_loss,█▅▄▃▂▂▁▁▁▁
val_acc,▁▃▅▆▇▇▇███
val_loss,▁▄▄▅▄▇▇▇▆█

0,1
loss,0.10407
test_acc,0.65516
train_acc,0.94996
train_loss,0.1499
val_acc,0.66528
val_loss,1.89768




VBox(children=(Label(value='0.111 MB of 0.111 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▇▅▄▃▂▂▂▁▁▂▂▁▁▁▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▆▇▇████
train_loss,█▄▃▂▂▁▁▁▁▁
val_acc,▁▅▆▆▇▇████
val_loss,█▂▁▃▄▃▅▃▄▃

0,1
loss,0.19843
test_acc,0.62893
train_acc,0.91451
train_loss,0.25118
val_acc,0.64107
val_loss,1.96855




VBox(children=(Label(value='0.013 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.119274…

0,1
loss,█▇▅▄▃▃▃▂▂▂▂▂▁▂▂▁▁▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁
test_acc,▁
train_acc,▁▅▇▇▇█████
train_loss,█▄▂▂▁▁▁▁▁▁
val_acc,▁▆▆▇▇█████
val_loss,█▃▂▁▁▁▁▂▁▁

0,1
loss,0.13532
test_acc,0.61667
train_acc,0.89158
train_loss,0.33345
val_acc,0.61352
val_loss,1.90638




VBox(children=(Label(value='0.013 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.119020…

0,1
loss,█▇▅▃▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▆▆▇▇▇████
train_loss,█▃▂▂▁▁▁▁▁▁
val_acc,▁▆▆██▇████
val_loss,█▂▁▁▂▂▂▂▂▂

0,1
loss,0.28474
test_acc,0.65775
train_acc,0.93786
train_loss,0.19874
val_acc,0.65359
val_loss,1.71662




VBox(children=(Label(value='0.013 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.119210…

0,1
loss,█▇▄▃▃▂▂▂▂▂▂▂▂▂▁▁▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▆▇▇████
train_loss,█▄▃▂▂▂▁▁▁▁
val_acc,▁▅▆▇▇▇████
val_loss,█▂▁▁▁▂▁▂▂▂

0,1
loss,0.2072
test_acc,0.63497
train_acc,0.91681
train_loss,0.24562
val_acc,0.64608
val_loss,1.87812




VBox(children=(Label(value='0.112 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▇▆▄▃▃▂▂▂▂▁▂▂▂▂▂▁▁▁▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▇▇█████
train_loss,█▃▂▂▁▁▁▁▁▁
val_acc,▁▆▇▇▇█████
val_loss,█▂▁▁▁▂▁▁▁▁

0,1
loss,0.29461
test_acc,0.63255
train_acc,0.91118
train_loss,0.27971
val_acc,0.64691
val_loss,1.81658




VBox(children=(Label(value='0.112 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▆▄▂▂▂▁▂▁▁▁▂▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▄▆▆▇▇███
train_loss,█▄▄▃▂▂▁▁▁▁
val_acc,▁▃▄▆▆▇██▇█
val_loss,▁▁▁▄▅▇▅▇▇█

0,1
loss,0.37437
test_acc,0.65326
train_acc,0.95267
train_loss,0.14585
val_acc,0.66361
val_loss,1.97987


In [13]:
for random_seed in [43,44,45]:
    for hierarchy_name in ['shapley','cem','labels']:
        run_experiment('gnn',hierarchy_name,random_seed,pretrain=False)



VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122541…

0,1
loss,█▆▄▄▂▃▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂
test_acc,▁
train_acc,▁▅▆▆▇▆████
train_loss,█▄▃▂▂▂▁▁▁▁
val_acc,▁▅▆▆▇▆▇███
val_loss,█▂▂▁▂▃▁▂▂▁

0,1
loss,0.20485
test_acc,0.65395
train_acc,0.93224
train_loss,0.20188
val_acc,0.66528
val_loss,1.69238




VBox(children=(Label(value='0.014 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.125129…

0,1
loss,██▆▄▃▂▂▂▁▂▁▂▂▁▂▂▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▇▇▇▇███
train_loss,█▄▃▂▂▂▁▁▁▁
val_acc,▁▅▆▇▇▇▇███
val_loss,█▂▃▂▂▂▂▁▁▁

0,1
loss,0.24114
test_acc,0.64463
train_acc,0.91681
train_loss,0.26453
val_acc,0.6586
val_loss,1.64713




VBox(children=(Label(value='0.013 MB of 0.113 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.118326…

0,1
loss,█▅▃▂▂▁▁▂▂▁▁▁▁▂▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▃▅▅▆▆▇███
train_loss,█▅▄▃▃▂▁▁▁▁
val_acc,▁▄▆▅▆▇████
val_loss,▅▆▁▆█▄█▇▆▇

0,1
loss,0.33473
test_acc,0.6757
train_acc,0.956
train_loss,0.13682
val_acc,0.67613
val_loss,1.73908




VBox(children=(Label(value='0.112 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▆▄▃▃▃▂▂▂▂▂▂▁▁▁▁▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁
test_acc,▁
train_acc,▁▅▆▆▇▇████
train_loss,█▄▃▃▂▂▁▁▁▁
val_acc,▁▄▆▆▆▇████
val_loss,█▃▃▄▂▁▂▂▂▂

0,1
loss,0.25583
test_acc,0.65482
train_acc,0.92807
train_loss,0.20175
val_acc,0.66194
val_loss,1.75263




VBox(children=(Label(value='0.112 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,██▆▄▃▂▂▂▂▂▂▂▂▁▁▂▁▁▂▁▁▁▁▁▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▆▆▇▇▇████
train_loss,█▃▂▂▂▁▁▁▁▁
val_acc,▁▅▆▆▇▇████
val_loss,█▂▁▁▂▁▁▁▁▁

0,1
loss,0.20051
test_acc,0.65136
train_acc,0.92348
train_loss,0.24073
val_acc,0.67112
val_loss,1.68969




VBox(children=(Label(value='0.113 MB of 0.113 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▅▄▂▂▂▂▂▂▂▁▁▁▁▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▅▆▆▇▇███
train_loss,█▅▄▃▂▂▁▁▁▁
val_acc,▁▃▄▆▆▆▇▇██
val_loss,█▄▁▆▂▅▃▅▅▆

0,1
loss,0.3035
test_acc,0.67553
train_acc,0.95163
train_loss,0.13896
val_acc,0.6828
val_loss,1.68317




VBox(children=(Label(value='0.013 MB of 0.108 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.123104…

0,1
loss,█▄▄▃▂▂▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▁▅▅▆▇▇███
train_loss,██▄▄▃▂▂▁▁▁
val_acc,▁▁▆▆▆█▇███
val_loss,▅█▃▂▂▁▁▂▁▂

0,1
loss,0.22553
test_acc,0.65154
train_acc,0.92952
train_loss,0.19254
val_acc,0.65776
val_loss,2.01517




VBox(children=(Label(value='0.112 MB of 0.112 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▇▆▄▂▂▂▃▂▁▂▂▁▁▂▁▁▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▅▆▇▇▇▇███
train_loss,█▃▂▂▂▁▁▁▁▁
val_acc,▁▆▇▇▇█████
val_loss,█▂▂▁▂▁▂▁▁▁

0,1
loss,0.29841
test_acc,0.64273
train_acc,0.92264
train_loss,0.2535
val_acc,0.64441
val_loss,1.66965




VBox(children=(Label(value='0.013 MB of 0.113 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.118064…

0,1
loss,█▅▃▃▂▂▂▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▅▆▆▇▇███
train_loss,█▄▃▃▂▂▁▁▁▁
val_acc,▁▅▅▆▆█▇▇██
val_loss,█▁▄▅▃▃▄▅▄▄

0,1
loss,0.2576
test_acc,0.68088
train_acc,0.95309
train_loss,0.14136
val_acc,0.68447
val_loss,1.62531


In [14]:
for random_seed in [43,44,45]:
    for hierarchy_name in ['shapley','cem','labels']:
        run_experiment('gnn_basic',hierarchy_name,random_seed,pretrain=True)



VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122860…

0,1
loss,█▅▅▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▅▆▆▇▇▇██
train_loss,█▄▄▃▃▂▁▁▁▁
val_acc,▁▄▄▆▆▇█▇██
val_loss,█▆▆▄▃▂▂▂▁▁

0,1
loss,0.17939
test_acc,0.67708
train_acc,0.93703
train_loss,0.18938
val_acc,0.6803
val_loss,1.61988




VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122888…

0,1
loss,█▆▄▃▃▂▂▂▃▂▂▂▂▂▂▁▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▆▆▆▇▇███
train_loss,█▅▃▃▂▂▂▁▁▁
val_acc,▁▅▆▅▆▇▇███
val_loss,█▅▃▆▅▃▃▂▂▁

0,1
loss,0.57092
test_acc,0.6776
train_acc,0.93453
train_loss,0.19384
val_acc,0.67947
val_loss,1.60157




VBox(children=(Label(value='0.013 MB of 0.105 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.126815…

0,1
loss,█▆▅▃▂▂▂▂▂▂▂▁▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▅▆▆▆▇███
train_loss,█▅▄▃▃▂▂▁▁▁
val_acc,▁▃▅▅▅▆▇▇██
val_loss,██▅▇▇▆▃▂▁▁

0,1
loss,0.1745
test_acc,0.67415
train_acc,0.93474
train_loss,0.18876
val_acc,0.68114
val_loss,1.65302




VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122425…

0,1
loss,█▆▄▃▂▂▂▂▂▁▂▁▁▁▁▂▂▁▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▃▅▆▆▇▇███
train_loss,█▅▄▃▃▂▂▁▁▁
val_acc,▁▃▅▅▅▇▇▇██
val_loss,█▆▄▅▇▃▄▁▂▁

0,1
loss,0.1875
test_acc,0.67466
train_acc,0.93369
train_loss,0.19572
val_acc,0.6828
val_loss,1.67513




VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122780…

0,1
loss,█▆▃▃▂▂▂▂▁▁▂▂▁▂▂▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▃▅▆▆▆▇▇██
train_loss,█▆▄▂▂▂▂▁▁▁
val_acc,▁▄▆▇▆▇▇███
val_loss,█▇▄▄▄▄▂▂▁▁

0,1
loss,0.08717
test_acc,0.67846
train_acc,0.93474
train_loss,0.1926
val_acc,0.67279
val_loss,1.67224




VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122438…

0,1
loss,█▆▄▃▃▂▂▂▁▂▁▂▁▂▂▁▁▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▄▅▆▆▇███
train_loss,█▅▄▃▂▂▂▁▁▁
val_acc,▁▄▅▆▆▇▇███
val_loss,▇▅▇█▅▃▃▁▁▁

0,1
loss,0.30615
test_acc,0.67639
train_acc,0.93641
train_loss,0.18978
val_acc,0.67279
val_loss,1.67215




VBox(children=(Label(value='0.109 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
loss,█▇▄▃▂▂▂▂▂▂▃▂▂▂▁▂▂▂▂▂▁▁▁▁▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▄▆▆▇▇███
train_loss,█▅▅▃▃▂▂▁▁▁
val_acc,▁▄▄▆▇▇▆█▇█
val_loss,▆▄█▃▄▃▃▂▁▁

0,1
loss,0.11048
test_acc,0.67708
train_acc,0.93766
train_loss,0.18983
val_acc,0.67696
val_loss,1.65493




VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122528…

0,1
loss,█▆▄▃▂▂▂▁▂▂▁▁▁▂▂▂▁▁▁▁▁▂▂▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁
test_acc,▁
train_acc,▁▄▅▆▆▆▇███
train_loss,█▅▄▃▃▂▂▁▁▁
val_acc,▁▄▅▅▆▆▇▇██
val_loss,█▅▅▇▅▅▄▂▁▂

0,1
loss,0.21621
test_acc,0.67932
train_acc,0.9362
train_loss,0.18603
val_acc,0.68614
val_loss,1.70705




VBox(children=(Label(value='0.013 MB of 0.109 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.122833…

0,1
loss,█▆▅▄▃▃▂▂▂▂▂▂▂▁▂▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁▁▂▁▁▁▂▂▁▁▁
test_acc,▁
train_acc,▁▇▇▇██████
train_loss,█▂▂▁▁▁▁▁▁▁
val_acc,▁▇▇███████
val_loss,█▂▁▁▁▁▁▁▁▁

0,1
loss,0.3821
test_acc,0.66017
train_acc,0.92827
train_loss,0.22514
val_acc,0.6586
val_loss,1.62753


In [90]:
model_type = "gnn"
hyperparameters = {'lr': .01, 'epochs': 200, 'num_layers': 1, 'emb_dim': 4, 
                       'in_dim': bottleneck_size, 'out_dim': output_classes, 
                       'edge_dim': 1, 'sim_matrix': sim_matrix, 'attributes': hierarchy.cub_attributes, 'group': cub_groups, 
                       'indexes': cub_groups_idx} # TODO: Chnage in_dim and emb_dim to 64

hyperparameters = update_hyperparameters_graph(model_type,hyperparameters)

In [73]:
train = cub_train_sequential
val = cub_val_sequential

if 'gnn' in model_type:
    edge_attr, edge_index, train, val = get_dataset_graph_cub_pretrain(model_type,cub_groups,hierarchy.cub_attributes,sim_matrix
                                                             , train, val) # TODO: Turn off use_random

In [74]:
lr_values = [0,1e-4,5e-4,1e-3,5e-3,1e-2,5e-2,1e-1,5e-1]
hyperparameters['epochs'] = 2

In [29]:
for lr in lr_values:
    hyperparameters['lr'] = lr
    model = run_model(model_type, hyperparameters, train, val, pretrain=True, use_wandb=False) # TODO: Change the task back
    print("Eval {} {}".format(lr, eval_model(model,model_type,val,pretrain=True))) # TODO: Change this back to pretrain 

Eval 0 (tensor(0.2345), 0.0)
Eval 0.0001 (tensor(0.0526), 0.0)
Eval 0.0005 (tensor(0.0569), 0.0)
Eval 0.001 (tensor(0.0623), 0.0)
Eval 0.005 (tensor(0.0551), 0.0)
Eval 0.01 (tensor(0.0548), 0.0)
Eval 0.05 (tensor(0.0611), 0.0)
Eval 0.1 (tensor(0.0603), 0.0)



KeyboardInterrupt



In [75]:
hyperparameters['lr'] = 1e-2
hyperparameters['epochs'] = 2

In [76]:
pretrain_model = run_model(model_type, hyperparameters, train, val, use_wandb=False, pretrain=True) # TODO: Remove pretraining

In [84]:
state_dict = pretrain_model.state_dict()
new_state_dict = {}

for key,value in state_dict.items():
    if key.startswith('convs.0'):
        new_state_dict[key] = value
    elif key.startswith('lin_in'):
        new_state_dict[key] = value

In [91]:
train = cub_train_sequential
val = cub_val_sequential
model_type = 'gnn'

if 'gnn' in model_type:
    # TODO: Change this back to regular non-pretrained
    edge_attr, edge_index, train, val = get_dataset_graph_cub(model_type,cub_groups,hierarchy.cub_attributes,sim_matrix
                                                             , train, val) # TODO: Turn off use_random

In [92]:
hyperparameters['epochs'] = 2

In [93]:
for lr in lr_values:
    hyperparameters['lr'] = lr
    model = run_model(model_type, hyperparameters, train, val, pretrain=False, use_wandb=False) # TODO: Change the task back
    print("Eval {} {}".format(lr, eval_model(model,model_type,val,pretrain=False))) # TODO: Change this back to pretrain 

Eval 0 (tensor(5.3456, device='cuda:0'), tensor(0.0017, device='cuda:0'))
Eval 0.0001 (tensor(5.3566, device='cuda:0'), tensor(0.0025, device='cuda:0'))
Eval 0.0005 (tensor(5.1949, device='cuda:0'), tensor(0.0250, device='cuda:0'))
Eval 0.001 (tensor(4.9948, device='cuda:0'), tensor(0.0267, device='cuda:0'))
Eval 0.005 (tensor(2.9838, device='cuda:0'), tensor(0.2262, device='cuda:0'))
Eval 0.01 (tensor(3.1464, device='cuda:0'), tensor(0.1937, device='cuda:0'))
Eval 0.05 (tensor(2.3476, device='cuda:0'), tensor(0.3247, device='cuda:0'))
Eval 0.1 (tensor(1.5449, device='cuda:0'), tensor(0.5217, device='cuda:0'))
Eval 0.5 (tensor(2.6256, device='cuda:0'), tensor(0.3381, device='cuda:0'))


In [94]:
hyperparameters['lr'] = 0.1
hyperparameters['epochs'] = 10

In [None]:
final_model = run_model(model_type, hyperparameters, train, val, use_wandb=True, pretrain=False,weights=new_state_dict) 

In [24]:
eval_model(final_model,model_type,val)

(tensor(0.7204), tensor(0.8197))

In [81]:
hyperparameters = {'lr': .01, 'epochs': 200, 'num_layers': 1, 'emb_dim': 64, 
                       'in_dim': bottleneck_size, 'out_dim': output_classes, 
                       'edge_dim': 1, 'sim_matrix': sim_matrix, 'attributes': hierarchy.cub_attributes, 'group': cub_groups, 
                       'indexes': cub_groups_idx} 

model_type = "mlp"
hyperparameters = update_hyperparameters_graph(model_type,hyperparameters)

In [83]:
train_mlp = cub_train_sequential
val_mlp = cub_val_sequential

In [20]:
hyperparameters['lr'] = 0.005
hyperparameters['epochs'] = 10

In [21]:
final_model_mlp = run_model(model_type, hyperparameters, train_mlp, val_mlp, use_wandb=True, pretrain=False) 

[34m[1mwandb[0m: Currently logged in as: [33mnavr414[0m. Use [1m`wandb login --relogin`[0m to force relogin


  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 20.2 K
---------------------------------------
20.2 K    Trainable params
0         Non-trainable params
20.2 K    Total params
0.081     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=50` reached.


VBox(children=(Label(value='0.004 MB of 0.004 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,▂▁▂▁▁▁▂▁▁▁▁▁▁▁▁█▁▁▅█▁▁▆▅▁▁▇▄▄▁▂▄▁▁▆▃▁▁▁▁
train_acc,▁▇▇▇▇▇██████████████████████████████████
train_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_acc,▁▂▄▅▆▆▆▆▆▇▇▇████████████████████████████
val_loss,██▆▅▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,49.0
loss,0.00095
train_acc,0.98673
train_loss,0.03914
trainer/global_step,22499.0
val_acc,0.90818
val_loss,0.31324


In [22]:
eval_model(final_model_mlp,model_type,val_mlp)

0.908180296421051

## 3-SAT Dataset

In [74]:
num_variables = bottleneck_size = 30
sat_train, sat_val, clauses = create_3_sat_dataset(num_variables,5000,100)
output_classes = 2

In [75]:
hyperparameters = {'lr': .01, 'epochs': 200, 'num_layers': 1, 'emb_dim': 64, 
                       'in_dim': len(clauses)*len(clauses[0]), 'out_dim': output_classes, 
                       'edge_dim': 1,'clauses': clauses}
model_type = "mlp_group"

if 'gnn' in model_type:
    sat_edge_attr, sat_edge_index, sat_train_graph, sat_val_graph = get_dataset_graph(
        model_type, clauses, sat_train, sat_val)
    train = sat_train_graph
    val = sat_val_graph
else:
    train = sat_train
    val = sat_val

if 'mlp' in model_type:
    hyperparameters['in_dim'] = num_variables

hyperparameters = update_hyperparameters_graph(model_type,hyperparameters)

In [76]:
hyperparameters['epochs'] = 10

In [77]:
lr_values = [1e-4,5e-4,1e-3,5e-3,1e-2,5e-2,1e-1]

In [78]:
for lr in lr_values:
    hyperparameters['lr'] = lr
    model = run_model(model_type, hyperparameters, train, val, use_wandb=False)
    print("Eval {} {}".format(lr, eval_model(model,model_type,val)))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params

Eval 0.0001 0.7300000190734863


`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Eval 0.0005 0.7300000190734863


`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Eval 0.001 0.7300000190734863


`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Eval 0.005 0.7200000286102295


`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Eval 0.01 0.6600000262260437


`Trainer.fit` stopped: `max_epochs=10` reached.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)


Eval 0.05 0.7200000286102295


`Trainer.fit` stopped: `max_epochs=10` reached.


Eval 0.1 0.6800000071525574


In [73]:
hyperparameters['lr'] = 0.01
hyperparameters['epochs'] = 50
model = run_model(model_type, hyperparameters, train, val, use_wandb=True)

  rank_zero_warn(
[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
  rank_zero_deprecation(
  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | relu    | ReLU       | 0     
1 | softmax | Softmax    | 0     
2 | sigmoid | Sigmoid    | 0     
3 | fc      | Sequential | 1.3 K 
---------------------------------------
1.3 K     Trainable params
0         Non-trainable params
1.3 K     Total params
0.005     Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=50` reached.


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,▅▄▂▂▂▂▂▆█▂▃▄▃▂▂▃▁▂▃▂▂▂▂▁▂▁▁▂▂▂▂▂▂▂▂▁▁▁▁▁
train_acc,▁▃▄▃▄▄▄▃▂▂▂▃▄▄▅▆▆▇▇▇████████████████████
train_loss,▇▅▅▄▄▄▅▅██▇▆▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val_acc,████████▆▄▃▂▂▃▃▂▃▃▃▄▄▃▃▃▃▄▄▄▄▅▄▄▃▁▂▁▂▂▂▂
val_loss,▁▁▂▃▅▆▇█▃▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇

0,1
epoch,49.0
loss,0.18527
train_acc,0.922
train_loss,0.22591
trainer/global_step,799.0
val_acc,0.53
val_loss,1.39342
