In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")
from pathlib import Path
import torch as th
import torch.nn.functional as F
from src.guidance.reconstruction import ReconstructionGuidance
from src.model.resnet import load_classifier
from src.utils.net import Device, get_device
from src.diffusion.base import DiffusionSampler
from src.diffusion.beta_schedules import improved_beta_schedule
from src.model.unet import load_mnist_diff
from src.utils.vis import every_nth_el, plot_reconstr_diff_seq, plot_accs, plot_samples_grid
from src.utils.classification import accuracy, logits_to_label

def _load_class(class_path: Path, device):
    classifier = load_classifier(class_path)
    classifier.to(device)
    classifier.eval()
    return classifier

def reconstr_accuracy(samples, classifier, guidance, ys):
    accs = []
    for t, x_t_batch in samples:
        x_t_batch = x_t_batch.to(th.device("cuda"))
        base_pred_ys = logits_to_label(classifier(x_t_batch))
        base_acc_t = accuracy(ys, base_pred_ys)
        x_0_hat = guidance.predict_x_0(x_t_batch, t)
        rec_pred_ys = logits_to_label(classifier(x_0_hat))
        rec_acc_t = accuracy(ys, rec_pred_ys)
        accs.append((t, base_acc_t, rec_acc_t))
    return accs

def _detach_samples():
    for (t, x_t) in samples:
        x_t.detach().cpu()

## Sample from unconditional diffusion model

In [None]:
device = get_device(Device.GPU)
models_dir = Path.cwd().parent / "models"
uncond_diff = load_mnist_diff(models_dir / "uncond_unet_mnist.pt", device)
T = 1000
diff_sampler = DiffusionSampler(improved_beta_schedule, num_diff_steps=T)
diff_sampler.to(device)
num_samples = 20
print("Sampling x_0:T")
diff_samples_0, diff_samples = diff_sampler.sample(uncond_diff, num_samples, device, th.Size((1, 28, 28)))


In [None]:
diff_samples_0 = diff_samples_0.detach().cpu()
plot_samples_grid(diff_samples_0, (4, 5))

## Classification

In [None]:
classifier = _load_class(models_dir / "resnet_reconstruction_classifier_mnist.pt", device)
guidance = ReconstructionGuidance(uncond_diff, classifier, diff_sampler.alphas_bar.clone(), F.cross_entropy)
pred_class = logits_to_label(classifier(diff_samples_0.clone().to(device)))
accs = reconstr_accuracy(diff_samples, classifier, guidance, pred_class)


In [None]:
plot_accs(accs)

In [None]:
logits = classifier(diff_samples_0.clone().to(device))
from src.utils.classification import entropy
p = F.softmax(logits, dim=1)
entropy(p)

## Classifier gradient

In [None]:
guidance