In [1]:
import numpy as np
import time
import os
from matplotlib import pyplot as plt
import torch
from torch import nn, optim
from torch.nn.utils import spectral_norm
import torchvision

In [2]:
class LeNet5(nn.Module):
    def __init__(self, use_bn=True, use_sn=False):
        super().__init__()
        
        modules = []
        modules.append(nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2, bias=not use_bn))
        if use_bn:
            modules.append(nn.BatchNorm2d(6))
        modules.append(nn.ReLU(inplace=True))
        modules.append(nn.AvgPool2d(2))
        modules.append(nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=2, bias=not use_bn))
        if use_bn:
            modules.append(nn.BatchNorm2d(16))
        modules.append(nn.ReLU(inplace=True))
        modules.append(nn.AvgPool2d(2))
        modules.append(nn.Flatten())
        modules.append(nn.Linear(784, 120))
        modules.append(nn.ReLU(inplace=True))
        modules.append(nn.Linear(120, 84))
        modules.append(nn.ReLU(inplace=True))
        if use_sn:
            for m_idx, m in enumerate(modules):
                if isinstance(m, (nn.Linear, nn.Conv2d)):
                    modules[m_idx] = spectral_norm(m)
        
        self.fe = nn.Sequential(*modules)
        self.classifier_head = nn.Linear(84, 10)
        self.class_embedding = nn.Linear(10, 84)
        self.critic_head = nn.Linear(84, 1)
        if use_sn:
            self.classifier_head = spectral_norm(self.classifier_head)
            self.class_embedding = spectral_norm(self.class_embedding)
            self.critic_head = spectral_norm(self.critic_head)
        
    def criticize_example(self, x, y):
        x_fe = self.fe(x)
        embedded_y = self.class_embedding(nn.functional.one_hot(y, num_classes=10).to(torch.float))
        out = self.critic_head(x_fe) + (x_fe * embedded_y).sum(dim=1)
        return out
    
    def classify_example(self, x):
        x_fe = self.fe(x)
        out = self.classifier_head(x_fe)
        return out

In [3]:
def hinge_loss(logits, y):
    return nn.functional.relu(1 - y*logits).mean()

def extrapolate_example(x, y_target, model, no_grad=False):
    def loss_fn(x):
        return nn.functional.cross_entropy(model.classify_example(x), y_target)
    x_g = x.clone().requires_grad_(True)
    loss = loss_fn(x_g)
    loss.backward(inputs=[x_g], create_graph=not no_grad)
    delta_x = x_g.grad
    alpha = torch.zeros(x.size(0), 1, 1, 1, device=x.device, requires_grad=True)
    alpha_opt = optim.LBFGS([alpha], line_search_fn='strong_wolfe')
    alpha_opt.zero_grad()
    alpha_opt.step(lambda: loss_fn(torch.sigmoid(x+alpha*delta_x.detach())))
    alpha = alpha.detach()
    print(alpha)
    return delta_x, alpha

def train_step(batch, classifiers, optimizers, device, single_purpose_models=True):
    x, y = batch
    x, y = x.to(device), y.to(device)
    if single_purpose_models:
        critic_idx = 0
    else:
        critic_idx = np.random.randint(2)
    critic, critic_opt = classifiers[critic_idx], optimizers[critic_idx]
    classifier, classifier_opt = classifiers[1-critic_idx], optimizers[1-critic_idx]
    y_target = torch.randint_like(y, 10)
    
    # update critic
    delta_x, alpha_opt = extrapolate_example(x, y_target, classifier, no_grad=True)
    critic_logits_fake = critic.criticize_example(torch.sigmoid(x+alpha_opt*delta_x), y_target)
    critic_logits_real = critic.criticize_example(x, y)
    critic_loss = 0.5*hinge_loss(critic_logits_fake, -1) + 0.5*hinge_loss(critic_logits_real, 1)
    critic_opt.zero_grad()
    critic_loss.backward()
    critic_opt.step()
    
    # update classifier
    delta_x, alpha_opt = extrapolate_example(x, y_target, classifier)
    critic_logits = critic.criticize_example(torch.sigmoid(x + alpha_opt*delta_x), y_target)
    realism_loss = -critic_logits.mean()
    alpha = alpha_opt*torch.rand_like(alpha_opt)
    classifier_logits = classifier.classify_example(alpha*x + (1-alpha)*torch.sigmoid(x+alpha_opt*delta_x))
    classifier_loss = (((alpha_opt-alpha)/alpha_opt)*nn.functional.cross_entropy(classifier_logits, y, reduction='none') + (alpha/alpha_opt)*nn.functional.cross_entropy(classifier_logits, y_target, reduction='none')).mean()
    classifier_opt.zero_grad()
    classifier_loss.backward()
    classifier_grads = [p.grad.clone() for p in classifier.parameters()]
    classifier_grad_norm = torch.norm(classifier_grads)
    classifier_opt.zero_grad()
    realism_loss.backward()
    torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=classifier_grad_norm)
    for p, cg in zip(classifier.parameters(), classifier_grads):
        p.grad += cg
    classifier_opt.step()
    
def eval_step(batch, classifiers, device, single_purpose_models=True):
    x, y = batch
    x, y = x.to(device), y.to(device)
    if single_purpose_models:
        critic_idx = 0
    else:
        critic_idx = np.random.randint(2)
    critic = classifiers[critic_idx]
    classifier = classifiers[1-critic_idx]
    y_target = torch.randint_like(y, 10)
    
    delta_x, alpha_opt = extrapolate_example(x, y_target, classifier, no_grad=True)
    with torch.no_grad():
        critic_logits_fake = critic.criticize_example(torch.sigmoid(x + alpha_opt*delta_x), y_target)
        critic_logits_real = critic.criticize_example(x, y)
        critic_loss = 0.5*hinge_loss(critic_logits_fake, -1) + 0.5*hinge_loss(critic_logits_real, 1)
        realism_loss = -critic_logits_fake.mean()
        classifier_logits = classifier.classify_example(x)
        classifier_loss = nn.functional.cross_entropy(classifier_logits, y)
        classifier_acc = np.mean(np.equal(np.argmax(classifier_logits.detach().cpu().numpy()), y.cpu().numpy()))
    
    return {
        'critic_loss': critic_loss.detach().cpu().numpy(),
        'realism_loss': realism_loss.detach().cpu().numpy(),
        'classifier_loss': classifier_loss.detach().cpu().numpy(),
        'classifier_acc': classifier_acc,
        'alpha_opt': alpha_opt.mean().detach().cpu().numpy()
    }

def run_epoch(train_dataloader, test_dataloader, classifiers, optimizers, device, **step_kwargs):
    t0 = time.time()
    for batch in train_dataloader:
        train_step(batch, classifiers, optimizers, device, **step_kwargs)
    rv = {}
    for batch in test_dataloader:
        step_rv = eval_step(batch, classifiers, device, **step_kwargs)
        for key, item in step_rv.items():
            if not key in rv.keys():
                rv[key] = []
            rv[key].append(item)
    for key, item in rv.items():
        rv[key] = np.mean(item)
        
    x, y = next(iter(test_dataloader))
    x, y = x.to(device), y.to(device)
    y_target = torch.tensor([i for i in range(10) for _ in range(len(y)//10)], device=device, dtype=torch.long)
    x = x[:len(y_target)]
    delta_x, alpha_opt = extrapolate_example(x, y_target, classifiers[1], no_grad=True)
    rv['reference_images'] = x.cpu().numpy()
    rv['generated_images'] = (torch.sigmoid(x + alpha_opt*delta_x)).detach().cpu().numpy()
    rv['time_taken'] = time.time()-t0
    return rv

def display_results(rv):
    for key, item in rv.items():
        if not hasattr(item, '__len__'):
            print('{}: {}'.format(key, item))
        else:
            num_cols = int(np.sqrt(len(item)))
            num_rows = num_cols + int(len(item) > num_cols**2)
            (fig, axes) = plt.subplots(num_rows, num_cols, figsize=(2*num_cols, 2*num_rows))
            for x, ax in zip(item, axes.flatten()):
                ax.imshow(np.moveaxis(x, 0, -1))
            fig.suptitle(key)
    plt.show()
    plt.close('all')

In [4]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
classifier = LeNet5(use_bn=True, use_sn=False).to(device)
critic = LeNet5(use_bn=False, use_sn=True).to(device)
classifier_opt = optim.Adam(classifier.parameters(), lr=1e-5, betas=(0.5, 0.999))
critic_opt = optim.Adam(critic.parameters(), lr=5e-5, betas=(0.5, 0.999))
batch_size = 64
train_dataset = torchvision.datasets.MNIST(root=os.path.join('.', 'downloads'), train=True, transform=torchvision.transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(root=os.path.join('.', 'downloads'), train=False, transform=torchvision.transforms.ToTensor(), download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [4]:
for epoch in range(50):
    print('Epoch {}'.format(epoch+1))
    rv = run_epoch(train_dataloader, test_dataloader, [critic, classifier], [critic_opt, classifier_opt], device, single_purpose_models=True)
    display_results(rv)
    print()

Epoch 1
tensor([[[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0.]]],


        [[[0

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/home/min/a/jgammell/anaconda3/envs/sca_defense/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3444, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_2931292/823666781.py", line 3, in <module>
    rv = run_epoch(train_dataloader, test_dataloader, [critic, classifier], [critic_opt, classifier_opt], device, single_purpose_models=True)
  File "/tmp/ipykernel_2931292/4240372477.py", line 89, in run_epoch
    train_step(batch, classifiers, optimizers, device, **step_kwargs)
  File "/tmp/ipykernel_2931292/4240372477.py", line 48, in train_step
    classifier_grads = [p.grad.clone() for p in classifier.parameters()]
  File "/tmp/ipykernel_2931292/4240372477.py", line 48, in <listcomp>
    classifier_grads = [p.grad.clone() for p in classifier.parameters()]
AttributeError: 'NoneType' object has no attribute 'clone'

During handling of the above exception, another exception occurred:

Traceba

TypeError: object of type 'NoneType' has no len()