In [1]:
import torch
import pickle
import argparse
import numpy as np
import torch.nn as nn
from torchmetrics import Accuracy
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import ImageFolder
import os
from fastai.vision.all import untar_data, URLs
import pandas as pd
from pytorch_lightning import Trainer
from dime.data_utils import HistopathologyDownsampledDataset
from dime.utils import MaskLayer2d
from dime import MaskingPretrainer
from dime import CMIEstimator, MaskLayer
from dime.resnet_imagenet import resnet18, resnet34, Predictor, ValueNetwork, ResNet18Backbone, resnet50
from dime.vit import PredictorViT, ValueNetworkViT
import timm
import matplotlib.pyplot as plt

  from pkg_resources import DistributionNotFound, get_distribution
  from .autonotebook import tqdm as notebook_tqdm


# Load Dataset

In [2]:
# Load test dataset, split into train/val
mnist_dataset = MNIST('/tmp/mnist/', download=True, train=True,
                      transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]))
np.random.seed(0)
# Load test dataset
test_dataset = MNIST('/tmp/mnist/', download=True, train=False,
                     transform=transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)]))

device = torch.device('cuda:0')

test_dataloader = DataLoader(
        test_dataset, batch_size=128, shuffle=False, pin_memory=True,
        drop_last=True, num_workers=2)

In [3]:
test_dataset

Dataset MNIST
    Number of datapoints: 10000
    Root location: /tmp/mnist/
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Lambda()
           )

# Set up networks

In [4]:
acc_metric = Accuracy(task='multiclass', num_classes=10)
d_in = 784
d_out = 10
hidden = 512
dropout = 0.3

# Outcome Predictor
predictor = nn.Sequential(
    nn.Linear(d_in * 2, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_out)).to(device)

# CMI Predictor
value_network = nn.Sequential(
    nn.Linear(d_in * 2, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, hidden),
    nn.ReLU(),
    nn.Dropout(dropout),
    nn.Linear(hidden, d_in)).to(device)

value_network[0] = predictor[0]
value_network[3] = predictor[3]

mask_layer = MaskLayer(append=True, mask_size=d_in)

trainer = Trainer(
                    accelerator='gpu',
                    devices=[device.index],
                    precision=16
                )

Using 16bit None Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


# Evaluate Penalized Policy

In [6]:
for trial in range(0, 1):
    results_dict = {"acc": {}}
    path = f"/home/haozhe/paper/DIME/logs/max_features_50_eps_0.05_with_decay_rate_0.2_save_best_loss_with_entropy_fix_trial_{trial}/version_2/checkpoints/best_val_perf_model.ckpt"
    greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(path,
                                                             value_network=value_network,
                                                             predictor=predictor,
                                                             mask_layer=mask_layer,
                                                             lr=1e-3,
                                                             min_lr=1e-6,
                                                             max_features=50,
                                                             eps=0.05,
                                                             loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                             val_loss_fn=acc_metric,
                                                             eps_decay=0.2,
                                                             eps_steps=10,
                                                             patience=3,
                                                             feature_costs=None).to(device)
    avg_num_features_lamda = []
    accuracy_scores_lamda = []
    all_masks_lamda =[]

    lamda_values = list(np.geomspace(0.00016, 0.28, num=10))
    for lamda in lamda_values:
        metric_dict = greedy_cmi_estimator.inference(trainer, test_dataloader, feature_costs=None, lam=lamda)
        
        y = metric_dict['y']
        pred = metric_dict['pred']
        accuracy_score = acc_metric(pred, y)
        final_masks = np.array(metric_dict['mask'])
        accuracy_scores_lamda.append(accuracy_score)
        avg_num_features_lamda.append(np.mean(np.sum(final_masks, axis=1)))
        results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score

        print(f"Lambda={lamda}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")
        all_masks_lamda.append(final_masks)
    # with open(f'results/mnist_lamda_ours_trial_{trial}.pkl', 'wb') as f:
    #     pickle.dump(results_dict, f)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0:   0%|          | 0/79 [00:00<?, ?it/s]

Predicting DataLoader 0: 100%|██████████| 79/79 [00:15<00:00,  5.10it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Lambda=0.00016, Acc=0.9419000148773193, Avg. num features=19.170900344848633
Predicting DataLoader 0: 100%|██████████| 79/79 [00:12<00:00,  6.46it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Lambda=0.0003668238322014383, Acc=0.939300000667572, Avg. num features=16.771799087524414
Predicting DataLoader 0: 100%|██████████| 79/79 [00:10<00:00,  7.29it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Lambda=0.0008409982741934297, Acc=0.9368000030517578, Avg. num features=14.5423002243042
Predicting DataLoader 0: 100%|██████████| 79/79 [00:09<00:00,  8.64it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Lambda=0.0019281138113401836, Acc=0.9307000041007996, Avg. num features=12.472599983215332
Predicting DataLoader 0: 100%|██████████| 79/79 [00:07<00:00, 10.81it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Lambda=0.004420488107477043, Acc=0.9235000014305115, Avg. num features=10.827899932861328
Predicting DataLoader 0: 100%|██████████| 79/79 [00:05<00:00, 13.63it/s]

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]



Lambda=0.010134627423660066, Acc=0.9128000140190125, Avg. num features=9.391400337219238
Predicting DataLoader 0: 100%|██████████| 79/79 [00:06<00:00, 12.10it/s]
Lambda=0.02323514293425482, Acc=0.8880000114440918, Avg. num features=8.028499603271484


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 79/79 [00:03<00:00, 20.49it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Lambda=0.05327002608057195, Acc=0.8518000245094299, Avg. num features=7.089300155639648
Predicting DataLoader 0: 100%|██████████| 79/79 [00:03<00:00, 23.71it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Lambda=0.12212946942716223, Acc=0.7943999767303467, Avg. num features=5.749199867248535
Predicting DataLoader 0: 100%|██████████| 79/79 [00:01<00:00, 50.55it/s]
Lambda=0.28, Acc=0.11349999904632568, Avg. num features=0.0


# Evaluate Budget Constrained Policy

In [7]:
results_dict = {"acc": {}}
# for trial in range(0, 5):
path = "/home/haozhe/paper/DIME/logs/max_features_50_eps_0.05_with_decay_rate_0.2_save_best_loss_with_entropy_fix_trial_4/version_0/checkpoints/best_val_perf_model.ckpt"
greedy_cmi_estimator = CMIEstimator.load_from_checkpoint(path,
                                                         value_network=value_network,
                                                         predictor=predictor,
                                                         mask_layer=mask_layer,
                                                         lr=1e-3,
                                                         min_lr=1e-6,
                                                         max_features=50,
                                                         eps=0.05,
                                                         loss_fn=nn.CrossEntropyLoss(reduction='none'),
                                                         val_loss_fn=acc_metric,
                                                         eps_decay=0.2,
                                                         eps_steps=10,
                                                         patience=3,
                                                         feature_costs=None)
avg_num_features_budget = []
accuracy_scores_budget = []
all_masks_budget=[]

max_budget_values = [3] + list(range(5, 30, 5))
for budget in max_budget_values:
    metric_dict_budget = greedy_cmi_estimator.inference(trainer, test_dataloader, 
                                                                    feature_costs=None, budget=budget)

    y = metric_dict_budget['y']
    pred = metric_dict_budget['pred']
    accuracy_score = acc_metric(pred, y)
    final_masks = np.array(metric_dict_budget['mask'])
    accuracy_scores_budget.append(accuracy_score)
    avg_num_features_budget.append(np.mean(np.sum(final_masks, axis=1)))
    results_dict['acc'][np.mean(np.sum(final_masks, axis=1))] = accuracy_score
    print(f"Budget={budget}, Acc={accuracy_score}, Avg. num features={np.mean(np.sum(final_masks, axis=1))}")

all_masks_budget.append(final_masks)
# with open(f'results/mnist_ours_trial_{trial-4}.pkl', 'wb') as f:
#     pickle.dump(results_dict, f)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0:   0%|          | 0/79 [00:00<?, ?it/s]

Predicting DataLoader 0: 100%|██████████| 79/79 [00:02<00:00, 37.97it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Budget=3, Acc=0.4246000051498413, Avg. num features=3.0
Predicting DataLoader 0: 100%|██████████| 79/79 [00:02<00:00, 33.46it/s]
Budget=5, Acc=0.5522000193595886, Avg. num features=5.0


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting DataLoader 0: 100%|██████████| 79/79 [00:03<00:00, 24.54it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Budget=10, Acc=0.6887000203132629, Avg. num features=10.0
Predicting DataLoader 0: 100%|██████████| 79/79 [00:04<00:00, 19.20it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Budget=15, Acc=0.7526000142097473, Avg. num features=15.0
Predicting DataLoader 0: 100%|██████████| 79/79 [00:04<00:00, 15.95it/s]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Budget=20, Acc=0.7922999858856201, Avg. num features=20.0
Predicting DataLoader 0: 100%|██████████| 79/79 [00:05<00:00, 13.71it/s]
Budget=25, Acc=0.8148000240325928, Avg. num features=25.0
