<a href="https://colab.research.google.com/github/karimul/Adversarially-Learned-Anomaly-Detection/blob/master/IGEBM_PCD_ASGLD_v2_Sinkhorn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
import os
drive.mount('/content/drive')

ROOT = "/content/drive/My Drive/Colab Notebooks"
sample_dir = os.path.join(ROOT, 'IGEBM_PCD_ASGLD.v2.Sinkhorn')
if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)
os.chdir(sample_dir)

In [None]:
%load_ext tensorboard

In [None]:
import math
from torch.utils.tensorboard import SummaryWriter
import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

In [None]:
class ConvBNReLU(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1):
        if isinstance(kernel_size, tuple):
            padding = (max(kernel_size) - 1) // 2
        else:
            padding = (kernel_size - 1) // 2

        super(ConvBNReLU, self).__init__()
        self.conv = nn.utils.spectral_norm(nn.Conv2d(
            in_planes,
            out_planes,
            kernel_size,
            stride,
            padding,
            groups=groups,
            bias=False,
        ))
        self.bn_normal = nn.BatchNorm2d(out_planes)
        self.bn_adversial = nn.BatchNorm2d(out_planes)
        self.act = nn.ReLU(inplace=True)

    def forward(self, x, adversial=False):
        x = self.conv(x)
        if adversial:
            x = self.bn_adversial(x)
        else:
            x = self.bn_normal(x)
        x = self.act(x)

        return x

class StandardCNN(nn.Module):
    def __init__(self):
        super(StandardCNN, self).__init__()
        self.conv1 = nn.utils.spectral_norm(nn.Conv2d(3, 64, 3, 1, 1))
        self.conv2 = nn.utils.spectral_norm(nn.Conv2d(64, 64, 4, 2, 1))

        self.conv3 = nn.utils.spectral_norm(nn.Conv2d(64, 128, 3, 1, 1))
        self.conv4 = nn.utils.spectral_norm(nn.Conv2d(128, 128, 4, 2, 1))

        self.conv5 = nn.utils.spectral_norm(nn.Conv2d(128, 256, 3, 1, 1))
        self.conv6 = nn.utils.spectral_norm(nn.Conv2d(256, 256, 4, 2, 1))

        self.conv7 = nn.utils.spectral_norm(nn.Conv2d(256, 512, 3, 1, 1))

        self.pool = nn.MaxPool2d(2, 2)
        self.act = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        self.dense = nn.utils.spectral_norm(nn.Linear(512 * 4 * 4, 1))

    def forward(self, x):

        x = self.act(self.conv1(x))
        x = self.act(self.conv2(x))
        # x = self.pool(x)
        x = self.act(self.conv3(x))
        x = self.act(self.conv4(x))
        # x = self.pool(x)
        x = self.act(self.conv5(x))
        x = self.act(self.conv6(x))
        # x = self.pool(x)
        x = self.act(self.conv7(x))

        x = self.dense(x.view(x.shape[0], -1))

        return x


class StandardBNCNN(nn.Module):
    def __init__(self):
        super(StandardBNCNN, self).__init__()
        self.conv1 = ConvBNReLU(3, 64)
        self.conv2 = ConvBNReLU(64, 64)
        self.conv3 = ConvBNReLU(64, 128)
        self.conv4 = ConvBNReLU(128, 128)
        self.conv5 = ConvBNReLU(128, 256)
        self.conv6 = ConvBNReLU(256, 256)
        self.conv7 = ConvBNReLU(256, 512)
        self.conv8 = ConvBNReLU(512, 512)
        self.max_pool = nn.MaxPool2d((2, 2))
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.utils.spectral_norm(nn.Linear(512, 1))

    def forward(self, x, adversial=False):
        x = self.conv1(x, adversial)
        x = self.conv2(x, adversial)
        x = self.max_pool(x)
        x = self.conv3(x, adversial)
        x = self.conv4(x, adversial)
        x = self.max_pool(x)
        x = self.conv5(x, adversial)
        x = self.conv6(x, adversial)
        x = self.max_pool(x)
        x = self.conv7(x, adversial)
        x = self.conv8(x, adversial)
        # (B, 512, 4, 4)

        x = self.avg_pool(x)
        x = torch.flatten(x, 1)

        x = self.fc(x)

        return x

In [None]:
class SampleReplayBuffer:
    def __init__(self, batch_size, buffer_length=10000, data_size=(3, 32, 32)):
        self.buffer_length = buffer_length
        self.data_size = data_size
        self.buffer = torch.rand((self.buffer_length,) + self.data_size)
        self.index = 0
        self.batch_size = batch_size
        self.cpu = torch.device("cpu")
        self.gpu = torch.device("cuda:0")

    def sample(self):
        indices = torch.randint(low=0, high=self.buffer_length, size=(self.batch_size,))
        return self.buffer[indices].to(self.gpu, non_blocking=True)

    def add_sample(self, x):
        if self.index + self.batch_size >= self.buffer_length:
            end = self.buffer_length - self.index
            self.buffer[self.index : self.index + end] = x[:end].to(
                self.cpu, non_blocking=True
            )

            start = self.batch_size - end
            self.buffer[:start] = x[end:].to(self.cpu, non_blocking=True)
        else:
            self.buffer[self.index : self.index + self.batch_size] = x.to(
                self.cpu, non_blocking=True
            )

        self.index = (self.index + self.batch_size) % self.buffer_length


In [None]:
class SGLDTrainer:
    def __init__(self, dataset=''):
        self.batch_size = 128
        self.last_epoch = -1
        self.epochs = 500
        self.num_workers = 3
        self.smoothness_scale = 1.0

        if dataset == 'CelebA':
            transform = transforms.Compose([
                # resize
                transforms.Resize(32),
                # center-crop
                transforms.CenterCrop(32),
                # to-tensor
                transforms.ToTensor()
            ])
            trainset = torchvision.datasets.CelebA(
                root="./data", split='Train', download=True, transform=transform
            )
        elif dataset == 'AFHQ':
            transform = transforms.Compose([
                # resize
                transforms.Resize(32),
                # center-crop
                transforms.CenterCrop(32),
                # to-tensor
                transforms.ToTensor()
            ])
            trainset = torchvision.datasets.ImageFolder(root='./data/AFHQ', transform=transform)
        else:
            transform = transforms.Compose([transforms.ToTensor()])
            trainset = torchvision.datasets.CIFAR10(
                root="./data", train=True, download=True, transform=transform
            )

        self.trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            drop_last=True,
        )
        # testset = torchvision.datasets.CIFAR10(
        #     root="./data", train=False, download=True, transform=transform
        # )
        # self.testloader = torch.utils.data.DataLoader(
        #     testset,
        #     batch_size=self.batch_size,
        #     shuffle=False,
        #     num_workers=self.num_workers,
        # )

        self.device = torch.device("cuda:0")
        self.model = StandardCNN()
        self.model.to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.0001, betas=[0.0, 0.999])
        self.step = 0

        self.buffer_sample_rate = 0.95
        self.data_size = (3, 32, 32)
        self.sample_replay_buffer = SampleReplayBuffer(
            self.batch_size, data_size=self.data_size
        )

        self.dynamics_steps = 60
        self.step_size = 10
        self.noise_scale = 0.005

        self.noise = 0.1
        self.momentum = 0.9
        self.eps = 1e-6

        self.writer = SummaryWriter()

        self.FloatTensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
        self.LongTensor = torch.cuda.LongTensor if torch.cuda.is_available() else torch.LongTensor

    def train(self):

        for epoch in range(self.last_epoch + 1, self.epochs):
            average_loss = 0.0
            for i, (x, _) in enumerate(self.trainloader, 0):
                output = self.process_batch(x)
                loss = output["loss"].item()
                average_loss += (loss - average_loss) / (i + 1)               

                if self.step % 50 == 0:
                    print(f"[{i+1}/{len(self.trainloader)}] -- loss {loss:.5f} -- avg. loss {average_loss:.5f}")
                    self.log({"average_loss": average_loss, **output})

                self.step += 1

            print(f"\nEpoch {epoch} started...")
            checkpoint = {
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'step' : self.step,
                'epoch' : epoch
            }
            # torch.save(self.model.state_dict(), f"./model{epoch}.pt")
            torch.save(checkpoint, f"./model{epoch}.pt")

    def process_batch(self, x):

        x = x.to(self.device, non_blocking=True)

        sample = self.get_sample()

        self.optimizer.zero_grad()
        losses = self.compute_loss(positive_examples=x, negative_examples=sample)
        losses["loss"].backward()
        self.optimizer.step()

        return {"positive_examples": x, "negative_examples": sample, **losses}

    def compute_loss(self, positive_examples, negative_examples):
        # negative_examples size (128, 3, 32, 32)
        positive_energy = self.model(positive_examples)
        negative_energy = self.model(negative_examples) # size (128,1)

        # print("positive_energy size", positive_energy.size())

        # maximum_likelihood_loss = positive_energy - negative_energy
        smoothness_loss = (
            positive_energy ** 2 + negative_energy ** 2
        ) * self.smoothness_scale

        # print("smoothness_loss",smoothness_loss)
        # total_loss = maximum_likelihood_loss + smoothness_loss

        # Sinkhorn parameters
        epsilon = 0.01
        niter = 100

        total_loss = self.sinkhorn_loss(negative_energy,positive_energy,epsilon,niter) 
        # print("total_loss",total_loss)

        return {
            "loss": total_loss.mean(),
            "positive_energy": positive_energy,
            "negative_energy": negative_energy,
            # "maximum_likelihood_loss": maximum_likelihood_loss,
            # "smoothness_loss": smoothness_loss,
        }

    def sinkhorn_loss(self, x, y, epsilon, niter):
        """
        Given two emprical measures with n points each with locations x and y
        outputs an approximation of the OT cost with regularization parameter epsilon
        niter is the max. number of steps in sinkhorn loop
        """
        def _cost_matrix(x, y, p=2):
            "Returns the matrix of $|x_i-y_j|^p$."
            x_col = x.unsqueeze(-2)
            y_lin = y.unsqueeze(-3)
            c = torch.sum((torch.abs(x_col - y_lin)) ** p, -1)
            return c

        # The Sinkhorn algorithm takes as input three variables :
        C = _cost_matrix(x, y)  # Wasserstein cost function
        x_points = x.shape[-2]
        y_points = y.shape[-2]

        # print("x size", x.size())

        if x.dim() == 2:
            batch_size = 1
        else:
            batch_size = x.shape[0]
        
        # both marginals are fixed with equal weights
        mu = torch.empty(batch_size, x_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / x_points).squeeze().type(self.FloatTensor)
        nu = torch.empty(batch_size, y_points, dtype=torch.float,
                         requires_grad=False).fill_(1.0 / y_points).squeeze().type(self.FloatTensor)
        # size (128,32)
        # print("mu size", mu.size())
        # Parameters of the Sinkhorn algorithm.
        rho = 1  # (.5) **2          # unbalanced transport
        tau = -.8  # nesterov-like acceleration
        lam = rho / (rho + epsilon)  # Update exponent
        thresh = 1e-9  # stopping criterion

        # Elementary operations .....................................................................
        def ave(u, u1):
            "Barycenter subroutine, used by kinetic acceleration through extrapolation."
            return tau * u + (1 - tau) * u1

        def M(u, v):
            "Modified cost for logarithmic updates"
            "$M_{ij} = (-c_{ij} + u_i + v_j) / \epsilon$"
            # print("u", u.unsqueeze(-1))
            # print("v", v.unsqueeze(-2))
            # print("-C",-C)
            return (-C + u.unsqueeze(-1) + v.unsqueeze(-2)) / epsilon

        # Actual Sinkhorn loop ......................................................................
        err = 0.
        u = torch.zeros_like(mu)
        v = torch.zeros_like(nu)
        
        actual_nits = 0  # to check if algorithm terminates because of threshold or max iterations reached

        with torch.no_grad():
            # Sinkhorn iterations
            for _ in range(niter):
                u1 = u  # useful to check the update
                u = epsilon * (torch.log(mu + 1e-8) - torch.logsumexp(M(u, v), dim=-1)) + u
                v = epsilon * (torch.log(nu + 1e-8) - torch.logsumexp(M(u, v).transpose(-2, -1), dim=-1)) + v
                err = (u - u1).abs().sum(-1).mean()

                actual_nits += 1
                if err.item() < thresh:
                    break
            U, V = u, v
        pi = torch.exp(M(U, V))  # Transport plan pi = diag(a)*K*diag(b)
        cost = torch.sum(pi * C, dim=(-2, -1))  # Sinkhorn cost

        return cost
            
    def get_sample(self):
        neg_img = self.get_initial_sample()
        neg_img.requires_grad = True
        self.model.eval()

        # Intialize mean and variance to zero
        mean = torch.zeros_like(neg_img.data)
        std = torch.zeros_like(neg_img.data)
        weight_decay = 5e-4
        for i in range(self.dynamics_steps):
            # Getting mean,std at previous step
            old_mean = mean.clone()
            old_std = std.clone()

            noise = torch.normal(mean=old_mean, std=old_std)
            neg_img.data.add(self.noise, noise)
            if weight_decay != 0:
                neg_img.data.add_(weight_decay, neg_img.data)        
            
            energy = self.model(neg_img)
            energy.backward(torch.ones_like(energy))

            # Updating mean
            mean = mean.mul(self.momentum).add(neg_img.data)

            # Updating std
            part_var1 = neg_img.data.add(-old_mean)
            part_var2 = neg_img.data.add(-mean)

            new_std = torch.pow(old_std,2).mul(self.momentum).addcmul(1,part_var1,part_var2).add(self.eps)                
            new_std = torch.pow(torch.abs_(new_std),1/2)
            std.add_(-1,std).add_(new_std)

            neg_img.grad.data.clamp_(-0.01, 0.01)           
 
            neg_img.data.add_(-self.step_size, neg_img.grad.data)
 
            neg_img.grad.detach_()
            neg_img.grad.zero_()
 
            neg_img.data.clamp_(0, 1)

            # if self.step % 50 == 0:
                # print(energy.detach().mean(), torch.std_mean(neg_img.detach()))

        sample = neg_img.detach()
        self.sample_replay_buffer.add_sample(sample)        
        self.model.train()
        self.model.zero_grad()

        return sample

    def get_initial_sample(self):
        if torch.rand(1) > self.buffer_sample_rate:
            return torch.rand((self.batch_size,) + self.data_size, device=self.device)
        else:
            return self.sample_replay_buffer.sample()

    def load_checkpoint(self):
        checkpoint = torch.load(f"./model28.pt")
        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.step = checkpoint['step']
        self.last_epoch = checkpoint['epoch']         

    def log(self, data):

        self.writer.add_scalar("average loss", data["average_loss"], self.step)
        self.writer.add_scalar("batch loss", data["loss"], self.step)

        std, mean = torch.std_mean(data["positive_energy"])
        self.writer.add_scalar("positive energy mean", mean, self.step)
        self.writer.add_scalar("positive energy std", std, self.step)

        std, mean = torch.std_mean(data["negative_energy"])
        self.writer.add_scalar("negative energy mean", mean, self.step)
        self.writer.add_scalar("negative energy std", std, self.step)

        # self.writer.add_scalar(
        #     "smoothness loss", data["smoothness_loss"].mean(), self.step
        # )

        self.writer.add_images(
            "positive examples", data["positive_examples"], self.step
        )
        self.writer.add_images(
            "negative examples", data["negative_examples"], self.step
        )

In [None]:
tensorboard --logdir runs

In [None]:
trainer = SGLDTrainer('')
# trainer.load_checkpoint()
trainer.train()