### Initializations

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
torch.manual_seed(0)

def show_tensor_images(image_tensor, num_images=25, size=(3, 32, 32), nrow=5, show=True):
    
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    
    if show:
        plt.show()

### Generator

In [2]:
class Generator(nn.Module):
    
    '''
    Generator Class
    Values:
        input_dim: the dimension of the input vector, a scalar
        im_chan: the number of channels of the output image, a scalar
              (CIFAR100 is in color (red, green, blue), so 3 is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    
    def __init__(self, input_dim=10, im_chan=3, hidden_dim=64):
        
        super(Generator, self).__init__()
        
        self.input_dim = input_dim
        
        self.gen = nn.Sequential(
            
        )
        
    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        
        if not final_layer:
            
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True)
            )
        
        else:
            
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh()
            )
        
def get_noise(n_samples, input_dim, device='cpu'):
    
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, input_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        input_dim: the dimension of the input vector, a scalar
        device: the device type
    '''
    
    return torch.randn(n_samples, input_dim, device=device)

def combine_vectors(x, y):
    
    '''
    Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?)
    Parameters:
    x: (n_samples, ?) the first vector. 
        In this assignment, this will be the noise vector of shape (n_samples, z_dim), 
        but you shouldn't need to know the second dimension's size.
    y: (n_samples, ?) the second vector.
        Once again, in this assignment this will be the one-hot class vector 
        with the shape (n_samples, n_classes), but you shouldn't assume this in your code.
    '''
    return torch.cat([x, y], 1)

def get_one_hot_labels(labels, n_classes):
    
    return F.one_hot(labels, n_classes)

### Training

In [3]:
cifar100_shape = (3, 32, 32)
n_classes = 100

In [4]:
n_epochs = 10000
z_dim = 64
display_step = 500
batch_size = 64
lr = 0.0002
device = 'cuda'

In [5]:
generator_input_dim = z_dim + n_classes

### Classifier

In [6]:
class Classifier(nn.Module):
    
    '''
    Classifier Class
    Values:
        im_chan: the number of channels of the output image, a scalar
        n_classes: the total number of classes in the dataset, an integer scalar
        hidden_dim: the inner dimension, a scalar
    '''
    
    def __init__(self, im_chan, n_classes, hidden_dim=32):
        
        super(Classifier, self).__init__()
        
        self.disc = nn.Sequential(
            self.make_classifier_block(im_chan, hidden_dim),
            self.make_classifier_block(hidden_dim, hidden_dim * 2),
            self.make_classifier_block(hidden_dim * 2, hidden_dim * 4),
            self.make_classifier_block(hidden_dim * 4, n_classes, final_layer=True)
        )
        
    def make_classifier_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        
        '''
        Function to return a sequence of operations corresponding to a classifier block; 
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        
        if not final_layer:
            
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=False)
            )
        
        else:
            
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
            )
        
    def forward(self, image):
        
        '''
        Function for completing a forward pass of the classifier: Given an image tensor, 
        returns an n_classes-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with im_chan channels
        '''
        
        class_pred = self.disc(image)
        return class_pred.view(len(class_pred), -1)

### Tuning the Classifier

You will fine-tune your model by augmenting the original real data with fake data and during that process, observe how to increase the accuracy of your classifier with these fake, GAN-generated bugs. After this, you will prove your worth as a bug master.



In [7]:
def combine_sample(real, fake, p_real):
    
    '''
    Function to take a set of real and fake images of the same length (x)
    and produce a combined tensor with length (x) and sampled at the target probability
    Parameters:
        real: a tensor of real images, length (x)
        fake: a tensor of fake images, length (x)
        p_real: the probability the images are sampled from the real set
    '''
    
    make_fake = torch.rand(len(real)) > p_real
    target_images = real.clone()
    target_images[make_fake] = fake[make_fake]
    
    return target_images

In [8]:
a = torch.rand(5)
c = torch.rand(5)

In [9]:
b = [True, False, True, False, True]
print(a)
print(c)

tensor([0.4963, 0.7682, 0.0885, 0.1320, 0.3074])
tensor([0.6341, 0.4901, 0.8964, 0.4556, 0.6323])


In [10]:
c[b] = a[b]
c

tensor([0.4963, 0.4901, 0.0885, 0.4556, 0.3074])

In [11]:
n_test_samples = 9999
test_combination = combine_sample(
    torch.ones(n_test_samples, 1), 
    torch.zeros(n_test_samples, 1), 
    0.3
)
# Check that the shape is right
assert tuple(test_combination.shape) == (n_test_samples, 1)
# Check that the ratio is right
assert torch.abs(test_combination.mean() - 0.3) < 0.05
# Make sure that no mixing happened
assert test_combination.median() < 1e-5

test_combination = combine_sample(
    torch.ones(n_test_samples, 10, 10), 
    torch.zeros(n_test_samples, 10, 10), 
    0.8
)
# Check that the shape is right
assert tuple(test_combination.shape) == (n_test_samples, 10, 10)
# Make sure that no mixing happened
assert torch.abs((test_combination.sum([1, 2]).median()) - 100) < 1e-5

test_reals = torch.arange(n_test_samples)[:, None].float()
test_fakes = torch.zeros(n_test_samples, 1)
test_saved = (test_reals.clone(), test_fakes.clone())
test_combination = combine_sample(test_reals, test_fakes, 0.3)
# Make sure that the sample isn't biased
assert torch.abs((test_combination.mean() - 1500)) < 100
# Make sure no inputs were changed
assert torch.abs(test_saved[0] - test_reals).sum() < 1e-3
assert torch.abs(test_saved[1] - test_fakes).sum() < 1e-3

test_fakes = torch.arange(n_test_samples)[:, None].float()
test_combination = combine_sample(test_reals, test_fakes, 0.3)
# Make sure that the order is maintained
assert torch.abs(test_combination - test_reals).sum() < 1e-4
if torch.cuda.is_available():
    # Check that the solution matches the input device
    assert str(combine_sample(
        torch.ones(n_test_samples, 10, 10).cuda(), 
        torch.zeros(n_test_samples, 10, 10).cuda(),
        0.8
    ).device).startswith("cuda")
print("Success!")

Success!


Now you have a challenge: find a p_real and a generator image such that your classifier gets an average of a 51% accuracy or higher on the insects, when evaluated with the eval_augmentation function. You'll need to fill in find_optimal to find these parameters to solve this part!

When you're training a generator, you will often have to look at different checkpoints and choose one that does the best (either empirically or using some evaluation method). Here, you are given four generator checkpoints: gen_1.pt, gen_2.pt, gen_3.pt, gen_4.pt. You'll also have some scratch area to write whatever code you'd like to solve this problem, but you must return a p_real and an image name of your selected generator checkpoint. You can hard-code/brute-force these numbers if you would like, but you are encouraged to try to solve this problem in a more general way. In practice, you would also want a test set (since it is possible to overfit on a validation set), but for simplicity you can just focus on the validation set.

In [14]:
def find_optimal():
    
    gen_names = [
        "gen_1.pt",
        "gen_2.pt",
        "gen_3.pt",
        "gen_4.pt"
    ]
    
    best_p_real, best_gen_name = 0, "gen_1.pt"
    max_eval = -1
    
    for gen_name in gen_names:
        
        p_real_all = torch.linspace(0, 1, 21)
        
        for p_real in tqdm(p_real_all):
            
            curr_eval = eval_augmentation(p_real, gen_name, n_test=20)
            
            if curr_eval > max_eval:
                
                max_eval = curr_eval
                best_p_real = p_real
                best_gen_name = gen_name
                
    return best_p_real, best_gen_name

def augmented_train(p_real, gen_name):
    
    gen = Generator(generator_input_dim).to(device)
    gen.load_state_dict(torch.load(gen_name))
    
    classifier = Classifier(cifar100_shape[0], n_classes).to(device)
    classifier.load_state_dict(torch.load('class.pt'))
    
    criterion = nn.CrossEntropyLoss()
    
    batch_size = 256
    
    train_set = torch.load('insect_train.pt')
    val_set = torch.load('insect_val.pt')
    
    dataloader = DataLoader(
        torch.utils.data.TensorDataset(train_set["images"], train_set['labels']),
        batch_size=batch_size,
        shuffle=True
    )
    
    validation_dataloader = DataLoader(
        torch.utils.data.TensorDataset(val_set['images'], val_set['labels']),
        batch_size=batch_size
    )
    
    display_step = 1
    lr = 0.0002
    n_epochs = 20
    classifier_opt = torch.optim.Adam(classifier.parameters(), lr=lr)
    cur_step = 0
    best_score = 0
    
    for epoch in range(n_epochs):
        
        for real, labels in dataloader:
            
            real = real.to(device)
            
            # Flatten the image
            labels = labels.to(device)
            one_hot_labels = get_one_hot_labels(labels, n_classes).float()
            
            # Updating the classifier
            classifier_opt.zero_grad()
            cur_batch_size = len(labels)
            
            fake_noise = get_noise(cur_batch_size, z_dim, device=device)
            noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
            fake = gen(noise_and_labels)
            
            target_images = combine_sample(real.clone(), fake.clone(), p_real)
            labels_hat = classifier(target_images.detach())
            classifier_loss = criterion(labels_hat, labels)
            classifier_loss.backward()
            classifier_opt.step()
            
            # Calculate the accuracy of the validation set
            if cur_step % display_step == 0 and cur_step > 0:
                
                classifier_val_loss = 0
                classifier_correct = 0
                num_validation = 0
                
                with torch.no_grad():
                    
                    for val_example, val_label in validation_dataloader:
                        
                        cur_batch_size = len(val_example)
                        num_validation += cur_batch_size
                        val_example = val_example.to(device)
                        val_label = val_label.to(device)
                        labels_hat = Classifier(val_example)
                        classifier_val_loss += criterion(labels_hat, labels) * cur_batch_size
                        classifier_correct += (labels_hat.argmax(1) == val_label).float().sum()
                        
                    accuracy = classifier_correct.item() / num_validation
                    
                    if accuracy > best_score:
                        
                        best_score = accuracy
                        
            cur_step += 1
            
    return best_score

def eval_augmentation(p_real, gen_name, n_test=20):
    
    total = 0
    
    for i in range(n_test):
        
        total += augmented_train(p_real, gen_name)
        
    return total / n_test

best_p_real, best_gen_name = find_optimal()

performance = eval_augmentation(best_p_real, best_gen_name)

print(f"Your model had an accuracy of {performance:0.1%}")
assert performance > 0.51
print("Success!")

  0%|          | 0/21 [00:00<?, ?it/s]

RuntimeError: Error(s) in loading state_dict for Generator:
	Unexpected key(s) in state_dict: "gen.0.0.weight", "gen.0.0.bias", "gen.0.1.weight", "gen.0.1.bias", "gen.0.1.running_mean", "gen.0.1.running_var", "gen.0.1.num_batches_tracked", "gen.1.0.weight", "gen.1.0.bias", "gen.1.1.weight", "gen.1.1.bias", "gen.1.1.running_mean", "gen.1.1.running_var", "gen.1.1.num_batches_tracked", "gen.2.0.weight", "gen.2.0.bias", "gen.2.1.weight", "gen.2.1.bias", "gen.2.1.running_mean", "gen.2.1.running_var", "gen.2.1.num_batches_tracked", "gen.3.0.weight", "gen.3.0.bias". 