In [4]:
import utils
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LRScheduler
	
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
sde = utils.VPSDE(T_max=1, beta_min=0.01, beta_max=10.0)
image_size = 28
classes_by_index  = np.arange(0,10).astype('str')

transform = transforms.Compose([transforms.Resize(image_size),\
                                transforms.ToTensor(),\
                                transforms.Normalize([0.5],[0.5])]) #Normalize to -1,1
trainset = torchvision.datasets.MNIST(root='./data', train=True,
                                    download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False,
                                    download=True, transform=transform)

batch_size = 256
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

In [6]:
class MNISTClassifier(nn.Module):
    """Code from: https://nextjournal.com/gkoehler/pytorch-mnist"""
    def __init__(self):
        super(MNISTClassifier, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x, t):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(x + t[:,None])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)

def train_diffused_classifier(model, sde: utils.ItoSDE, dataloader: DataLoader, optimizer, device, n_epochs: int, print_every: int, scheduler: LRScheduler = None):
    
    model.train()
    model = model.to(DEVICE)
    running_loss_list = []
    lr_list = []
    
    for epoch in range(n_epochs):
        print(f"Epoch: {epoch}")
        running_loss = 0.0
        for idx, (x_inp,target) in enumerate(dataloader):
            
            #Zero gradients:
            optimizer.zero_grad()
            
            #Run forward samples:
            X_t,noise,score,time = sde.run_forward_random_time(x_inp)

            #Send to device:
            X_t = X_t.to(DEVICE)
            noise = noise.to(DEVICE)
            time = time.to(DEVICE)
            
            #Predict score:
            model_pred = model(X_t,time)

            #ONLY THIS LINE CHANGED TO BEFORE: we train the model to minimize the negative log-likelihood:
            loss = F.nll_loss(model_pred, target.to(DEVICE))

            #Optimize:
            loss.backward()
            optimizer.step()

            if scheduler is not None:
                scheduler.step()
                
            # print statistics
            running_loss += loss.detach().item()
            
            if (idx+1) % print_every == 0:
                avg_loss = running_loss/print_every
                running_loss_list.append(avg_loss)
                running_loss = 0.0
                if scheduler is not None:
                    print(f"Loss: {avg_loss:.4f} | {scheduler.get_lr()}")
                    lr_list.append(scheduler.get_lr())
                else:
                    print(f"Loss: {avg_loss:.4f}")

                
    return model,running_loss_list

LEARNING_RATE = 1e-3 #2e-5
WEIGHT_DECAY = 0.0
N_EPOCHS = 500
TRAIN_SCORE = False
RETRAIN = True
classifier = MNISTClassifier()

if RETRAIN:
    optimizer = torch.optim.AdamW(classifier.parameters(),lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,maximize=False)
    scheduler=torch.optim.lr_scheduler.OneCycleLR(optimizer,LEARNING_RATE,total_steps=N_EPOCHS*len(trainloader),pct_start=0.25,anneal_strategy='cos')
    classifier,running_loss_list = train_diffused_classifier(classifier, sde, trainloader, optimizer=optimizer, scheduler=scheduler, device=DEVICE, n_epochs=N_EPOCHS, print_every=100)
    torch.save(classifier.state_dict(),"mnist_diffusion_classifier.ckpt")
    
else:
    classifier_state_dict = torch.load("mnist_diffusion_classifier.ckpt")
    classifier.load_state_dict(classifier_state_dict)
    classifier = classifier.to(DEVICE)



Epoch: 0


  return F.log_softmax(x)
  _warn_get_lr_called_within_step(self)


Loss: 2.3220 | [4.002745242199905e-05]
Loss: 2.3026 | [4.010980654784845e-05]
Epoch: 1
Loss: 2.2949 | [4.030805478637321e-05]
Loss: 2.2882 | [4.0519379714906907e-05]
Epoch: 2


KeyboardInterrupt: 

In [None]:
n_grid_points = 16
time_vec = torch.linspace(0,1,n_grid_points)**2
X_0, Y = trainset.__getitem__(23410)
X_0 = torch.stack([X_0.unsqueeze(0).squeeze()]*n_grid_points)
X_t, noise, score = sde.run_forward(X_0,time_vec)
X_t = X_t.unsqueeze(1)

results = np.exp(classifier(X_t.to(DEVICE),time_vec.to(DEVICE)).cpu().detach().numpy())
fig, axs = plt.subplots(2, len(results),figsize=(3*len(results),6))
for idx in range(len(results)):
    axs[0,idx].set_title(f"Prediction distribution \n time = {time_vec[idx]:.3f}")
    axs[0, idx].bar(x=classes_by_index, height=results[idx])
    axs[1, idx].set_title(f"Input image at t={time_vec[idx]:.2f}")
    axs[1, idx].imshow(X_t[idx].squeeze(), cmap='grey')

In [None]:
def get_classifier_gradient(x: torch.Tensor, t: torch.Tensor, target: int, scale_factor: float = 8.0):
    classifier.zero_grad()
    x = torch.nn.Parameter(x.to(DEVICE),requires_grad=True)
    t = t.to(DEVICE)
    output = classifier(x,t)
    output[:,target].sum().backward()
    return scale_factor*x.grad.detach()

X_0, Y = trainset.__getitem__(23410)
X_0 = torch.stack([X_0.unsqueeze(0).squeeze()]*n_grid_points)
X_t, noise, score = sde.run_forward(X_0,time_vec)
X_t = X_t.unsqueeze(1)

fig, axs = plt.subplots(2,16,figsize=(16*4,4))
for idx in range(16):
    gradient = get_classifier_gradient(X_t[idx].unsqueeze(0), torch.tensor([0.05]), Y)
    axs[0,idx].imshow(gradient.detach().cpu().numpy().squeeze())
    axs[1,idx].imshow(X_t[idx].squeeze())