In [2]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import matplotlib.pyplot as plt
import numpy as np
from torchvision.io import read_image

#laplace packages
from laplace.baselaplace import FullLaplace
from laplace.curvature.backpack import BackPackGGN
from laplace import Laplace, marglik_training

plt.ion()

<matplotlib.pyplot._IonContext at 0x7fe344de6fa0>

In [3]:
import wandb

In [4]:
wandb.init(project="Bayesian DL", name = 'STN_LA_version_1', entity="zefko")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mzefko[0m (use `wandb login --relogin` to force relogin)


In [5]:
class MnistRandomPlacement(Dataset):

  def __init__(self, crop_size, digits,mode,download=True):

    self.datasets = []
    self.cropsize = crop_size
    self.download = download
    self.digits = digits
    self.mode = mode

    # False (test) or True (train,val)
    trainingset = self.mode in ['train', 'val']

    self.datasets.append(datasets.MNIST('/content/',
                        transform=transforms.Compose([
                        transforms.ToTensor()]),
                        train=trainingset,
                        download=self.download))

    # self.datasets.append(datasets.KMNIST(opt.dataroot,
    #                     transform=transforms.Compose([
    #                     transforms.ToTensor()]),
    #                     train=trainingset,
    #                     download=opt.download))

    self.num_images = self.digits

  def __len__(self):
    return min([self.datasets[0].__len__() for i in range(self.num_images)])

  def __getitem__(self, idx):

    im = torch.zeros((1, 96, 96), dtype=torch.float)

    used_positions, target = [], ''
    for i in range(self.num_images):
      while True:
        x = np.random.randint(0, 96 - 32)
        if len(used_positions) == 0 or abs(used_positions[0][0] - x) > 32:
          break
      while True:
        y = np.random.randint(0, 96 - 32)
        if len(used_positions) == 0 or abs(used_positions[0][1] - y) > 32:
          break

      im1, target1 = self.datasets[i].__getitem__((idx) * (i + 1) % self.datasets[i].__len__())

      c, w, h = im1.shape

      im[:, y:y + h, x:x + w] = im1.type(torch.float)
      #print('created image', im.shape, 'x:', x, 'y:', y)

      target += str(target1)

      transform = transforms.Compose(
      [transforms.ToPILImage(),transforms.Resize(self.cropsize), transforms.ToTensor(),#, transforms.RandomRotation(degrees=(0,180))
      transforms.Normalize((0.1307,), (0.3081,))])
      im = transform(im)

      

      return im,int(target)

In [None]:
config = {
    "kernel_size": 5,
    "channels":1,
    "filter_1_out" :16,
    "filter_2_out" :32,
    "padding" :0,
    "stride" :1, 
    "pool":2,
    "learning_rate": 0.01,
    "epochs": 50,
    "batch_size": 64,
    "crop_size":128
}


wandb.config ={
    
    "learning_rate": 0.01,
    "epochs": 50,
    "batch_size": 64}

In [None]:
train_mnist = MnistRandomPlacement(config["crop_size"],10,'train',True)
test_mnist = MnistRandomPlacement(config["crop_size"],10,'test',True)

In [None]:
fig = plt.figure()

for i in range(len(train_mnist.datasets[0])):
    image,label = train_mnist[i]

    print(i, image.shape, label)

    ax = plt.subplot(1, 5, i + 1)
    plt.tight_layout()
    ax.set_title('Sample #{}'.format(i))
    ax.axis('off')
    ax.imshow(image[0,:,:].numpy())

    if i == 4:
        plt.show()
        break
wandb.log({'MNIST examples': wandb.Image(fig)})

In [None]:
from six.moves import urllib
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def data(misplacement=True):

    if misplacement:
        train_mnist = MnistRandomPlacement(config["crop_size"],10,'train',True)
        test_mnist = MnistRandomPlacement(config["crop_size"],10,'test',True)
        train_loader = torch.utils.data.DataLoader(train_mnist, batch_size=64, shuffle=True, num_workers=2)
        test_loader = torch.utils.data.DataLoader(test_mnist, batch_size=64, shuffle=True, num_workers=2)

    else:
    #Training dataset
        train_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=True, download=True,
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                            transforms.Normalize((0.1307,), (0.3081,))
                        ])), batch_size=config["batch_size"], shuffle=True, num_workers=4)

        #Test dataset
        test_loader = torch.utils.data.DataLoader(
            datasets.MNIST(root='.', train=False, transform=transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.1307,), (0.3081,))
            ])), batch_size=config["batch_size"], shuffle=True, num_workers=4)
    return train_loader,test_loader


train_loader,test_loader = data()


In [None]:
train_features, train_labels = next(iter(train_loader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[1].squeeze()
label = train_labels[1]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

In [None]:
preview_dt = wandb.Table(columns=["id", "image", "label", "split"])

for i in range(10):
    label = train_labels[i]
    plt.subplot(2, 5, i+1)
    plt.imshow(train_features[i][0], cmap='gray')
    plt.title(label)
    #print(f"Label: {label}")
    preview_dt.add_data(i,wandb.Image(train_features[i]),label,'train')
wandb.log({'Train Input':preview_dt})

In [None]:

height =train_features.shape[2]
width= train_features.shape[3]

def compute_conv_dim(dim_size, kernel_size, padding, stride):
  # (I-F)+2*P/S +1
    return int((dim_size - kernel_size + 2 * padding) / stride + 1)

def compute_pool_dim(dim_size, kernel_size, stride):
  #(I-F)/S +1
  return int((dim_size - kernel_size) / stride + 1)

In [None]:
from numpy.lib import polynomial
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()


        self.conv1 = nn.Conv2d(config["channels"], config["filter_1_out"], config["kernel_size"])
        #evaluating image dimensions after first connvolution
        self.conv1_out_height = compute_conv_dim(height, config["kernel_size"], config["padding"], config["stride"])
        self.conv1_out_width = compute_conv_dim(width,  config["kernel_size"],  config["padding"],  config["stride"])


        #first pooling
        self.pool1 = nn.MaxPool2d(config["pool"], config["pool"])
        #evaluating image dimensions after first pooling
        self.conv2_out_height = compute_pool_dim(self.conv1_out_height, config["pool"], config["pool"])
        self.conv2_out_width = compute_pool_dim(self.conv1_out_width,  config["pool"],  config["pool"])
        
        
        #Second Convolution
        self.conv2 = nn.Conv2d(config["filter_1_out"], config["filter_2_out"], config["kernel_size"])
        #evaluating image dimensions after second convolution
        self.conv3_out_height = compute_conv_dim(self.conv2_out_height, config["kernel_size"], config["padding"], config["stride"])
        self.conv3_out_width = compute_conv_dim(self.conv2_out_width,  config["kernel_size"], config["padding"], config["stride"])
        self.conv2_drop = nn.Dropout2d()

        
        #Second pooling
        self.pool2 = nn.MaxPool2d(config["pool"], config["pool"])
        #evaluating image dimensions after second pooling
        self.conv4_out_height = compute_pool_dim(self.conv3_out_height, config["pool"], config["pool"])
        self.conv4_out_width = compute_pool_dim(self.conv3_out_width,  config["pool"], config["pool"])
        
        #Third Convolution
        # self.conv3 = nn.Conv2d(32, 64, kernel_size)
        # #evaluating image dimensions after second convolution
        # self.conv5_out_height = compute_conv_dim(self.conv4_out_height, kernel_size, padding,  stride)
        # self.conv5_out_width = compute_conv_dim(self.conv4_out_width,  kernel_size,  padding,  stride)
        # self.conv3_drop = nn.Dropout2d()

        # #Second pooling
        # self.pool3 = nn.MaxPool2d(pool, pool)
        # #evaluating image dimensions after second pooling
        # self.conv6_out_height = compute_pool_dim(self.conv5_out_height, pool, pool)
        # self.conv6_out_width = compute_pool_dim(self.conv5_out_width,  pool, pool)
        #print(self.conv4_out_height)
        #print(self.conv4_out_width)

        
        self.fc1 = nn.Linear(config["filter_2_out"]* self.conv4_out_height * self.conv4_out_width, 50)
        #print(self.fc1)
        self.fc2 = nn.Linear(50, 10)


        # Spatial transformer localization-network
        self.localization = nn.Sequential(
            nn.Conv2d(1, 8, kernel_size=7),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True),
            nn.Conv2d(8, 10, kernel_size=5),
            nn.MaxPool2d(2, stride=2),
            nn.ReLU(True)
        )
#apply laplace to the last linera layer for the first attempt
        # Regressor for the 3 * 2 affine matrix
        self.fc_loc = nn.Sequential(
            #nn.Linear(10 * 3 * 3, 32),#original
            nn.Linear(10 * 28* 28, 32),
            nn.ReLU(True),
            nn.Linear(32, 3 * 2)
        )

        # Initialize the weights/bias with identity transformation
        self.fc_loc[2].weight.data.zero_()
        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))

    # Spatial transformer network forward function
    def stn(self, x):
        #print('x',x.size())
        xs = self.localization(x)

        #print('xs',xs.size())
        #xs = xs.view(-1, 10 * 3 * 3) #original
        xs = xs.view(-1, 10 * 28 * 28)

        #print('xs view',xs.size())
        theta = self.fc_loc(xs)

        #print('theta before view',theta.shape)
        theta = theta.view(-1, 2, 3)
        #print('theta',theta.shape)
        #print('size',x.size())


        grid = F.affine_grid(theta, x.size(),align_corners =True)
        x = F.grid_sample(x, grid)

        return x

    def forward(self, x):

        #print('input',x.size())
        # transform the input
        x = self.stn(x)

        #print('transform',x.size())
        # Perform the usual forward pass
        #convolutional layer 1
        x = F.relu(self.pool1(self.conv1(x)))
        #print('forward1',x.size())

        #convolutional layer 2
        x = F.relu(self.pool2(self.conv2_drop(self.conv2(x))))
        #print('forward2',x.size())

        #convolutional layer 3
        #x = F.relu(self.pool3(self.conv3_drop(self.conv3(x))))
        #print('forward3',x.size())

        #x = x.view(-1, 320) #original
        #print(self.conv3_out_height)
        #print(self.conv4_out_width)
        x = x.view(-1, config["filter_2_out"]* self.conv4_out_height * self.conv4_out_width)

        #print('flatten',x.size())

        x = F.relu(self.fc1(x))
        #print('forward4',x.size())

        x = F.dropout(x, training=self.training)
        #print('forward5',x.size())

        x = self.fc2(x)
        #print('forward6',x.size())

        return F.log_softmax(x, dim=1)


model = Net().to(device)

In [None]:
print(model)

In [None]:
optimizer = optim.SGD(model.parameters(), lr=config["learning_rate"])
#optimizer = optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    #wandb.watch(model, optimizer,log="all", log_freq=10)
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 500 == 0:
            wandb.log({"epoch": epoch, "loss": loss.item()})
            #print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
             #   epoch, batch_idx * len(data), len(train_loader.dataset),
              #  100. * batch_idx / len(train_loader), loss.item()))

In [None]:
def test():
    with torch.no_grad():
        model.eval()
        test_loss = 0
        correct = 0
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)

            # sum up batch loss
            test_loss += F.nll_loss(output, target, size_average=False).item()
            # get the index of the max log-probability
            pred = output.max(1, keepdim=True)[1]
            correct += pred.eq(target.view_as(pred)).sum().item()

        test_loss /= len(test_loader.dataset)
        wandb.log({"test loss": test_loss,
                    "test_accuracy":100. * correct / len(test_loader.dataset) })

        #print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'
         #     .format(test_loss, correct, len(test_loader.dataset),
          #            100. * correct / len(test_loader.dataset)))

In [None]:
def convert_image_np(inp):
    """Convert a Tensor to numpy image."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    return inp

In [None]:
def visualize_stn():

     # create a wandb Artifact to version each test step separately
    test_data_at = wandb.Artifact("test_samples_" + str(wandb.run.id), type="predictions")
    # create a wandb.Table() in which to store predictions for each test step
    columns=[ "Grid in", "Grid out"]
    test_table = wandb.Table(columns=columns)

    with torch.no_grad():
        # Get a batch of training data
        data = next(iter(test_loader))[0].to(device)
        #for batch, (data,_) in enumerate(test_loader):
          #data = data.to(device)

        input_tensor = data.cpu()
        transformed_input_tensor = model.stn(data).cpu()

        in_grid = convert_image_np(
            torchvision.utils.make_grid(input_tensor))

        out_grid = convert_image_np(
            torchvision.utils.make_grid(transformed_input_tensor))
        
        # Plot the results side-by-side
        test_table.add_data(wandb.Image(in_grid), wandb.Image(out_grid))
    

        # Plot the results side-by-side
        f, axarr = plt.subplots(1, 2,figsize=(20,20))
        axarr[0].imshow(in_grid)
        axarr[0].set_title('Dataset Images')

        axarr[1].imshow(out_grid)
        axarr[1].set_title('Transformed Images')
        #f.suptitle(f'Bathc:{batch}') # or plt.suptitle('Main title')
        
    # log predictions table to wandb
    test_data_at.add(test_table, "predictions")
    wandb.run.log_artifact(test_data_at)

      



In [None]:
for epoch in range(1, config["epochs"]+ 1):
    train(epoch)
    test()

# Visualize the STN transformation on some input batch
visualize_stn()
# #wandb.log({'Final':wandb.Image(visualize_stn)})

plt.ioff()
plt.show()