In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn as nn
import torch.nn.functional as F
import random
import h5py
from torch.utils.data import DataLoader
from torch.utils.data import Dataset, TensorDataset
import torch.optim as optim

from regression.CNOClassification import CNOClassificationModel_pl
from utils.utils_data import get_loader, load_data, read_cli_inference, find_files_with_extension, save_errors
from diffusion.variance_fn import marginal_prob_std_1, diffusion_coeff_1, marginal_prob_std_2, diffusion_coeff_2
from GenCFD.model.lightning_wrap.pl_conditional_denoiser import PreconditionedDenoiser_pl


# Data processing

In [None]:
which = "cifar_class"
which_type = "x&y"

if "mnist" in which:
    N_train = 2000
else:
    N_train = 40000


ood_share = 0.1
batch_size = 64

train_loader = get_loader(which_data = which,
                        which_type = "train",
                        N_samples = N_train,
                        ood_share= ood_share,
                        batch_size =  batch_size)

valid_loader = get_loader(which_data = which,
                        which_type = "val",
                        N_samples = 0,
                        ood_share= ood_share,
                        batch_size = batch_size)

test_loader = get_loader(which_data = which,
                        which_type = "test",
                        N_samples = 0,
                        ood_share= 1.0,
                        batch_size = 16)


### Show the data

In [None]:
images,y = next(iter(valid_loader))
images = torch.permute(images, (0, 2, 3, 1))
images = (255*(images.detach().cpu().numpy()*0.5 + 0.5)).astype("uint8")
fig = plt.figure(figsize=(25, 4))
for idx in np.arange(20):
    ax = fig.add_subplot(2, 20//2, idx+1, xticks=[], yticks=[])
    plt.imshow(images[idx])


## Load the classification model

In [None]:
import argparse
import functools
import os


if "cifar" in which:
    diff_model = "/path_to_diff_model/"
elif "mnist" in which:
    diff_model = "/path_to_diff_model/"

diffusion_model_path = str(find_files_with_extension(diff_model + "/model", "ckpt", [], is_pl = True)[0])
diffusion_config_path = str(find_files_with_extension(diff_model, "json", ["param"])[0])
config_diff = argparse.Namespace(**load_data(diffusion_config_path))
config_diff_arch = load_data(config_diff.config_arch)

sigma = config_diff.sigma
#sigma = sigma*(5/6)
marginal_prob_std_fn = functools.partial(marginal_prob_std_2, sigma_min = 0.001, sigma_max=sigma, device = "cuda")
diffusion_coeff_fn = functools.partial(diffusion_coeff_2, sigma_min = 0.001, sigma_max=sigma, device = "cuda")

if which_type == "x&y":
    dim_cond = 0
    dim = config_diff.in_dim + 1
else:
    dim_cond = 1
    dim = config_diff.in_dim

print(config_diff_arch)
print(config_diff)
diffusion_model = PreconditionedDenoiser_pl(dim = dim, 
                                            dim_cond = dim_cond,
                                            loss_fn = None,
                                            marginal_prob_std_fn = marginal_prob_std_fn,
                                            diffusion_coeff_fn = diffusion_coeff_fn,
                                            config_train = vars(config_diff),
                                            config_arch = config_diff_arch,
                                            is_inference = True
                                            )
device = "cuda"
checkpoint = torch.load(diffusion_model_path, map_location = device)
diffusion_model.load_state_dict(checkpoint["state_dict"])
diffusion_model = diffusion_model.best_model_ema.to(device)


'''
    Classification:
'''

if "cifar" in which:
    path = "/path_to_class_model/"
    reg_path_errors = "/path_to_err_folder/"
elif "mnist" in which: 
    path = "/path_to_class_model/"
    reg_path_errors = "/path_to_err_folder/"

class_path = str(find_files_with_extension(path, "json", ["param"])[0])
config_class = load_data(class_path)
config_class["workdir"] = None
config_arch = load_data(config_class["config_arch"])
model = CNOClassificationModel_pl(in_dim = config_class["in_dim"], 
                                out_dim = config_class["out_dim"],
                                loss_fn = None,
                                config_train = config_class,
                                config_arch = config_arch)
model.load_state_dict(torch.load(f"{path}/model-cifar.pt"))
model = model.to("cuda")


if not os.path.exists(reg_path_errors):
    os.makedirs(reg_path_errors)

In [None]:
from scipy.io import loadmat
from torch.utils.data import TensorDataset, DataLoader


def load_svhn(N, device = "cuda"):
    file = "/path_to_svhn/"
    data = loadmat(file)

    x_data = data['X']
    y_data = data['y']
    print(x_data.shape, y_data.shape)
    X = ((torch.tensor(np.array(x_data)[:,:,:,:N], device = device))/255-0.5)/0.5
    X = X.type(torch.float32)

    Y = torch.tensor(np.array(y_data)[:N], device = device) - 1
    Y = Y.type(torch.int64)

    # Assuming X and Y are already defined
    dataset = TensorDataset(X.permute(3,2,0,1),Y[:,-1])

    # Create DataLoader
    batch_size = 16
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataloader

In [None]:
'''
    Get p(y)
'''

def p(train_targets, y = 1):
    print(len(train_targets[train_targets == y]), len(train_targets))
    return len(train_targets[train_targets == y])/len(train_targets)

train_targets = torch.zeros(0, device = 'cuda')
for i,(data, target) in enumerate(train_loader):    
    _, target = data.cuda(), target.cuda()
    train_targets = torch.cat((train_targets, target), axis = 0)

Y = torch.zeros(10)
for y in range(10):
    Y[y] = p(train_targets, y)


### Evaluate

In [None]:
def plot_image(pred, x, target, classes, is_svhn = False):
    if x.shape[1]>1:
        x = (255*(x.permute(0,2,3,1).detach().cpu().numpy()*0.5 + 0.5)).astype("uint8")
    for j in range(pred.shape[0]):
            if j ==0:
                fig = plt.figure(figsize=(25, 4))

            ax = fig.add_subplot(2, pred.shape[0]//2, j+1, xticks=[], yticks=[])
            
            if x.shape[1] == 1:
                cax = ax.imshow((x[j,0].detach().cpu().numpy()), cmap = "seismic", vmin =0, vmax =1)
            else:
                cax = ax.imshow(x[j], cmap = "seismic", vmin =0, vmax =1)
            #ax.set_title(f"{str(target[j].item())}{list(probs.detach().cpu().numpy()[j])}")
            if not is_svnh:
                ax.set_title(f"True label: {classes[target[j].item()]}")

                    
    plt.show()

In [None]:
from diffusion.likelihood import ode_likelihood
import copy

A = 0

def evaluate(model, test_loader, stop_id = -1, which_type = "yx", T = 1, noisy_labels = True, is_svnh = False):

    if "cifar" in which:
        classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    else:
        classes = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']

    test_loss = 0.0
    class_correct = list(0. for i in range(10))
    class_total = list(0. for i in range(10))

    model.eval()
    bpds = np.zeros((0,))
    predicted = np.zeros((0,))
    targets = np.zeros((0,))


    epsilon_size = 16
    if "mnist" in which:
        dim = 1
        s = 28
    else:
        dim = 3
        s = 32
    
    if which_type == "x&y":
        dim+=1
    epsilon = torch.randn((epsilon_size, dim, s, s), device = device).type(torch.float32)
    epsilon = torch.sqrt(torch.prod(torch.tensor(dim, device=device))) * epsilon / torch.norm(epsilon, dim=1, keepdim=True)

    for i,(data, target) in enumerate(test_loader):
        
        if stop_id!=-1 and i>=stop_id:
            return bpds, targets, predicted, class_correct, class_total

        if len(target.shape)>1:
            target = target[:,0]
        
        data, target = data.cuda(), target.cuda()
    
        output = model(data)
        _, pred = torch.max(output, 1)   
        probs = torch.exp(output/T)/torch.sum(torch.exp(output/T), axis = 1)[:, None]

        targets = np.concatenate((targets, target.detach().cpu().numpy()), axis=0)
        predicted = np.concatenate((predicted, pred.detach().cpu().numpy()), axis=0)

        correct_tensor = pred.eq(target.data.view_as(pred))
        correct = np.squeeze(correct_tensor.cpu().numpy())
        length = data.shape[0]
        for j in range(length):
            label = target.data[j]
            class_correct[label] += correct[j].item()
            class_total[label] += 1

        shape = (pred.shape[0], 1) + data.shape[2:]

        label_gen = (pred.view(pred.shape[0], 1, 1, 1)/9.0) * torch.ones(shape, device = data.device).type(torch.float32)
        
        if noisy_labels:
            label_gen = torch.zeros(shape, device = data.device).type(torch.float32)
            for j in range(pred.shape[0]):
                ## Create a categorical distribution
                dist = torch.distributions.Categorical(probs=probs[j])

                ## Sample (128 x 128) integers from the distribution
                label_gen[j,0] = dist.sample((data.shape[2], data.shape[3]))/9.0
        
        
        if which_type == "yx":
            variable = data
            condition = label_gen
        else:
            variable = torch.cat((data, label_gen), axis = 1)
            condition = None
                
        _, prior, delta = ode_likelihood(diffusion_model,
                                        variable,
                                        condition,
                                        marginal_prob_std_fn,
                                        diffusion_coeff_fn,
                                        t_batch = None,
                                        batch_size=epsilon_size,
                                        device='cuda',
                                        eps = 1e-10,
                                        rtol = 1e-4,
                                        atol = 1e-4,
                                        epsilon = epsilon,
                                        ode_method = "rk38",
                                        reduce_prior = True)

        bpd = prior + delta     
        bpds = np.concatenate((bpds, bpd.detach().cpu().numpy()))
        
        print(i, len(test_loader), np.mean(bpds))
        
        for j in range(pred.shape[0]):
            if i < 2 and j<16:

                if i < 2 and j ==0:
                    fig = plt.figure(figsize=(25, 4))
                    cmap = plt.get_cmap('tab10', 10)

                ax = fig.add_subplot(2, pred.shape[0]//2, j+1, xticks=[], yticks=[])
                
                cax = ax.imshow((9.0*label_gen[j,0].detach().cpu().numpy()).astype("uint8"), cmap = cmap, vmin =-0.5, vmax = 9.5)
                if not is_svnh:
                    ax.set_title(f"True label: {classes[target[j].item()]}, l ={round(bpd[j].item(), 1)}")
                cbar = fig.colorbar(cax,ticks=list(np.arange(0,10)))
                cbar.ax.set_yticklabels(classes)  # horizontal colorbar
                
            
            if i <2 and j == pred.shape[0]-1:
                plt.show()
        
        if i <2:
            plot_image(pred, data, target, classes, is_svnh)

    # average test loss
    test_loss = test_loss/len(test_loader.dataset)
    print('Test Loss: {:.6f}\n'.format(test_loss))

    for i in range(10):
        if class_total[i] > 0:
            print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
                classes[i], 100 * class_correct[i] / class_total[i],
                np.sum(class_correct[i]), np.sum(class_total[i])))
        else:
            print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))

    print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
        100. * np.sum(class_correct) / np.sum(class_total),
        np.sum(class_correct), np.sum(class_total)))
    
    return bpds, targets, predicted, class_correct, class_total

In [None]:
import ast

def save_errors(folder, bpds, targets, predicted, class_correct, class_total, is_svhn=False, noisy_labels=True):
    if not os.path.exists(folder):
        os.makedirs(folder)
    if is_svhn:
        tag1 = "svhn"
    else:
        tag1 = "cifar10"
    if noisy_labels:
        tag2 = "noisy_labels"
    else:
        tag2 = "NOT_noisy_labels"
    file = f"{folder}/{tag1}_{tag2}_{len(bpds)}.txt"
    s = f"{str(list(bpds))}\n{str(list(targets))}\n{str(list(predicted))}\n{str(list(class_correct))}\n{str(list(class_total))}"

    text_file = open(file, "w")
    text_file.write(s)
    text_file.close()

    print(file, "SAVED")

def load_errors(file):
    with open(file, "r") as f:
        lines = f.readlines()
        if len(lines) != 5:
            raise ValueError("Expected 5 lines in the file")

        bpds = ast.literal_eval(lines[0].strip())
        targets = ast.literal_eval(lines[1].strip())
        predicted = ast.literal_eval(lines[2].strip())
        class_correct = ast.literal_eval(lines[3].strip())
        class_total = ast.literal_eval(lines[4].strip())

    return np.array(bpds), np.array(targets), np.array(predicted), np.array(class_correct), np.array(class_total)