In [None]:
import os
import sys
import torch
import numpy as np
from tqdm import tqdm
from util import get_transforms, get_dataset, get_image_size, get_dataloader
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
%reload_ext autoreload
%autoreload 2
import scipy


dataset_name = 'MNIST'

batch_size = 1  # record each loss element, not mean
timesteps = 300

image_size = get_image_size(dataset_name)
transform, _ = get_transforms(image_size=image_size)
trainset, testset = get_dataset(dataset_name, transform)
trainloader, testloader = get_dataloader(trainset, testset, 1)

def flatten_features(record_latent_Features):
    for key in record_latent_Features.keys():
        for i in range(len(record_latent_Features[key])):
            record_latent_Features[key][i] = record_latent_Features[key][i].flatten()


# Modify UNet 
add 3 functions to the UNet class
- _record_latent_features()
- _stop_record_latent_features()
- _start_ood_detection()

In [None]:
import math
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF


class Block(nn.Module):
    def __init__(self, in_channels, out_channels, time_emb_dim, label):
        super().__init__()

        if label:
            self.label_embedding = nn.Embedding(1, 8)
            self.label_mlp = nn.Linear(1, out_channels)

        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.Dropout2d(0.03),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, t, label=None):
        # print(x.shape)
        h = self.conv1(x)
        # Time embedding
        time_emb = self.relu(self.time_mlp(t))
        # Extend last 2 dimensions
        time_emb = time_emb[(...,) + (None,) * 2]
        # Add time channel
        h = h + time_emb
        if label:
            label_emb = self.relu(self.label_mlp(label))
            label_emb = label_emb[(...,) + (None,) * 2]
            h = h + label_emb
        # Second Conv
        h = self.conv2(h)
        return h


class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim

    def forward(self, time):
        device = time.device
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
        embeddings = time[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return embeddings


class UNet(nn.Module):
    def __init__(
        self,
        input_channels=3,
        output_channels=3,
        channels=(64, 128, 256, 512),
        time_emb_dim=32,
        label=None,
    ):
        super().__init__()

        # latent recording
        self.record_latent = False # activate only when we call _record_latent_features()
        # self.record_timesteps = record_timesteps
        # if self.record_latent:
        #     self.record_latent_features = { key:[] for key in record_timesteps }
        self.ood_detection_indicator = False

        # Time embedding
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_emb_dim),
            nn.Linear(time_emb_dim, time_emb_dim),
            nn.ReLU(),
        )

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Downsample
        self.downs = nn.ModuleList()
        for channel in channels:
            self.downs.append(Block(input_channels, channel, time_emb_dim, label))
            input_channels = channel

        # Bottleneck
        self.bottleneck = Block(channels[-1], channels[-1] * 2, time_emb_dim, label)

        # Upsample
        self.ups = nn.ModuleList()
        for channel in reversed(channels):
            self.ups.append(
                nn.ConvTranspose2d(channel * 2, channel, kernel_size=2, stride=2)
            )
            self.ups.append(Block(channel * 2, channel, time_emb_dim, label))

        self.output = nn.Conv2d(channels[0], output_channels, kernel_size=1)

    def _record_latent_features(self, record_timesteps):
        # latent recording
        self.record_latent = True
        self.record_timesteps = record_timesteps
        if self.record_latent:
            self.record_latent_features = { key:[] for key in record_timesteps }

    def _stop_record_latent_features(self):
        self.record_latent = False

    def _start_ood_detection(self, ood_detector, detect_timesteps):
        self.detect_timesteps = detect_timesteps
        self.ood_detection_indicator = True
        self.ood_detector = ood_detector
        self.ood_detect_res = []

    def forward(self, x, timestep, label=None):
        # Embedd time
        t = self.time_mlp(timestep)
        # Unet
        residual_inputs = []
        for down in self.downs:
            x = down(x, t, label)
            residual_inputs.append(x)
            x = self.pool(x)
        # print(x.shape)
        # print(timestep)
        # record the bottleneck latent features to detect OOD samples (only in training mode)
        # only record a few timesteps, since first few steps already contain enough information
        if self.record_latent:
            for i, _t in enumerate(timestep):
                if _t in self.record_timesteps:
                    self.record_latent_features[_t.item()].append(x[i].detach().cpu().numpy())
        
        if self.ood_detection_indicator:
            for i, _t in enumerate(timestep):
                if _t in self.detect_timesteps:
                    # print(x[i].detach().cpu().numpy().reshape(1,-1).shape)
                    ood_pred, max_dist = self.ood_detector.detect_l2_distance_ood(x[i].detach().cpu().numpy().reshape(1,-1), _t.item())
                    self.ood_detect_res.append((ood_pred, max_dist))

        x = self.bottleneck(x, t)
        for i in range(0, len(self.ups), 2):
            conv_t = self.ups[i]
            up = self.ups[i + 1]
            residual_x = residual_inputs.pop()

            x = conv_t(x)

            if x.shape != residual_x.shape:
                x = TF.resize(x, size=residual_x.shape[2:], antialias=True)

            x = torch.cat((x, residual_x), dim=1)
            x = up(x, t, label)

        return self.output(x)

# Train the DDPM model

In [None]:
# skip

# Infer the training dataset

In [None]:
import torch
import argparse
import torchvision
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from unet import UNet
from model import DiffusionModel
from util import get_transforms, get_dataset, get_image_size, get_dataloader
import sys


def run_latent_all_samples(unet, loss_fn, trainloader, diffusion_model, device, batch_size, record_timesteps):
    unet.eval()
    related_loss = { key:[] for key in record_timesteps }
    for timestep in record_timesteps:
        for batch, _ in trainloader:
            t = torch.full((batch_size, ), timestep).to(device)
            batch = batch.to(device)
            batch_noisy, noise = diffusion_model.forward(batch, t, device) 
            predicted_noise = unet(batch_noisy, t)
            loss = loss_fn(noise, predicted_noise)
            related_loss[timestep].append(loss.item())
            print(f'length of record_latent_features: {np.sum([len(unet.record_latent_features[key]) for key in unet.record_latent_features.keys()])}')

    torch.save(unet.record_latent_features, "weight/record_latent_features.pt")
    torch.save(related_loss, "weight/record_latent_features_loss.pt")

    print(f'length of record_latent_features: {np.sum([len(unet.record_latent_features[key]) for key in unet.record_latent_features.keys()])}')


def main():
    dataset_name = 'MNIST'
    batch_size = 1  # record each loss element, not mean
    timesteps = 300

    image_size = get_image_size(dataset_name)
    transform, _ = get_transforms(image_size=image_size)
    trainset, testset = get_dataset(dataset_name, transform)
    # trainloader, testloader = get_dataloader(trainset, testset, batch_size)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=8, drop_last=True)

    record_timesteps = (0,1,5,10,25,50,100,200,299)
    unet = UNet(input_channels=1, output_channels=1, record_latent=True).to('cuda')
    unet.load_state_dict(torch.load('weight/parameters_power_mnist.pkl'))
    diffusion_model = DiffusionModel(timesteps=timesteps)
    loss_fn = torch.nn.MSELoss()
    unet._record_latent_features(record_timesteps=record_timesteps)
    run_latent_all_samples(unet, loss_fn, trainloader, diffusion_model, 'cuda', batch_size, record_timesteps)


if __name__ == '__main__':
    main()

# Load the stored features

In [None]:
record_latent_features = torch.load('./weight/record_latent_features.pt')
record_latent_features_loss = torch.load('./weight/record_latent_features_loss.pt')
flatten_features(record_latent_features)

# Define the OOD Detector
Only two methods sofar
- distance-based method
- norm-based method

In [None]:
def svd_Phi(Phi,num_k, method):
    # the rows of vh are the eigenvectors of Phi^HPhi, 
    # i.e. the principal components of Phi
    if method == 'np':
        U, S, Vh = np.linalg.svd(Phi)
        print(U.shape, S.shape, Vh.shape)
        # select the top-k principal components
        topk_pc = Vh[:num_k,:]
        # project Phi to top-k principal components
        Phi_projected = Phi @ topk_pc.T
        print(Phi_projected.shape)
    elif method == 'scipy':
        U, S, Vh = scipy.linalg.svd(Phi, full_matrices=False)
        print(U.shape, S.shape, Vh.shape)
        # select the top-k principal components
        topk_pc = Vh[:num_k,:]
        # project Phi to top-k principal components
        Phi_projected = Phi @ topk_pc.T
        print(Phi_projected.shape)
    else:
        raise ValueError("no such svd method")
    return Phi_projected, (U,S,Vh)


def lower_triangular_flatten(matrix):
    # Extract the lower triangular part using np.tril_indices
    lower_triangular_indices = np.tril_indices(matrix.shape[0])
    
    # Use the indices to extract the lower triangular elements
    lower_triangular_elements = matrix[lower_triangular_indices]

    return lower_triangular_elements




class pca_ood_detector:
    def __init__(self, record_latent_features):
        
        self.record_latent_features = record_latent_features
        self.timesteps = [key for key in self.record_latent_features]

    def flatten_features(self):
        for key in self.record_latent_features.keys():
            for i in range(len(self.record_latent_features[key])):
                self.record_latent_features[key][i] = self.record_latent_features[key][i].flatten()

    def pca_analyze(self, topk=100):

        self.pca_basis = { key:None for key in self.record_latent_features.keys() }
        self.pca_projected_latents = { key:None for key in self.record_latent_features.keys() }
        tqmdr = tqdm(self.record_latent_features.keys(), desc='pca_analysis')

        for key in tqmdr:
            U, S, Vh = scipy.linalg.svd(self.record_latent_features[key], full_matrices=False)
            print(U.shape, S.shape, Vh.shape)
            # select the top-k principal components
            topk_pc = Vh[:topk,:]
            # project Phi to top-k principal components
            Phi_projected = self.record_latent_features[key] @ topk_pc.T

            self.pca_basis[key] = topk_pc
            self.pca_projected_latents[key] = Phi_projected

    def calc_avg_distance(self):    # setting threshold

        pca_projected_latents_dist = {}
        pca_projected_latents_stats = {}

        tqdmr = tqdm(self.pca_projected_latents.keys(), desc='calculate the stats of latent features for training samples')

        for key  in tqdmr:

            rand_ids = np.random.choice(len(self.pca_projected_latents[key]), size=10000, replace=False)
            vectors = np.array(self.pca_projected_latents[key])[rand_ids,:]
            
            # Compute squared distances
            dot_product = np.dot(vectors, vectors.T)
            squared_norms = np.sum(vectors**2, axis=1, keepdims=True)
            squared_distances = squared_norms + squared_norms.T - 2 * dot_product
            
            # Ensure distances are non-negative
            squared_distances = np.maximum(squared_distances, 0)
            
            # Take the square root to get L2 distance
            distances = np.sqrt(squared_distances)

            pca_projected_latents_dist[key] = distances
            pca_projected_latents_stats[key] = (np.mean(distances), np.std(lower_triangular_flatten(distances)))

        self.pca_projected_latents_dist = pca_projected_latents_dist
        self.pca_projected_latents_stats = pca_projected_latents_stats

    def set_threshold(self, sigma_threshold=5): # mean+3*std

        self.thresholds = { key:self.pca_projected_latents_stats[key][0]+sigma_threshold*self.pca_projected_latents_stats[key][1] 
                          for key in self.pca_projected_latents_stats.keys() }

    def detect_l2_distance_ood(self, latent, timestep):
        assert timestep in self.timesteps 

        projected_latent = latent @ self.pca_basis[timestep].T
        # projected_latent = self.pca_models[timestep].transform(latent)
        l2_distances = np.linalg.norm(self.pca_projected_latents[timestep] - projected_latent, axis=1)
        min_distance = np.min(l2_distances)
        if min_distance > self.thresholds[timestep]:
            print('ood sample!')
            return True, min_distance
        else:
            print('id sample!')
            return False, min_distance


class pca_ood_detector_norm:
    def __init__(self, record_latent_features):
        
        self.record_latent_features = record_latent_features
        self.timesteps = [key for key in self.record_latent_features]

    def flatten_features(self):
        for key in self.record_latent_features.keys():
            for i in range(len(self.record_latent_features[key])):
                self.record_latent_features[key][i] = self.record_latent_features[key][i].flatten()

    def pca_analyze(self, topk=100):

        self.pca_basis = { key:None for key in self.record_latent_features.keys() }
        self.pca_projected_latents = { key:None for key in self.record_latent_features.keys() }
        tqmdr = tqdm(self.record_latent_features.keys(), desc='pca_analysis')

        for key in tqmdr:
            U, S, Vh = scipy.linalg.svd(self.record_latent_features[key], full_matrices=False)
            print(U.shape, S.shape, Vh.shape)
            # select the top-k principal components
            topk_pc = Vh[:topk,:]
            # project Phi to top-k principal components
            Phi_projected = self.record_latent_features[key] @ topk_pc.T

            self.pca_basis[key] = topk_pc
            self.pca_projected_latents[key] = Phi_projected

    def calc_avg_norm(self):    # setting threshold
        pca_projected_latents_norm = {}
        pca_projected_latents_norm_stats = {}

        tqdmr = tqdm(self.pca_projected_latents.keys(), desc='calculate the stats of latent features for training samples')

        for key in tqdmr:
            norms = np.linalg.norm(self.pca_projected_latents[key],axis=1)
            pca_projected_latents_norm[key] = norms
            pca_projected_latents_norm_stats[key] = (np.mean(norms), np.std(norms)) 
           
        self.pca_projected_latents_norm = pca_projected_latents_norm
        self.pca_projected_latents_norm_stats = pca_projected_latents_norm_stats

    def set_threshold(self, sigma_threshold=5): # mean+3*std

        self.thresholds = { key:(self.pca_projected_latents_norm_stats[key][0]+sigma_threshold*self.pca_projected_latents_norm_stats[key][1] ,
                                self.pca_projected_latents_norm_stats[key][0]-sigma_threshold*self.pca_projected_latents_norm_stats[key][1])
                          for key in self.pca_projected_latents_norm_stats.keys() }

    def detect_l2_distance_ood(self, latent, timestep):
        assert timestep in self.timesteps 
        projected_latent = latent @ self.pca_basis[timestep].T
        norm = np.linalg.norm(projected_latent)
        # print(norm)
        # print(self.thresholds[timestep])
        # l2_distances = np.linalg.norm(self.pca_projected_latents[timestep] - projected_latent, axis=1)
        # min_distance = np.min(l2_distances)
        if norm > self.thresholds[timestep][0] or norm < self.thresholds[timestep][1]:
            print('ood sample!')
            return True, norm
        else:
            print('id sample!')
            return False, norm
        
    


# Fit the OOD Detector

In [None]:
ood_norm_detector = pca_ood_detector_norm(record_latent_features)
ood_norm_detector.flatten_features()
ood_norm_detector.pca_analyze()
ood_norm_detector.calc_avg_norm()
ood_norm_detector.set_threshold(2)

# Test the OOD Detector

In [None]:
import wandb
import torch
import argparse
import torchvision
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
# from unet import UNet
from model import DiffusionModel
from util import get_transforms, get_dataset, get_image_size, get_dataloader
import sys
from unet_power import UNet



def infer(unet, diffusion_model, device, T, reverse_transform, n):

    unet.train()
    sd = []
    total_samples = []
    with torch.no_grad():
        tqdmr = tqdm(range(n))
        for _ in tqdmr:
            samples = []
            image = torch.randn((1, 1, 32, 32)).to(device)
            for i in reversed(range(diffusion_model.timesteps)):
                samples_at_step = []
                for _ in range(T):
                    image = diffusion_model.backward(image, torch.full((1, ), i, dtype=torch.long, device=device), unet)
                    samples_at_step.append(image)
                samples_at_step = torch.cat(samples_at_step, dim=0)
                mean_sample = samples_at_step.mean(dim=0)
                sd_sample = samples_at_step.std(dim=0).mean()
                if i % 50 == 0:
                    samples.append(reverse_transform(mean_sample).cpu())
            sd.append(sd_sample.cpu().item())
            total_samples.append(samples)
    return total_samples, sd

def infer_ood_abnormal(unet, diffusion_model, ood_detector, device, reverse_transform, n_imgs):

    unet.eval()
    unet._start_ood_detection(ood_detector, (0,199,399,599,799,999))
    
    infer_samples = []
    ood_indicator = []
    with torch.no_grad():
        tqdmr = tqdm(range(n_imgs))
        for _ in tqdmr:
            samples = []
            # image = torch.randint(0,255,(1, 1, 32, 32),dtype=torch.float).to(device)
            image = torch.randn((1, 1, 32, 32)).to(device) * 1.5
            for i in reversed(range(diffusion_model.timesteps)):
                image = diffusion_model.backward(image, torch.full((1, ), i, dtype=torch.long, device=device), unet)
                # print(i)
                if i % 50 == 0:
                    samples.append(reverse_transform(image).cpu())
            infer_samples.append(samples)
            if np.any([i[0] for i in unet.ood_detect_res[-5:-1]]):
                ood_indicator.append('ood')
            else:
                ood_indicator.append('id')

    return infer_samples, ood_indicator

In [None]:
TIMESTEPS = 1000
IMAGE_SIZE = (32, 32)
device = 'cuda'
n_imgs = 5
N = 50
T = 2

# ood_norm_detector.set_threshold(1.5) # previous success
ood_norm_detector.set_threshold(2)
unet = UNet(T=TIMESTEPS, ch=32, ch_mult=[1,2,2,2], attn=[1], num_res_blocks=2, 
                 dropout=0.1).to('cuda')
unet.load_state_dict(torch.load('weight/parameters_power_mnist.pkl'))
diffusion_model = DiffusionModel(timesteps=TIMESTEPS)
_, reverse_transform = get_transforms(image_size=IMAGE_SIZE)

infer_samples, ood_indicator = infer_ood_abnormal(unet, diffusion_model, ood_norm_detector, device, reverse_transform, N)

# Draw ID/OOD images

In [None]:
# Create a grid of 4 rows and 5 columns for subplots
fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(15, 12))

# Flatten the 2D array of subplots for easier indexing
axes = axes.flatten()
# Plot on each subplot
for axes_idx,i in enumerate(np.array(range(50))[np.array(ood_indicator)=='ood'][0:]):
    if axes_idx >= 20:  break
    axes[axes_idx].imshow(infer_samples[i][-1][-1][0],cmap='gray',vmin=0, vmax=255)
    axes[axes_idx].set_title(f'Plot {ood_indicator[i]}')

# Adjust layout for better spacing
plt.tight_layout()

# Show the plots
plt.show()

In [None]:
# Create a grid of 4 rows and 5 columns for subplots
fig, axes = plt.subplots(nrows=4, ncols=5, figsize=(15, 12))

# Flatten the 2D array of subplots for easier indexing
axes = axes.flatten()
# Plot on each subplot
for axes_idx,i in enumerate(np.array(range(50))[np.array(ood_indicator)=='ood'][0:]):
    if axes_idx >= 20:  break
    axes[axes_idx].imshow(infer_samples[i][-1][-1][0],cmap='gray',vmin=0, vmax=255)
    axes[axes_idx].set_title(f'Plot {ood_indicator[i]}')

# Adjust layout for better spacing
plt.tight_layout()

# Show the plots
plt.show()