In [None]:
!pip install torchvision
!pip install ipywidgets

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
from datetime import datetime
from mnist_data import MNISTSampler, get_mnist_data
from mnist_utils import CNN, train, test, extract_embeddings, run_distribution_ablation

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
shift_type = "none"
train_data, test_data = get_mnist_data(root="data", shift=shift_type, download=True)

In [None]:
print(train_data)
print(test_data)

In [None]:
loaders = {
    'train' : DataLoader(train_data, 
                         batch_size=100, 
                         shuffle=True, 
                         num_workers=1),
    
    'test'  : DataLoader(test_data, 
                         batch_size=100, 
                         shuffle=True, 
                         num_workers=1),
}

In [None]:
model = CNN()
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = 0.01)

In [None]:
model

In [None]:
num_epochs = 5
train(num_epochs, model, loaders['train'], loss_func, optimizer)
test(model, loaders['test'])

In [None]:
data = extract_embeddings(model, loaders)
with open('mnist_embeddings.pkl', 'wb') as file:
    pickle.dump(data, file)

# Ablation

In [None]:
model = CNN()
loss_func = nn.CrossEntropyLoss()

In [None]:
st = datetime.now()
out = run_distribution_ablation(
    model,
    loss_func,
    optimizer_type = optim.Adam,
    num_samples = 10,
    num_epochs = 5,
    sampling_distribution = "dirichlet"
)
print(datetime.now() - st)

In [None]:
accuracy, embedding_distances, sample_weight_distance, proportions, classification_test_accuracies, shift_classification_accuracy, sample_weights_train, sample_weights_test = out

In [None]:
import pickle
ablation_dict = {
    'accuracies': accuracy,
    'embedding_distances': embedding_distances,
    'sample_weight_distance': sample_weight_distance,
    'sample_weights_train': sample_weights_train,
    'sample_weights_test': sample_weights_test,
    'proportions': proportions,
    'classification_test_accuracies': classification_test_accuracies,
    'shift_classification_accuracy': shift_classification_accuracy
}
with open('ablation.pkl', 'wb') as f:
    pickle.dump(ablation_dict, f)

# Visualize

In [None]:
import numpy as np
import matplotlib.pyplot as plt

fig, ax1 = plt.subplots()

color = 'tab:red'
ax1.set_xlabel('Distance between embedding sets')
ax1.set_ylabel('Accuracy', color=color)
ax1.scatter(sample_weight_distance, classification_test_accuracies, color=color)
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('Positive rate', color=color)  # we already handled the x-label with ax1
ax2.scatter(sample_weight_distance, proportions, color=color)
ax2.tick_params(axis='y', labelcolor=color)

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.show()