## VAE

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons, make_swiss_roll
import sklearn
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
import torch.nn as nn
import pytorch_lightning as pl
import numpy as np


pl.seed_everything(1)

# VAE Model
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2*hidden_dim),
            nn.Sigmoid(),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, latent_dim * 2)  # Mean and log-variance
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * hidden_dim),
            nn.Sigmoid(),
            nn.Linear(2 * hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, input_dim),
            nn.Sigmoid()
        )
    
    def encode(self, x):
        h = self.encoder(x)
        mean, logvar = h.chunk(2, dim=-1)
        return mean, logvar
    
    def reparameterize(self, mean, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mean + eps * std
    
    def decode(self, z):
        return self.decoder(z)
    
    def forward(self, x):
        mean, logvar = self.encode(x)
        z = self.reparameterize(mean, logvar)
        return self.decode(z), z, mean, logvar

# Loss function
def vae_loss(recon_x, x, mean, logvar, beta=0.01, score=None, DSM=None):
    if isinstance(recon_x, list) and isinstance(recon_x, list):
        for i in range(len(recon_x)):
            if i == 0:
                recon_loss = nn.functional.mse_loss(recon_x[i], x[i], reduction='sum')
            else:
                recon_loss += nn.functional.mse_loss(recon_x[i], x[i], reduction='sum')
    else:
        recon_loss = nn.functional.mse_loss(recon_x, x, reduction='sum')
    kld_encoder_posterior = 0.5 * torch.sum(- 1 - logvar)
    kld_prior = 0.5 * torch.sum(mean.pow(2) + logvar.exp())
    kld_loss = kld_encoder_posterior + kld_prior
    if score is not None and DSM is None:
        kld_loss = kld_encoder_posterior - score
    elif DSM is not None:
        kld_loss = kld_encoder_posterior + DSM
    return recon_loss + beta * kld_loss, recon_loss, kld_encoder_posterior, kld_prior

#### CNN VAE

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(1, 4, kernel_size=4, stride=2, padding=1)  # Output: 32x14x14
        self.conv2 = nn.Conv2d(4, 16, kernel_size=4, stride=2, padding=1) # Output: 64x7x7
        self.conv3 = nn.Conv2d(16, 64, kernel_size=3, stride=2, padding=1) # Output: 128x4x4
        self.fc1 = nn.Linear(64 * 4 * 4, 256)
        self.fc2_mu = nn.Linear(256, 1 * 16 * 16)
        self.fc2_logvar = nn.Linear(256, 1 * 16 * 16)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        mu = self.fc2_mu(x)
        logvar = self.fc2_logvar(x)
        return mu, logvar

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        self.fc = nn.Linear(256, 64 * 7 * 7)  # Adjusted to match the size before reshaping
        self.deconv1 = nn.ConvTranspose2d(64, 16, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(16, 4, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(4, 1, kernel_size=4, stride=2, padding=1)
        self.fc1 = nn.Linear(56*56, 28*28)

    def forward(self, x):
        x = F.relu(self.fc(x))
        x = x.view(x.size(0), 64, 7, 7)  # Reshape to match the size before upsampling
        x = F.relu(self.deconv1(x))
        x = F.relu(self.deconv2(x))
        x = F.relu(self.deconv3(x)).view((x.shape[0], -1))
        x = torch.sigmoid(self.fc1(x))
        return x
        
class CNN_VAE(nn.Module):
    def __init__(self):
        super(CNN_VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decoder(z)
        # print(f'recon: {recon.shape}')
        return recon.view((z.shape[0], -1)), z.view((z.shape[0], -1)), mu.view((z.shape[0], -1)), logvar.view((z.shape[0], -1))

## UNet

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_dim, out_dim):
        super(ResidualBlock, self).__init__()
        self.fc1 = nn.Linear(in_dim, out_dim)
        self.ln1 = nn.LayerNorm(out_dim)
        self.swish = nn.SiLU()
        self.fc2 = nn.Linear(out_dim, out_dim)
        self.ln2 = nn.LayerNorm(out_dim)

    def forward(self, x):
        identity = x
        out = self.fc1(x)
        out = self.ln1(out)
        out = self.swish(out)
        out = self.fc2(out)
        out = self.ln2(out)
        out += identity  # Skip connection
        return self.swish(out)

# UNet with advanced techniques
class UNet(nn.Module):
    def __init__(self, in_dim, out_dim, num_timesteps, embedding_dim=2, multiplier=4, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"), is_warm_init=False):
        super(UNet, self).__init__()
        self.num_timesteps = num_timesteps
        self.device = device

        # Define encoder layers
        self.encoder = nn.Sequential(
            nn.Linear(in_dim, multiplier*in_dim),
            nn.SiLU(),
            ResidualBlock(multiplier*in_dim, multiplier*in_dim),
            nn.Dropout(0.1),
            ResidualBlock(multiplier*in_dim, multiplier*in_dim),
            nn.Dropout(0.1)
        )

        # Define decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(multiplier*in_dim + embedding_dim, multiplier*in_dim),
            nn.SiLU(),
            ResidualBlock(multiplier*in_dim, multiplier*in_dim),
            nn.Dropout(0.1),
            ResidualBlock(multiplier*in_dim, multiplier*in_dim),
            nn.Dropout(0.1),
            nn.Linear(multiplier*in_dim, out_dim)
        )

        # Define time step embedding layer for decoder
        self.embedding = nn.Embedding(num_timesteps, embedding_dim)

        if is_warm_init:
            self.warm_init()

    def forward(self, x, timestep, enc_sigma=None):
        # Encoder
        if enc_sigma is not None:
            encoded_enc_sigma = self.encoder(enc_sigma)
        else:
            encoded_enc_sigma = 0
        x = self.encoder(x) + encoded_enc_sigma

        # Decoder
        x = self.decoder(torch.hstack((x, self.embedding(timestep))))

        return x

    def warm_init(self):
        # Custom initialization for better convergence
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Embedding):
                nn.init.normal_(m.weight, mean=0, std=0.01)

## Score-Based Models

In [None]:
class EMA:
    def __init__(self, model, decay):
        self.model = model
        self.decay = decay
        self.shadow = self._clone_model_params()

    def _clone_model_params(self):
        shadow = {}
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                shadow[name] = param.data.clone()
        return shadow

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

class Score_fn(nn.Module):
    def __init__(self, model, ema=None, ema_decay=0.99, sigma_min=0.01, sigma_max=50, num_timesteps=1000, device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")):
        """Construct a score function model.
        
        Args:
          sigma_min: smallest sigma.
          sigma_max: largest sigma.
          num_timestep: number of discretization steps
        """
        super(Score_fn, self).__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
        self.discrete_sigma = torch.exp(torch.linspace(np.log(self.sigma_min), np.log(self.sigma_max), num_timesteps)).to(device)
        self.num_timesteps = num_timesteps
        self.model = model
        self.device = device
        self.loss_dict = {}
        self.total_loss = 0
        self.loss_counter = 0
        if ema is not None:
            self.ema = ema(model, decay=ema_decay)

        # Learnable parameter for residual score function and assures value between [0,1]
        self.lbda = nn.ParameterList([nn.Parameter(torch.tensor([0.0]))])

    def to_device(self):
        self.model = self.model.to(self.device)
        
    # Compute denoising score matching loss
    def compute_DSM_loss(self, x, t, enc_mu=None, enc_sigma=None, alpha=None, turn_off_enc_sigma=False, learn_lbda=False, is_mixing=False, is_residual=False, is_vanilla=False, is_LSGM=False, divide_by_sigma=False):
        sigmas = self.discrete_sigma[t.long()].view(x.shape[0], *([1] * len(x.shape[1:])))
        noise = torch.randn_like(x, device=self.device) * sigmas
        perturbed_data = x + noise
        if is_mixing:
            score = self.get_mixing_score_fn(perturbed_data, t, alpha=alpha, is_residual=is_residual, is_vanilla=is_vanilla, divide_by_sigma=divide_by_sigma)
        elif is_residual:
            enc_eps = x - enc_mu
            score = self.get_residual_score_fn(perturbed_data, t, enc_eps, enc_sigma, turn_off_enc_sigma, learn_lbda, is_vanilla=is_vanilla, divide_by_sigma=divide_by_sigma)
        else:
            score = self.get_score_fn(perturbed_data, t)
        target = -noise / (sigmas ** 2)
        losses = torch.square(score - target)
        losses = 1/2. * torch.sum(losses.reshape(losses.shape[0], -1), dim=-1) * sigmas.squeeze() ** 2
        if is_LSGM:
            return torch.sum(losses)
        else:
            return torch.mean(losses)

    # Get score function
    def get_score_fn(self, x, t, detach=False):
        if detach:
            self.model.eval()
            return (self.model(x, t) / self.discrete_sigma[t.long()].view(x.shape[0], *([1] * len(x.shape[1:])))).detach()
        else:
            return self.model(x, t) / self.discrete_sigma[t.long()].view(x.shape[0], *([1] * len(x.shape[1:])))

    # Our implementation of residual score function
    def get_residual_score_fn(self, x, t, enc_eps, enc_sigma, detach=False, turn_off_enc_sigma=False, learn_lbda=False):

        # turn on eval for detach
        if detach:
            self.model.eval()
        
        # Computes learnable score
        learnable_score = self.model(x, t) / self.discrete_sigma[t.long()].view(x.shape[0], *([1] * len(x.shape[1:])))

        # Learns lbda hyperparameter
        if learn_lbda:
            learnable_score = self.lbda * learnable_score

        # Makes the variance equal 1 when turned off and variance equal to the encoder variance
        if turn_off_enc_sigma:
            residual_score = - enc_eps
        else:
            residual_score = - enc_eps / (enc_sigma ** 2)
        if detach:
            self.model.train()
            return (learnable_score + residual_score).detach()
        else:
            return learnable_score + residual_score

    # Training LSGM Mixing Normal and Neural Score Functions based on this paper https://arxiv.org/pdf/2106.05931
    # if no alpha param is given assumed alpha is learned by the model. If it is residual behaves like Prof. Inouye's idea
    def get_mixing_score_fn(self, x, t, alpha=None, is_residual=False, is_vanilla=False, detach=False, divide_by_sigma=False):

        if detach:
            self.model.eval()

        # Converts lbda to alpha to match LGSM notation and bounds [0, 1]
        if alpha is None:
            # alpha = torch.relu(torch.tanh(self.lbda[0]))
            alpha = torch.sigmoid(self.lbda[0])
            # print(f"alpha: {alpha}")
        else:
            alpha = alpha.to(self.device)

        if divide_by_sigma:
            learnable_score = alpha * self.model(x, t) / self.discrete_sigma[t.long()].view(x.shape[0], *([1] * len(x.shape[1:])))
        else:
            learnable_score = alpha * self.model(x, t)

        # Turning on the residual flag is identical to Prof. Inouye's method
        if is_residual:
            residual_score = - x
        else:
            residual_score = - (1 - alpha) * x

        if detach:
            if is_vanilla:
                return learnable_score.detach()
            self.model.train()
            return (learnable_score + residual_score).detach()
        else:
            if is_vanilla:
                return learnable_score
            return learnable_score + residual_score


    def get_LSGM_loss(self, x, t=None, is_mixing=False, is_residual=False, is_vanilla=False, alpha=None):
        if t is None:
            t = torch.randint(0, self.num_timesteps, (x.shape[0],), device=device)
        
        loss = self.compute_DSM_loss(x, t, is_mixing=is_mixing, is_residual=is_residual, alpha=alpha, is_vanilla=is_vanilla, is_LSGM=True, divide_by_sigma=True)
        return loss

    # Update one batch and add shrink the max timestep for reducing the variance range of training (default is equal to defined num_timestep).
    # When verbose is true, gets the average loss up until last verbose and saves to loss dict
    def update_score_fn(self, x, optimizer, alpha=None, max_timestep=None, t=None, verbose=False, is_mixing=False, is_residual=False, is_vanilla=False, divide_by_sigma=False):
        # TODO: Add ema optimization
        if max_timestep is None or max_timestep > self.num_timesteps:
            max_timestep = self.num_timesteps

        if t is None:
            t = torch.randint(0, max_timestep, (x.shape[0],), device=device)
        
        loss = self.compute_DSM_loss(x, t, is_mixing=is_mixing, is_residual=is_residual, alpha=alpha, is_vanilla=is_vanilla, divide_by_sigma=False)

        self.total_loss += loss.item()
        self.loss_counter += 1.
        if verbose:
            avg_loss = self.total_loss / self.loss_counter
            self.reset_loss_count()
            self.update_loss_dict(avg_loss)
            print(avg_loss)
            print(f'alpha: {torch.sigmoid(self.lbda[0])}')
        
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        # Update EMA
        if hasattr(self, 'ema'):
            self.ema.update()
            
    # Update for residual score model training
    def update_residual_score_fn(self, x, enc_mu, enc_sigma, optimizer, max_timestep=None, learn_lbda=False, turn_off_enc_sigma=False, t=None, verbose=False):
        if max_timestep is None or max_timestep > self.num_timesteps:
            max_timestep = self.num_timesteps
        
        if t is None:
            t = torch.randint(0, max_timestep, (x.shape[0],), device=device)
        
        loss = self.compute_DSM_loss(x, t, is_residual=True, enc_mu=enc_mu, enc_sigma=enc_sigma, turn_off_enc_sigma=turn_off_enc_sigma, learn_lbda=learn_lbda)

        self.total_loss += loss.item()
        self.loss_counter += 1.
        if verbose:
            avg_loss = self.total_loss / self.loss_counter
            self.reset_loss_count()
            self.update_loss_dict(avg_loss)
            print(avg_loss)
        
        optimizer.zero_grad()
        loss.backward(retain_graph=True)
        optimizer.step()

        # Update EMA
        if hasattr(self, 'ema'):
            self.ema.update()

    def add_EMA_training(self, ema, decay=0.99):
        self.ema = ema(self.model, decay)

    def update_param_with_EMA(self):
        if hasattr(self, 'ema'):
            for name, param in self.model.named_parameters():
                if param.requires_grad and name in self.ema.shadow:
                    param.data.copy_(self.ema.shadow[name])
        else:
            raise AttributeError("EMA model is not defined in the class. Please use add_EMA_training class function and retrain")
        
    # Draws a vector field of the score function
    def draw_gradient_field(self, xlim, ylim, t=0, x_num=20, y_num=20, file="./Score_Function", noise_label=1, save=False, data=None, labels=None, n_samples=100, alpha=None, is_mixture=False, is_residual=False, is_vanilla=False):
        x, y = np.meshgrid(np.linspace(xlim[0], xlim[1], x_num), np.linspace(ylim[0], ylim[1], y_num))
        x_ = torch.from_numpy(x.reshape(-1, 1)).type(torch.float).to(self.device)
        y_ = torch.from_numpy(y.reshape(-1, 1)).type(torch.float).to(self.device)
    
        input = torch.hstack((x_, y_))

        if data is not None:
            if isinstance(data, torch.Tensor):
                data = data.detach()
                if data.is_cuda:
                    data = data.cpu().numpy()
            else:
                return data

            if labels is not None:
                data1, data2 = data.chunk(2)
                labels1, labels2 = labels.view((-1,)).chunk(2)
                data1_l1, data1_l2 = data1[labels1==0], data1[labels1==1]
                data2_l1, data2_l2 = data2[labels2==0], data2[labels1==1]
                plt.scatter(data1_l1[:n_samples, 0], data1_l1[:n_samples, 1], marker='x', label='D1_L1', c='b', s=20)
                plt.scatter(data1_l2[:n_samples, 0], data1_l2[:n_samples, 1], marker='o', label='D1_L2', c='b', s=20)
                plt.scatter(data2_l1[:n_samples, 0], data2_l1[:n_samples, 1], marker='+', label='D2_L1', c='g', s=20)
                plt.scatter(data2_l2[:n_samples, 0], data2_l2[:n_samples, 1], marker='o', label='D2_L2', c='g', s=20)
                plt.legend()
            else:
                plt.scatter(data[:, 0], data[:, 1])

        if is_mixture:
            score_fn = self.get_mixing_score_fn(input, torch.ones((x_num * y_num,), device=device).type(torch.long) * t, detach=True, alpha=alpha, is_vanilla=is_vanilla)
        elif is_residual:
            score_fn = self.get_mixing_score_fn(input, torch.ones((x_num * y_num,), device=device).type(torch.long) * t, detach=True, alpha=alpha, is_residual=True, is_vanilla=is_vanilla)
        else:
            score_fn = self.get_score_fn(input, torch.ones((x_num * y_num,), device=device).type(torch.long) * t, detach=True)
            
        score_fn_x = score_fn[:, 0].cpu().numpy().reshape(x_num, y_num)
        score_fn_y = score_fn[:, 1].cpu().numpy().reshape(x_num, y_num)
        plt.quiver(x, y, score_fn_x, score_fn_y, color='r')
        plt.title('Score Function')
        plt.grid()
        plt.show()
        if save:
            plt.savefig(f"{file}")

    # Resets the total loss and respective count of updates
    def reset_loss_count(self):
        self.total_loss = 0
        self.loss_counter = 0

    def update_loss_dict(self, loss):
        if not self.loss_dict:
            self.loss_dict.update({'DSMloss': [loss]})
        else:
            self.loss_dict['DSMloss'].append(loss)
    
    def get_loss_dict(self):
        return self.loss_dict

In [None]:
def get_lp_dist(p=2):
    return nn.PairwiseDistance(p=p, keepdim=True)
# def move_metric_x_to_device(self, metric_x, device):
# metric_x.to(device)
def compute_gp_loss(x, z, dist_func_x, dist_func_z):
    batch_size = len(x)
    loss = 0
    for idx in range(batch_size-1):
        p_dist_x = dist_func_x(x[idx], x[idx+1:]).squeeze()
        p_dist_z = dist_func_z(z[idx], z[idx+1:]).squeeze()
        loss += ((p_dist_x-p_dist_z)**2).sum()
    return loss/(batch_size-1)

def pairwise_distances(x, y=None):
    '''
    Input: x is a Nxd matrix
    y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
    if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''

    x_norm = (x**2).sum(1).view(-1, 1)
    if y is not None:
        y_t = torch.transpose(y, 0, 1)
        y_norm = (y**2).sum(1).view(1, -1)
    else:
        y_t = torch.transpose(x, 0, 1)
        y_norm = x_norm.view(1, -1)
        dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
        # Ensure diagonal is zero if x=y
    if y is None:
        dist = dist - torch.diag(dist.diag())
    return torch.clamp(dist, 0.0, np.inf)

def calculate_gp_loss(X_list, Z_list):
    loss = 0
    for X, Z in zip(X_list, Z_list):
        loss += torch.sum(torch.abs(pairwise_distances(X)-pairwise_distances(Z)))
    return loss

In [None]:
# Function to display reconstructed images
def display_reconstructed_images(epoch, vae_model, data, n_samples=10, dim=[1, 28, 28], is_flip=False):
    vae_model.eval()
    with torch.no_grad():
        data = data.to(device)[:n_samples]
        recon_x, z, _, _ = vae_model(data)
        recon_x = recon_x[:n_samples]
        comparison = torch.cat([data.view(-1, dim[0], dim[1], dim[2]), recon_x.view(-1, dim[0], dim[1], dim[2])])
        comparison = make_grid(comparison, nrow=data.size(0))
        comparison = comparison.cpu().numpy().transpose(1, 2, 0)
        
        plt.figure(figsize=(15, 5))
        plt.imshow(comparison, cmap='gray')
        plt.axis('off')
        plt.title(f'Reconstructed Images at Epoch {epoch}')
        plt.show()

def display_reconstructed_and_flip_images(epoch, vae_model, flip_vae_model, data, n_samples=10, dim=[1, 28, 28], flip_dim=[3, 32, 32], is_mnist=True, is_both=True):
    vae_model.eval()
    with torch.no_grad():
        data = data.to(device)[:n_samples]
        recon_x, z, _, _ = vae_model(data)
        recon_x_flip = flip_vae_model.decode(z)
        data = data[:n_samples]
        recon_x = recon_x[:n_samples]
        recon_x_flip = recon_x_flip[:n_samples]

        data = data.view(n_samples, dim[0], dim[1], dim[2])
        recon_x = recon_x.view(n_samples, dim[0], dim[1], dim[2])
        recon_x_flip = recon_x_flip.view(n_samples, flip_dim[0], flip_dim[1], flip_dim[2])
        z = z[:n_samples]
        fig, axes = plt.subplots(3, n_samples, figsize=(n_samples * 3 / 2, 4.5))
        if is_mnist:
            main_color = 'gray'
            flip_color = None
        elif is_both:
            main_color = 'gray'
            flip_color = 'gray'
        else:
            flip_color = 'gray'
            main_color = None
            
        for i in range(n_samples):
            axes[0, i].imshow(np.transpose(data[i].detach().cpu().numpy(), (1, 2, 0)), cmap=main_color)
            axes[0, i].axis('off')
            
            axes[1, i].imshow(np.transpose(recon_x[i].detach().cpu().numpy(), (1, 2, 0)), cmap=main_color)
            axes[1, i].axis('off')
            
            axes[2, i].imshow(np.transpose(recon_x_flip[i].detach().cpu().numpy(), (1, 2, 0)), cmap=flip_color)
            axes[2, i].axis('off')

        plt.tight_layout()
        plt.show()

## Classifier

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F
from torch.optim.lr_scheduler import StepLR

# Define the transformation to be applied to the images
transform = transforms.Compose([
    transforms.ToTensor()  # Convert images to PyTorch tensors
])

# Download and load the training dataset
train_dataset = datasets.MNIST(root='data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# Download and load the test dataset
test_dataset = datasets.MNIST(root='data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Define a simple ResNet-like architecture for the classifier
class CNN(nn.Module):
    def __init__(self, num_classes=10):
        super(CNN, self).__init__()
        # Convolutional layer (sees 1x16x16 image tensor)
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.dropout = nn.Dropout(0.25)
        # Fully connected layer
        self.fc1 = nn.Linear(128 * 2 * 2, 256)  # assuming the input is (1, 16, 16)
        self.fc2 = nn.Linear(256, num_classes)
    
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 2 * 2)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize classifier, optimizer, and learning rate scheduler
classifier = CNN().to(device)
optimizer = optim.Adam(classifier.parameters(), lr=0.001)
# scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

# Training function
def train_classifier(classifier, device, train_loader, optimizer, epoch):
    classifier.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = classifier(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}'
                  f' ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Test function
def test_classifier(classifier, device, test_loader, preprocess_model=None):
    classifier.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = classifier(data)
            test_loss += F.cross_entropy(output, target, reduction='sum').item()  # Sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # Get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)}'
          f' ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Train and evaluate the classifier
# num_epochs = 100
# for epoch in range(1, num_epochs + 1):
#     train_classifier(classifier, device, train_loader, optimizer, epoch)
#     test_classifier(classifier, device, test_loader)
    # scheduler.step()

## Domain Adaptation

#### MNIST and MNIST Flip

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from itertools import chain
import numpy as np
from tqdm import tqdm
import warnings

# Ignore all warnings
warnings.filterwarnings("ignore")

# Model, optimizer
batch_size = 1024
timesteps = 100
is_vanilla = True
mnist_input_dim = 1 * 28 * 28
flip_mnist_input_dim = 1 * 28 * 28
hidden_dim = 1024
latent_dim = 256
beta = 2
sigma_max = 0.4
sigma_min = 0.01
loops = 1
alpha = None
gp_lambda = 0.05
n_print_per_epoch = 1
classifier_lambda = 10

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the transformation to be applied to the images
mnist_transform = transforms.Compose([
    transforms.ToTensor()  # Convert images to PyTorch tensor
])

flip_mnist_transform = transforms.Compose([
    transforms.ToTensor()  # Convert images to PyTorch tensors
])

# Download and load the training dataset
mnist_dataset = datasets.MNIST(root='data', train=True, download=True, transform=mnist_transform)
# flip_mnist_dataset = datasets.flip_mnist(root='./data', split='train', download=True, transform=flip_mnist_transform)
flip_mnist_dataset = datasets.MNIST(root='data', train=True, download=True, transform=mnist_transform)

# Create a DataLoader
mnist_dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
flip_mnist_dataloader = DataLoader(flip_mnist_dataset, batch_size=batch_size, shuffle=True)
mnist_dataloader_score = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
flip_mnist_dataloader_score = DataLoader(flip_mnist_dataset, batch_size=batch_size, shuffle=True)

# Setup Model
# mnist_vae = VAE(mnist_input_dim, hidden_dim, latent_dim).to(device)
# flip_mnist_vae = VAE(flip_mnist_input_dim, hidden_dim, latent_dim).to(device)
mnist_vae = CNN_VAE().to(device)
flip_mnist_vae = CNN_VAE().to(device)

classifier = CNN().to(device)

score_model = Score_fn(UNet(in_dim=latent_dim, out_dim=latent_dim, num_timesteps=timesteps, is_warm_init=False), sigma_min=sigma_min, sigma_max=sigma_max, num_timesteps=timesteps, device=device).to(device)
# optimizer_vae = optim.Adam(chain(mnist_vae.parameters(), flip_mnist_vae.parameters()), lr=1e-3)
optimizer_vae = optim.Adam(chain(mnist_vae.parameters(), flip_mnist_vae.parameters(), classifier.parameters()), lr=1e-3)
optimizer_score = torch.optim.Adam(score_model.parameters(), 1e-4)

# mnist_vae.load_state_dict(torch.load('cnn_mnist_vae_test.pth'))
# flip_mnist_vae.load_state_dict(torch.load('cnn_mnist_flip_vae_test.pth'))
# score_model.load_state_dict(torch.load('cnn_score_model_test.pth'))
# classifier.load_state_dict(torch.load('cnn_classifier_test.pth'))

total_loss_list = []
recon_loss_list = []
kl_loss_list = []
gp_loss_list = []

# Training
num_epochs = 1500
for epoch in tqdm(range(num_epochs)):
    mnist_vae.train()
    flip_mnist_vae.train()
    total_loss = 0
    total_recon_loss = 0
    total_kld_encoder_posterior = 0
    total_kld_prior = 0
    total_gp_loss = 0
    total_classifier_loss = 0
    for i, (data1, data2) in enumerate(zip(mnist_dataloader, flip_mnist_dataloader)):
        x1, label1 = data1
        x2, label2 = data2
        # x1, x2 = x1.to(device).view(x1.shape[0], -1), (1.0 - x2).to(device).view(x2.shape[0], -1) # Reshape
        # x1, x2 = (1.0 - x1).to(device).view(x1.shape[0], -1), (x2).to(device).view(x2.shape[0], -1) # Reshape
        x1, x2 = (1.0 - x1).to(device), x2.to(device) # Reshape
        optimizer_vae.zero_grad()

        recon_x1, z1, mean1, logvar1 = mnist_vae(x1)
        recon_x2, z2, mean2, logvar2 = flip_mnist_vae(x2)
        # print(recon_x2.shape)
        x, recon_x, z, mean, logvar = [x1.view((x1.shape[0], -1)), x2.view((x1.shape[0], -1))], [recon_x1, recon_x2], torch.vstack((z1, z2)), torch.vstack((mean1, mean2)), torch.vstack((logvar1, logvar2))

        # Score loss
        DSM = score_model.get_LSGM_loss(z, is_mixing=True, is_residual=True, is_vanilla=is_vanilla)
        score = score_model.get_mixing_score_fn(z, 30*torch.ones(z.shape[0], device=device).type(torch.long), detach=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha) - 0.05 * z
        score = torch.matmul(score.unsqueeze(1), z.unsqueeze(-1)).sum()

        # VAE loss
        loss, recon_loss, kld_encoder_posterior, kld_prior = vae_loss(recon_x, x, mean, logvar, beta, score=score, DSM=None)

        # dist_func_x = get_lp_dist(p=2)
        # dist_func_z = get_lp_dist(p=2)
        # gp_loss = sum([compute_gp_loss(x, z, dist_func_x, dist_func_z) for x, z in zip([x1, x2], [z1, z2])])
        gp_loss = gp_lambda * calculate_gp_loss([x1.view((x1.shape[0], -1)), x2.view((x1.shape[0], -1))], [z1, z2])
        
        # gp_loss_list.append(gp_loss.item())
        
        loss += gp_loss

        output = classifier(z1.view((z1.shape[0], 1, 16, 16)))
        label1 = label1.to(device)
        classifier_loss = classifier_lambda * F.cross_entropy(output, label1, reduction='sum')
        total_classifier_loss += classifier_loss.item()
        
        loss += classifier_loss
        
        total_loss_list.append((loss).item())
        recon_loss_list.append((recon_loss).item())
        kl_loss_list.append((kld_encoder_posterior+kld_prior).item())
        
        loss.backward()
        optimizer_vae.step()

        total_gp_loss += gp_loss.item()
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kld_encoder_posterior += kld_encoder_posterior.item()
        total_kld_prior += kld_prior.item()

        # Update Score Function
        for loop in range(loops):
            data1, data2 = next(iter(zip(mnist_dataloader_score, flip_mnist_dataloader_score)))
            x1, label1 = data1
            x2, label2 = data2
            x1, x2 = (1.0 - x1).to(device), x2.to(device) # Reshape
            # x1, x2 = x1.to(device).view(x1.shape[0], -1), (1.0 - x2).to(device).view(x2.shape[0], -1) # Reshape
            # x1, x2 = (1.0 - x1).to(device).view(x1.shape[0], -1), (x2).to(device).view(x2.shape[0], -1) # Reshape
            recon_x1, z1, mean1, logvar1 = mnist_vae(x1)
            recon_x2, z2, mean2, logvar2 = flip_mnist_vae(x2)
            x, recon_x, z, mean, logvar, labels = [x1.view((x1.shape[0], -1)), x2.view((x1.shape[0], -1))], [recon_x1, recon_x2], torch.vstack((z1, z2)), torch.vstack((mean1, mean2)), torch.vstack((logvar1, logvar2)), torch.vstack((label1, label2))

            if loop == (loops-1) and (epoch+1) % n_print_per_epoch == 0 and i==0:
                print(f"Epoch {epoch+1} DSM average loss:", end=' ') 
                score_model.update_score_fn(z, optimizer=optimizer_score, max_timestep=None, verbose=True, is_mixing=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha)
                # score_model.draw_gradient_field((-0, 0), (-0, 0), t=0, x_num=40, yfl_num=40, data=z.detach().cpu(), labels=labels, save=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha, is_mixture=True)
            else:
                score_model.update_score_fn(z, optimizer=optimizer_score, max_timestep=None, is_mixing=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha)
    # Print every 25 epochs
    if (epoch + 1) % n_print_per_epoch == 0:
        display_reconstructed_and_flip_images(epoch=epoch, vae_model=mnist_vae, flip_vae_model=flip_mnist_vae, data=x1, dim=[1, 28, 28], flip_dim=[1, 28, 28])
        display_reconstructed_and_flip_images(epoch=epoch, vae_model=flip_mnist_vae, flip_vae_model=mnist_vae, data=x2, dim=[1, 28, 28], flip_dim=[1, 28, 28])
        print(f'Epoch {epoch+1}, Total Loss: {total_loss:.2f}, Recon Loss: {total_recon_loss:.2f}, '
              f'Encoder Posterior Loss: {total_kld_encoder_posterior:.2f}, Prior Loss: {total_kld_prior:.2f}, '
              f'Total Gp loss: {total_gp_loss}, Total Classifier loss: {total_classifier_loss}')
        # print(f'Epoch {epoch+1}, Total Loss: {total_loss:.2f}, Recon Loss: {total_recon_loss:.2f}, '
        #       f'Encoder Posterior Loss: {total_kld_encoder_posterior:.2f}, Prior Loss: {total_kld_prior:.2f}'
        #         )


plt.plot(total_loss_list, label='total lost')
plt.show()
plt.plot(recon_loss_list, label='recon list')
plt.show()
plt.plot(kl_loss_list, label='kl list')
plt.show()
plt.plot(gp_loss_list, label='gp list')

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from itertools import chain
import numpy as np
from tqdm import tqdm
import warnings

# Ignore all warnings
warnings.filterwarnings("ignore")

# Model, optimizer
batch_size = 1024
timesteps = 100
is_vanilla = True
mnist_input_dim = 1 * 28 * 28
svhn_input_dim = 3 * 32 * 32
hidden_dim = 1024
latent_dim = 256
beta = 2
sigma_max = 0.4
sigma_min = 0.01
loops = 1
alpha = None
gp_lambda = 0.2
classifier_lambda = 10
n_print_per_epoch = 1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the transformation to be applied to the images
mnist_transform = transforms.Compose([
    transforms.ToTensor()  # Convert images to PyTorch tensor
])

svhn_transform = transforms.Compose([
    transforms.ToTensor()  # Convert images to PyTorch tensors
])

# Download and load the training dataset
mnist_dataset = datasets.MNIST(root='data', train=True, download=True, transform=mnist_transform)
svhn_dataset = datasets.SVHN(root='./data', split='train', download=True, transform=svhn_transform)

# Create a DataLoader
mnist_dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
svhn_dataloader = DataLoader(svhn_dataset, batch_size=batch_size, shuffle=True)
mnist_dataloader_score = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
svhn_dataloader_score = DataLoader(svhn_dataset, batch_size=batch_size, shuffle=True)

# Setup Model
mnist_vae = VAE(mnist_input_dim, hidden_dim, latent_dim).to(device)
svhn_vae = VAE(svhn_input_dim, hidden_dim, latent_dim).to(device)

# Initialize classifier, optimizer, and learning rate scheduler
classifier = CNN().to(device)

score_model = Score_fn(UNet(in_dim=latent_dim, out_dim=latent_dim, num_timesteps=timesteps, is_warm_init=False), sigma_min=sigma_min, sigma_max=sigma_max, num_timesteps=timesteps, device=device).to(device)
optimizer_vae = optim.Adam(chain(mnist_vae.parameters(), svhn_vae.parameters(), classifier.parameters()), lr=1e-3)
optimizer_score = torch.optim.Adam(score_model.parameters(), 1e-4)

# mnist_vae.load_state_dict(torch.load('svhn_mnist_vae_test.pth'))
# svhn_vae.load_state_dict(torch.load('svhn_flip_vae_test.pth'))
# score_model.load_state_dict(torch.load('svhn_score_model_test.pth'))
# classifier.load_state_dict(torch.load('svhn_classifier_test.pth'))

total_loss_list = []
recon_loss_list = []
kl_loss_list = []
gp_loss_list = []
classifier_loss_list = []

# Training
num_epochs = 1500
for epoch in tqdm(range(num_epochs)):
    mnist_vae.train()
    svhn_vae.train()
    total_loss = 0
    total_recon_loss = 0
    total_kld_encoder_posterior = 0
    total_kld_prior = 0
    total_gp_loss = 0
    total_classifier_loss = 0
    
    for i, (data1, data2) in enumerate(zip(mnist_dataloader, svhn_dataloader)):
        x1, label1 = data1
        x2, label2 = data2
        label1 = label1.to(device)
        x1, x2 = x1.to(device).view(x1.shape[0], -1), x2.to(device).view(x2.shape[0], -1) # Reshape
        optimizer_vae.zero_grad()

        recon_x1, z1, mean1, logvar1 = mnist_vae(x1)
        recon_x2, z2, mean2, logvar2 = svhn_vae(x2)
        x, recon_x, z, mean, logvar = [x1, x2], [recon_x1, recon_x2], torch.vstack((z1, z2)), torch.vstack((mean1, mean2)), torch.vstack((logvar1, logvar2))

        # Score loss
        DSM = score_model.get_LSGM_loss(z, is_mixing=True, is_residual=True, is_vanilla=is_vanilla)
        score = score_model.get_mixing_score_fn(z, 30*torch.ones(z.shape[0], device=device).type(torch.long), detach=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha) - 0.05 * z
        score = torch.matmul(score.unsqueeze(1), z.unsqueeze(-1)).sum()

        # VAE loss
        loss, recon_loss, kld_encoder_posterior, kld_prior = vae_loss(recon_x, x, mean, logvar, beta, score=score, DSM=None)

        # Classifier loss
        output = classifier(z1.view((z1.shape[0], 1, 16, 16)))
        classifier_loss = classifier_lambda * F.cross_entropy(output, label1, reduction='sum')
        loss += classifier_loss

        dist_func_x = get_lp_dist(p=2)
        dist_func_z = get_lp_dist(p=2)
        gp_loss = sum([compute_gp_loss(x, z, dist_func_x, dist_func_z) for x, z in zip([x1, x2], [z1, z2])])
        gp_loss = gp_lambda * calculate_gp_loss([x1, x2], [z1, z2])

        classifier_loss_list.append(classifier_loss.item())
        gp_loss_list.append(gp_loss.item())
        
        loss += gp_loss
        total_classifier_loss += classifier_loss.item()
        total_gp_loss += gp_loss.item()
        total_loss_list.append((loss).item())
        recon_loss_list.append((recon_loss).item())
        kl_loss_list.append((kld_encoder_posterior+kld_prior).item())
        
        loss.backward()
        optimizer_vae.step()

        
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kld_encoder_posterior += kld_encoder_posterior.item()
        total_kld_prior += kld_prior.item()

        # Update Score Function
        for loop in range(loops):
            data1, data2 = next(iter(zip(mnist_dataloader_score, svhn_dataloader_score)))
            x1, label1 = data1
            x2, label2 = data2
            x1, x2 = x1.to(device).view(x1.shape[0], -1), x2.to(device).view(x2.shape[0], -1) # Reshape
            recon_x1, z1, mean1, logvar1 = mnist_vae(x1)
            recon_x2, z2, mean2, logvar2 = svhn_vae(x2)
            x, recon_x, z, mean, logvar, labels = [x1, x2], [recon_x1, recon_x2], torch.vstack((z1, z2)), torch.vstack((mean1, mean2)), torch.vstack((logvar1, logvar2)), torch.vstack((label1, label2))

            if loop == (loops-1) and (epoch+1) % n_print_per_epoch == 0 and i==0:
                print(f"Epoch {epoch+1} DSM average loss:", end=' ') 
                score_model.update_score_fn(z, optimizer=optimizer_score, max_timestep=None, verbose=True, is_mixing=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha)
                # score_model.draw_gradient_field((-0, 0), (-0, 0), t=0, x_num=40, y_num=40, data=z.detach().cpu(), labels=labels, save=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha, is_mixture=True)
            else:
                score_model.update_score_fn(z, optimizer=optimizer_score, max_timestep=None, is_mixing=True, is_residual=True, is_vanilla=is_vanilla, alpha=alpha)
    # Print every 25 epochs
    if (epoch + 1) % n_print_per_epoch == 0:
        display_reconstructed_and_flip_images(epoch=epoch, vae_model=mnist_vae, flip_vae_model=svhn_vae, data=x1, dim=[1, 28, 28], flip_dim=[3, 32, 32], is_both=False)
        display_reconstructed_and_flip_images(epoch=epoch, vae_model=svhn_vae, flip_vae_model=mnist_vae, data=x2, dim=[3, 32, 32], flip_dim=[1, 28, 28], is_mnist=False, is_both=False)
        print(f'Epoch {epoch+1}, Total Loss: {total_loss:.2f}, Recon Loss: {total_recon_loss:.2f}, '
              f'Encoder Posterior Loss: {total_kld_encoder_posterior:.2f}, Prior Loss: {total_kld_prior:.2f}, '
              f'Total Gp loss: {total_gp_loss}, Total Classifier loss: {total_classifier_loss}')
        # print(f'Epoch {epoch+1}, Total Loss: {total_loss:.2f}, Recon Loss: {total_recon_loss:.2f}, '
        #       f'Encoder Posterior Loss: {total_kld_encoder_posterior:.2f}, Prior Loss: {total_kld_prior:.2f}'
        #         )

    if (epoch + 1) % 100 == 0:
        torch.save(mnist_vae.state_dict(), 'svhn_mnist_vae_test.pth')
        torch.save(svhn_vae.state_dict(), 'svhn_flip_vae_test.pth')
        torch.save(score_model.state_dict(), 'svhn_score_model_test.pth')
        torch.save(classifier.state_dict(), 'svhn_classifier_test.pth')

plt.plot(total_loss_list, label='total lost')
plt.show()
plt.plot(recon_loss_list, label='recon list')
plt.show()
plt.plot(kl_loss_list, label='kl list')
plt.show()
plt.plot(gp_loss_list, label='gp list')


In [None]:
torch.save(mnist_vae.state_dict(), 'svhn_mnist_vae_test.pth')
torch.save(svhn_vae.state_dict(), 'svhn_flip_vae_test.pth')
torch.save(score_model.state_dict(), 'svhn_score_model_test.pth')
torch.save(classifier.state_dict(), 'svhn_classifier_test.pth')