**Run experiments on CIFAR and SVHN using RandAugment, the Wide-ResNet-28-2 model, the loss from [UDA](https://arxiv.org/pdf/1904.12848.pdf), and treating only a portion of the CIFAR and SVHN images as labeled**

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import random
import math
from PIL import Image
import PIL.ImageOps
import PIL.ImageEnhance
import PIL.ImageFilter

**Original RandAugment**

In [2]:
TRANSFORMATIONS = ['invert', 'cutout', 'sharpness', 'autocontrast', 'posterize', 'shearx', 'translatex', 'translatey',
                   'sheary', 'rotate', 'equalize', 'contrast', 'color', 'solarize', 'brightness']

In [3]:
def transform_image(image, transform, magnitude):
    """Applies the transform to the given image with the given magnitude if the
    transformation takes a magnitude.

    Args:
        image: The image to transform.
        transform: The transformation to apply.
        magnitude: The magnitude of the transformation, if applicable.
    Returns:
        The transformed image.
    """
    if transform == 'invert':
        return PIL.ImageOps.invert(image)
    elif transform == 'cutout':
        crop_box = (magnitude, magnitude, image.width - magnitude, image.height - magnitude)
        return image.crop(box=crop_box).resize(image.size)
    elif transform == 'sharpness':
        enhancer = PIL.ImageEnhance.Sharpness(image)
        return enhancer.enhance(magnitude)
    elif transform == 'autocontrast':
        return PIL.ImageOps.autocontrast(image, cutoff=magnitude)
    elif transform == 'posterize':
        return PIL.ImageOps.posterize(image, min(magnitude, 8))
    elif transform == 'shearx':
        transform_values = (1, magnitude, -1 * magnitude * image.width, 0, 1, 0)
        return image.transform(image.size, Image.AFFINE, transform_values)
    elif transform == 'translatex':
        return image.transform(image.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0))
    elif transform == 'translatey':
        return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude))
    elif transform == 'sheary':
        transform_values = (1, 0, 0, magnitude, 1, -1 * magnitude * image.height)
        return image.transform(image.size, Image.AFFINE, transform_values)
    elif transform == 'rotate':
        return image.rotate(magnitude)
    elif transform == 'equalize':
        return PIL.ImageOps.equalize(image)
    elif transform == 'contrast':
        enhancer = PIL.ImageEnhance.Contrast(image)
        return enhancer.enhance(magnitude)
    elif transform == 'color':
        enhancer = PIL.ImageEnhance.Color(image)
        return enhancer.enhance(magnitude)
    elif transform == 'solarize':
        return PIL.ImageOps.solarize(image, magnitude)
    else:
        # brightness
        enhancer = PIL.ImageEnhance.Brightness(image)
        return enhancer.enhance(magnitude)

In [4]:
def rand_augment(image, probability=0.5):
    """Applies RandAugment to the given image (a tensor) with the given
    probability being the probability with which to apply the transformation.

    Args:
        image: The image to transform.
        probability: The probability with which to transform the image.
    Returns:
        The transformed image.
    """
    if np.random.binomial(1, 0.5) == 0:
        return torch.clone(image)
    cur_transform = random.choice(TRANSFORMATIONS)
    magnitude = random.randint(1, 9)
    pil_image = transforms.ToPILImage()(image).convert('RGB')
    new_image = transform_image(pil_image, cur_transform, magnitude)
    return transforms.ToTensor()(new_image)

**New Version of RandAugment**

In [5]:
TRANSFORMATIONS = ['invert', 'cutout', 'sharpness', 'autocontrast', 'posterize', 'shearx', 'translatex', 'translatey',
                   'sheary', 'rotate', 'equalize', 'contrast', 'color', 'solarize', 'brightness']

In [6]:
def new_transform_image(image, transform):
    """Applies the transform to the given image, using a custom random range for
    the transform magnitude (if the transformation takes a magnitude).

    Args:
        image: The image to transform.
        transform: The transformation to apply.
    Returns:
        The transformed image.
    """
    if transform == 'invert':
        return PIL.ImageOps.invert(image)
    elif transform == 'cutout':
        magnitude = random.randint(1, 8)
        crop_box = (magnitude, magnitude, image.width - magnitude, image.height - magnitude)
        return image.crop(box=crop_box).resize(image.size)
    elif transform == 'sharpness':
        enhancer = PIL.ImageEnhance.Sharpness(image)
        magnitude = random.uniform(0.0, 2.0)
        return enhancer.enhance(magnitude)
    elif transform == 'autocontrast':
        magnitude = random.randint(1, 35)
        return PIL.ImageOps.autocontrast(image, cutoff=magnitude)
    elif transform == 'posterize':
        magnitude = random.randint(3, 8)
        return PIL.ImageOps.posterize(image, magnitude)
    elif transform == 'shearx':
        magnitude = random.uniform(0.0, 0.5)
        transform_values = (1, magnitude, -1 * magnitude * image.width, 0, 1, 0)
        return image.transform(image.size, Image.AFFINE, transform_values)
    elif transform == 'translatex':
        magnitude = random.uniform(0.0, 10.0)
        return image.transform(image.size, Image.AFFINE, (1, 0, magnitude, 0, 1, 0))
    elif transform == 'translatey':
        magnitude = random.uniform(0.0, 10.0)
        return image.transform(image.size, Image.AFFINE, (1, 0, 0, 0, 1, magnitude))
    elif transform == 'sheary':
        magnitude = random.uniform(0.0, 0.5)
        transform_values = (1, 0, 0, magnitude, 1, -1 * magnitude * image.height)
        return image.transform(image.size, Image.AFFINE, transform_values)
    elif transform == 'rotate':
        magnitude = random.uniform(-10.0, 10.0)
        return image.rotate(magnitude)
    elif transform == 'equalize':
        return PIL.ImageOps.equalize(image)
    elif transform == 'contrast':
        magnitude = random.uniform(0.3, 2.0)
        enhancer = PIL.ImageEnhance.Contrast(image)
        return enhancer.enhance(magnitude)
    elif transform == 'color':
        magnitude = random.uniform(0.0, 2.0)
        enhancer = PIL.ImageEnhance.Color(image)
        return enhancer.enhance(magnitude)
    elif transform == 'solarize':
        magnitude = random.randint(128, 256)
        return PIL.ImageOps.solarize(image, magnitude)
    else:
        # brightness
        magnitude = random.uniform(0.3, 2.0)
        enhancer = PIL.ImageEnhance.Brightness(image)
        return enhancer.enhance(magnitude)

In [7]:
def new_rand_augment(image, probability=0.5):
    """Applies a modified version of RandAugment to the given image (a tensor)
    with the given probability being the probability with which to apply the
    transformation.

    Args:
        image: The image to transform.
        probability: The probability with which to transform the image.
    Returns:
        The transformed image.
    """
    if np.random.binomial(1, 0.5) == 0:
        return torch.clone(image)
    cur_transform = random.choice(TRANSFORMATIONS)
    pil_image = transforms.ToPILImage()(image).convert('RGB')
    new_image = new_transform_image(pil_image, cur_transform)
    return transforms.ToTensor()(new_image)

**Helper function to enable storing augmented image tensors in a dictionary, where the key is the hash of the original image tensor.**

In [8]:
def hash_tensor(input_tensor):
    """Hashes the tensor.

    Args:
        input_tensor: Tensor to hash.
    Returns:
        The hash value of the given tensor.
    """
    input_multiply_factor = 3
    input_multiplier = torch.arange(1, input_tensor.size()[1] + 1) * input_multiply_factor
    return torch.sum(input_tensor.to(device) * input_multiplier.to(device)).item()

**Mount Google Drive**

In [None]:
from google.colab import drive
drive.mount('/gdrive/')
!ls /gdrive

**The general setup, training, and evaluation code below is from [CSE 543 Deep Learning Homework 1](https://github.com/pjreddie/uwnet/blob/master/hw1.ipynb)**

In [None]:
import os

BASE_PATH = '/gdrive/My Drive/599project/cv/'
if not os.path.exists(BASE_PATH):
    os.makedirs(BASE_PATH)
DATA_PATH = BASE_PATH + 'tiny_imagenet/'

!pwd
!ls
os.chdir(BASE_PATH)
if not os.path.exists(DATA_PATH + 'train.h5'):
    !wget https://courses.cs.washington.edu/courses/cse599g1/19au/files/homework2.tar
    !tar -xvf homework2.tar
    !rm homework2.tar
!pwd
!ls

In [11]:
import h5py
import sys
sys.path.append(BASE_PATH)
import pt_util

In [12]:
device = torch.device("cuda")

In [13]:
import glob
import re
import matplotlib.pyplot as plt
try:
    # For 2.7
    import cPickle as pickle
except:
    # For 3.x
    import pickle


def restore(net, save_file):
    """Restores the weights from a saved file

    This does more than the simple Pytorch restore. It checks that the names
    of variables match, and if they don't doesn't throw a fit. It is similar
    to how Caffe acts. This is especially useful if you decide to change your
    network architecture but don't want to retrain from scratch.

    Args:
        net(torch.nn.Module): The net to restore
        save_file(str): The file path
    """

    net_state_dict = net.state_dict()
    restore_state_dict = torch.load(save_file)

    restored_var_names = set()

    print('Restoring:')
    for var_name in restore_state_dict.keys():
        if var_name in net_state_dict:
            var_size = net_state_dict[var_name].size()
            restore_size = restore_state_dict[var_name].size()
            if var_size != restore_size:
                print('Shape mismatch for var', var_name, 'expected', var_size, 'got', restore_size)
            else:
                if isinstance(net_state_dict[var_name], torch.nn.Parameter):
                    # backwards compatibility for serialized parameters
                    net_state_dict[var_name] = restore_state_dict[var_name].data
                try:
                    net_state_dict[var_name].copy_(restore_state_dict[var_name])
                    print(str(var_name) + ' -> \t' + str(var_size) + ' = ' + str(int(np.prod(var_size) * 4 / 10**6)) + 'MB')
                    restored_var_names.add(var_name)
                except Exception as ex:
                    print('While copying the parameter named {}, whose dimensions in the model are'
                          ' {} and whose dimensions in the checkpoint are {}, ...'.format(
                              var_name, var_size, restore_size))
                    raise ex

    ignored_var_names = sorted(list(set(restore_state_dict.keys()) - restored_var_names))
    unset_var_names = sorted(list(set(net_state_dict.keys()) - restored_var_names))
    print('')
    if len(ignored_var_names) == 0:
        print('Restored all variables')
    else:
        print('Did not restore:\n\t' + '\n\t'.join(ignored_var_names))
    if len(unset_var_names) == 0:
        print('No new variables')
    else:
        print('Initialized but did not modify:\n\t' + '\n\t'.join(unset_var_names))

    print('Restored %s' % save_file)


def restore_latest(net, folder):
    """Restores the most recent weights in a folder

    Args:
        net(torch.nn.module): The net to restore
        folder(str): The folder path
    Returns:
        int: Attempts to parse the epoch from the state and returns it if possible. Otherwise returns 0.
    """

    checkpoints = sorted(glob.glob(folder + '/*.pt'), key=os.path.getmtime)
    start_it = 0
    if len(checkpoints) > 0:
        restore(net, checkpoints[-1])
        try:
            start_it = int(re.findall(r'\d+', checkpoints[-1])[-1])
        except:
            pass
    return start_it


def save(net, file_name, num_to_keep=1):
    """Saves the net to file, creating folder paths if necessary.

    Args:
        net(torch.nn.module): The network to save
        file_name(str): the path to save the file.
        num_to_keep(int): Specifies how many previous saved states to keep once this one has been saved.
            Defaults to 1. Specifying < 0 will not remove any previous saves.
    """

    folder = os.path.dirname(file_name)
    if not os.path.exists(folder):
        os.makedirs(folder)
    torch.save(net.state_dict(), file_name)
    extension = os.path.splitext(file_name)[1]
    checkpoints = sorted(glob.glob(folder + '/*' + extension), key=os.path.getmtime)
    print('Saved %s\n' % file_name)
    if num_to_keep > 0:
        for ff in checkpoints[:-num_to_keep]:
            os.remove(ff)

def write_log(filename, data):
    """Pickles and writes data to a file

    Args:
        filename(str): File name
        data(pickleable object): Data to save
    """

    if not os.path.exists(os.path.dirname(filename)):
        os.makedirs(os.path.dirname(filename))
    pickle.dump(data, open(filename, 'wb'))

def read_log(filename, default_value=None):
    """Reads pickled data or returns the default value if none found

    Args:
        filename(str): File name
        default_value(anything): Value to return if no file is found
    Returns:
        unpickled file
    """

    if os.path.exists(filename):
        return pickle.load(open(filename, 'rb'))
    return default_value

def show_images(images, titles=None, columns=5, max_rows=5):
    """Shows images in a tiled format

    Args:
        images(list[np.array]): Images to show
        titles(list[string]): Titles for each of the images
        columns(int): How many columns to use in the tiling
        max_rows(int): If there are more than columns * max_rows images, only the first n of them will be shown.
    """

    images = images[:min(len(images), max_rows * columns)]

    plt.figure(figsize=(20, 10))
    for ii, image in enumerate(images):
        plt.subplot(len(images) / columns + 1, columns, ii + 1)
        plt.axis('off')
        if titles is not None and ii < len(titles):
            plt.title(str(titles[ii]))
        plt.imshow(image)
    plt.show()

def plot(x_values, y_values, title, xlabel, ylabel):
    """Plots a line graph

    Args:
        x_values(list or np.array): x values for the line
        y_values(list or np.array): y values for the line
        title(str): Title for the plot
        xlabel(str): Label for the x axis
        ylabel(str): label for the y axis
    """

    plt.figure(figsize=(20, 10))
    plt.plot(x_values, y_values)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

def to_scaled_uint8(array):
    """Returns a normalized uint8 scaled to 0-255. This is useful for showing images especially of floats.

    Args:
        array(np.array): The array to normalize
    Returns:
        np.array normalized and of type uint8
    """

    array = np.array(array, dtype=np.float32)
    array -= np.min(array)
    array *= (255. / np.max(array))
    array = array.astype(np.uint8)
    return array

**Dataset class:** To be used with PyTorch DataLoader to return a given image or the number of images.

In [14]:
class ComputerVisionDataset(torch.utils.data.Dataset):
    """Dataset class to be used with Pytorch DataLoader to return a given image
    or the number of images.
    """
    def __init__(self, cv_data):
        """Stores the given image data to be accessed later.

        Args:
            cv_data: Image data.
        """
        super(ComputerVisionDataset, self).__init__()

        self.cv_data = cv_data

    def __len__(self):
        """Returns the number of images in the dataset.

        Returns:
            The number of images in the dataset.
        """
        return len(self.cv_data)
        
    def __getitem__(self, idx):
        """Returns the image at the given index.

        Args:
            idx: Index of the item to return.
        Returns:
            Returns the image at the given index.
        """       
        return self.cv_data[idx][0], self.cv_data[idx][1]


**Wide-ResNet-28-2 Model** copied from https://github.com/xternalz/WideResNet-pytorch/blob/master/wideresnet.py since the model itself couldn't be automatically loaded in. The loss function is implemented according to the loss function for [UDA](https://arxiv.org/pdf/1904.12848.pdf)

In [15]:
class BasicBlock(nn.Module):
    def __init__(self, in_planes, out_planes, stride, dropRate=0.0):
        super(BasicBlock, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = dropRate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, dropRate=0.0):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, dropRate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, dropRate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, dropRate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

class WideResNet(nn.Module):
    def __init__(self, depth, num_classes, widen_factor=1, dropRate=0.0):
        super(WideResNet, self).__init__()
        nChannels = [16, 16*widen_factor, 32*widen_factor, 64*widen_factor]
        assert((depth - 4) % 6 == 0)
        n = (depth - 4) / 6
        block = BasicBlock
        # 1st conv before any network block
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1,
                               padding=1, bias=False)
        # 1st block
        self.block1 = NetworkBlock(n, nChannels[0], nChannels[1], block, 1, dropRate)
        # 2nd block
        self.block2 = NetworkBlock(n, nChannels[1], nChannels[2], block, 2, dropRate)
        # 3rd block
        self.block3 = NetworkBlock(n, nChannels[2], nChannels[3], block, 2, dropRate)
        # global average pooling and classifier
        self.bn1 = nn.BatchNorm2d(nChannels[3])
        self.relu = nn.ReLU(inplace=True)
        self.fc = nn.Linear(nChannels[3], num_classes)
        self.nChannels = nChannels[3]

        self.accuracy = None

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                m.bias.data.zero_()
    def forward(self, x):
        out = self.conv1(x)
        out = self.block1(out)
        out = self.block2(out)
        out = self.block3(out)
        out = self.relu(self.bn1(out))
        out = F.avg_pool2d(out, 8)
        out = out.view(-1, self.nChannels)
        return self.fc(out)

    def loss(self, image_data, label, temperature, beta, only_sup_loss, augmented_data=None, reduction='mean'):
        """Computes the loss used in UDA.

        Args:
            image_data: Batch of original images.
            label: Labels for all images in the batch.
            temperature: The temperature to use in sharpening predictions.
            beta: Confidence threshold for confidence-based masking.
            only_sup_loss: True if only the supervised loss should be computed,
                False otherwise.
            augmented_data: Augmented versions of the unlabeled images.
            reduction: Type of reduction to use when computing the loss (i.e.,
                mean or sum).
        Returns:
            The supervised loss (loss on labeled data) plus the consistency loss
            (loss on unlabeled data).
        """

        if only_sup_loss:
            prediction = F.softmax(self.forward(image_data), dim=1)
            return F.cross_entropy(prediction, label.squeeze(), reduction=reduction)
        
        unlabeled_loss_val = torch.tensor(0.0, requires_grad=True).to(device)
        labeled_loss_val = torch.tensor(0.0, requires_grad=True).to(device)
        num_unlabeled = 0
        num_labeled = 0

        for i in range(0, len(image_data)):
            cur_image = torch.unsqueeze(image_data[i], 0).to(device)
            if label[i] >= 0:
                # Data is labeled
                cur_prediction = F.softmax(self.forward(cur_image), dim=1)
                cur_label = torch.unsqueeze(label[i], 0)
                labeled_loss_val += F.cross_entropy(cur_prediction, cur_label, reduction=reduction)
                num_labeled += 1
            else:
                # Data is unlabeled
                with torch.no_grad():
                    num_unlabeled += 1
                    cur_prediction = self.forward(cur_image)
                softmax_prediction = F.softmax(cur_prediction, dim=1)
                if torch.max(softmax_prediction) > beta:
                    sharpened = F.log_softmax(cur_prediction / temperature, dim=1)
                    augmented_images = augmented_data[hash_tensor(image_data[i])]
                    cur_unlabeled_loss = torch.tensor(0.0, requires_grad=True).to(device)
                    for image in augmented_images:
                        augmented_image = torch.unsqueeze(image, 0).to(device)
                        augmented_classification = F.log_softmax(self.forward(augmented_image), dim=1)
                        cur_unlabeled_loss += F.kl_div(sharpened, augmented_classification, reduction='batchmean', log_target=True)
                    unlabeled_loss_val += (cur_unlabeled_loss / len(augmented_images))

        total_loss_val = torch.tensor(0.0, requires_grad=True).to(device)
        if num_unlabeled > 0:
            total_loss_val += (0.5 * (unlabeled_loss_val / num_unlabeled))
        if num_labeled > 0:
            total_loss_val += (labeled_loss_val / num_labeled)
        return total_loss_val

    def save_model(self, file_path, num_to_keep=1):
        """Saves the model to the given file.

        Args:
            file_path: Path to save the model.
            num_to_keep: Number of previous saved states to keep.
        """
        pt_util.save(self, file_path, num_to_keep)
        
    def save_best_model(self, accuracy, file_path, num_to_keep=1):
        """If a higher accuracy is achieved, saves the current model.

        Args:
            accuracy: New accuracy achieved.
            file_path: Path to save the model.
            num_to_keep: Number of previous saved states to keep.
        """
        if self.accuracy == None or accuracy > self.accuracy:
            self.accuracy = accuracy
            self.save_model(file_path, num_to_keep)

    def load_model(self, file_path):
        """Loads the model from the given file.

        Args:
            file_path: Path to save the model.
        """
        pt_util.restore(self, file_path)

    def load_last_model(self, dir_path):
        """Loads the most recent model from the given directory.

        Args:
            dir_path: Directory from which to load the model.
        """
        return pt_util.restore_latest(self, dir_path)

**Train and test**: This code is adapted from [CSE 543 Deep Learning Homework 1](https://github.com/pjreddie/uwnet/blob/master/hw1.ipynb)

In [16]:
import time
def train(model, device, train_loader, optimizer, epoch, log_interval, temperature, beta, augmented_images, pretrain):
    """Trains the model on the given data.

    Args:
        model: The model to train.
        device: The device the model is on.
        train_loader: DataLoader for the training data.
        optimizer: Optimizer to use with training.
        epoch: The current epoch.
        log_interval: How often to log metrics.
        temperature: The temperature to use in sharpening predictions.
        beta: Confidence threshold for confidence-based masking.
        augmented_images: Augmented versions of the unlabeled images.
        pretrain: True if pretraining the model, False otherwise.
    Returns:
        The average loss for each batch for the current epoch.
    """
    model.train()
    losses = []
    for batch_idx, (data, label) in enumerate(train_loader):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        loss = model.loss(data, label, temperature, beta, pretrain, augmented_data=augmented_images)
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('{} Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                time.ctime(time.time()),
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
    return np.mean(losses)

def test(model, device, test_loader, temperature, beta, log_interval=None):
    """Evaluates the model on the test data.

    Args:
        model: The model to evaluate.
        device: The device the model is on.
        test_loader: DataLoader for the test data.
        temperature: The temperature to use in sharpening predictions.
        beta: Confidence threshold for confidence-based masking.
        log_interval: How often to log metrics.
    Returns:
        The loss and accuracy on the test data.
    """
    model.eval()
    test_loss = 0
    correct = 0

    with torch.no_grad():
        for batch_idx, (data, label) in enumerate(test_loader):
            data, label = data.to(device), label.to(device)
            output = F.softmax(model(data), dim=1)
            test_loss_on = model.loss(data, label, temperature, beta, True, reduction='sum').item()
            test_loss += test_loss_on
            pred = output.max(1)[1]
            correct_mask = pred.eq(label.view_as(pred))
            num_correct = correct_mask.sum().item()
            correct += num_correct
            if log_interval is not None and batch_idx % log_interval == 0:
                print('{} Test: [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    time.ctime(time.time()),
                    batch_idx * len(data), len(test_loader.dataset),
                    100. * batch_idx / len(test_loader), test_loss_on))

    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset), test_accuracy))
    return test_loss, test_accuracy

In [17]:
def run_training(data_train, data_test, data_path, hyperparams, experiment_num, class_names, augmented_images, pretrain):
    """Train and evaluate the model on the given training and test data, respectively.

    Args:
        data_train: The training data.
        data_test: The test data.
        data_path: Location of the image data.
        hyperparams: Hyperparameters to use during training.
        experiment_num: Current experiment number.
        class_names: Names of the classes for the images.
        augmented_images: Augmented versions of the training data images.
        pretrain: True if pretraining the model, False otherwise.
    """
    BATCH_SIZE = hyperparams['batch_size']
    TEST_BATCH_SIZE = hyperparams['test_batch_size']
    EPOCHS = hyperparams['epochs']
    LEARNING_RATE = hyperparams['learning_rate']
    MOMENTUM = hyperparams['momentum']
    WEIGHT_DECAY = hyperparams['weight_decay']
    TEMPERATURE = hyperparams['temperature']
    BETA = hyperparams['beta']
    USE_CUDA = True
    SEED = 0
    PRINT_INTERVAL = 100

    EXPERIMENT_VERSION = experiment_num
    LOG_PATH = data_path + 'logs/' + EXPERIMENT_VERSION + '/'
    use_cuda = USE_CUDA and torch.cuda.is_available()

    print('Using device', device)
    import multiprocessing
    print('num cpus:', multiprocessing.cpu_count())

    kwargs = {'num_workers': multiprocessing.cpu_count(),
              'pin_memory': True} if use_cuda else {}

    train_loader = torch.utils.data.DataLoader(data_train, batch_size=BATCH_SIZE,
                                              shuffle=True, **kwargs)

    test_loader = torch.utils.data.DataLoader(data_test, batch_size=TEST_BATCH_SIZE,
                                              shuffle=False, **kwargs)

    model = WideResNet(28, len(class_names), widen_factor=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    start_epoch = model.load_last_model(LOG_PATH)

    train_losses, test_losses, test_accuracies = pt_util.read_log(LOG_PATH + 'log.pkl', ([], [], []))
    test_loss, test_accuracy = test(model, device, test_loader, TEMPERATURE, BETA)

    test_losses.append((start_epoch, test_loss))
    test_accuracies.append((start_epoch, test_accuracy))

    try:
        for epoch in range(start_epoch, EPOCHS + 1):
            train_loss = train(model, device, train_loader, optimizer, epoch, PRINT_INTERVAL, TEMPERATURE, BETA, augmented_images, pretrain)
            test_loss, test_accuracy = test(model, device, test_loader, TEMPERATURE, BETA)
            train_losses.append((epoch, train_loss))
            test_losses.append((epoch, test_loss))
            test_accuracies.append((epoch, test_accuracy))
            pt_util.write_log(LOG_PATH + 'log.pkl', (train_losses, test_losses, test_accuracies))
            model.save_best_model(test_accuracy, LOG_PATH + '%03d.pt' % epoch)


    except KeyboardInterrupt as ke:
        print('Interrupted')
    except:
        import traceback
        traceback.print_exc()
    finally:
        model.save_model(LOG_PATH + '%03d.pt' % epoch, 0)
        ep, val = zip(*train_losses)
        pt_util.plot(ep, val, 'Train loss', 'Epoch', 'Error')
        ep, val = zip(*test_losses)
        pt_util.plot(ep, val, 'Test loss', 'Epoch', 'Error')
        ep, val = zip(*test_accuracies)
        pt_util.plot(ep, val, 'Test accuracy', 'Epoch', 'Error')


**Helper function to convert images to tensors**

In [18]:
def convert_images_to_tensor(image_data):
    """Convert each image in the list to a tensor

    Args:
        image_data: A list of tuples containing an image and the label.
    Returns:
        A list of tuples with each image converted to a tensor.
    """
    tensor_data = []
    for image in image_data:
        tensor_image = transforms.ToTensor()(image[0])
        tensor_data.append((tensor_image, image[1]))
    return tensor_data

**Experiments with CIFAR**

**Note:** When trying to download the data, if you get an error saying "Transport Error: not connected to Drive", then re-run the cell to mount Google Drive and try running the cell to download the data again.

Load the CIFAR data

In [None]:
cifar_data_path = BASE_PATH + 'cifar/'
original_cifar_data_train = datasets.CIFAR10(root=cifar_data_path, train=True, download=True)
cifar_data_test = datasets.CIFAR10(root=cifar_data_path, train=False, download=True)

Randomly select a portion of the data to be labeled images and the remainder to be unlabeled

In [20]:
shuffled_cifar_data_train = list(original_cifar_data_train)
random.shuffle(shuffled_cifar_data_train)
num_labeled = 4000
labeled_cifar_data_train = shuffled_cifar_data_train[:num_labeled]
unlabeled_cifar_data_train = [(image[0], -1) for image in shuffled_cifar_data_train[num_labeled:]]
full_cifar_data_train = labeled_cifar_data_train + unlabeled_cifar_data_train
random.shuffle(full_cifar_data_train)
new_cifar_data_train = ComputerVisionDataset(convert_images_to_tensor(full_cifar_data_train))

Create augmented versions of all unlabeled CIFAR images in the training data

In [28]:
num_augmented_images = 5
augmented_cifar_images = dict()
for cifar_image in new_cifar_data_train:
    if cifar_image[1] < 0:
        hashed_image = hash_tensor(cifar_image[0])
        new_images = []
        for i in range(num_augmented_images):
            # To use the modified version of RandAugment, call the function
            # new_rand_augment instead of rand_augment
            new_images.append(rand_augment(cifar_image[0]))
        augmented_cifar_images[hashed_image] = new_images

In [22]:
cifar_data_test_tensor = ComputerVisionDataset(convert_images_to_tensor(cifar_data_test))

In [23]:
cifar_class_names = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [24]:
cifar_experiment_num = "0.001"

Pretrain with original CIFAR data

In [None]:
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(32),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(3),
    transforms.RandomRotation(3),
    transforms.ColorJitter(),
    transforms.ToTensor(),
])

original_cifar_data_train_tensor = datasets.CIFAR10(root=cifar_data_path, train=True, download=True, transform=transform_train)

In [None]:
cifar_hyperparams = {
    'batch_size': 32,
    'test_batch_size': 10,
    'epochs': 500,
    'learning_rate': 0.001,
    'momentum': 0.9,
    'weight_decay': 0.0,
    'temperature': 0.4,
    'beta': 1
}
run_training(original_cifar_data_train_tensor, cifar_data_test_tensor, cifar_data_path, cifar_hyperparams,
             cifar_experiment_num, cifar_class_names, None, True)

Train with augmented CIFAR data using UDA

In [None]:
ssl_cifar_hyperparams = {
    'batch_size': 64,
    'test_batch_size': 10,
    'epochs': 200,
    'learning_rate': 0.0001,
    'momentum': 0.9,
    'weight_decay': 0.0005,
    'temperature': 0.4,
    'beta': 0.7
}
run_training(new_cifar_data_train, cifar_data_test_tensor, cifar_data_path, ssl_cifar_hyperparams,
             cifar_experiment_num, cifar_class_names, augmented_cifar_images, False)

**Experiments with SVHN**

Load the SVHN data

In [None]:
svhn_data_path = BASE_PATH + 'svhn/'
original_svhn_data_train = datasets.SVHN(root=svhn_data_path, split='train', download=True)
svhn_data_test = datasets.SVHN(root=svhn_data_path, split='test', download=True)

Randomly select a portion of the data to be labeled images and the remainder to be unlabeled

In [30]:
shuffled_svhn_data_train = list(original_svhn_data_train)
random.shuffle(shuffled_svhn_data_train)
num_labeled = 1000
labeled_svhn_data_train = shuffled_svhn_data_train[:num_labeled]
unlabeled_svhn_data_train = [(image[0], -1) for image in shuffled_svhn_data_train[num_labeled:]]
full_svhn_data_train = labeled_svhn_data_train + unlabeled_svhn_data_train
random.shuffle(full_svhn_data_train)
new_svhn_data_train = ComputerVisionDataset(convert_images_to_tensor(full_svhn_data_train))

Create augmented versions of all unlabeled SVHN images in the training data

In [31]:
num_augmented_images = 5
augmented_svhn_images = dict()
for svhn_image in new_svhn_data_train:
    if svhn_image[1] < 0:
        hashed_image = hash_tensor(svhn_image[0])
        new_images = []
        for i in range(num_augmented_images):
            # To use the modified version of RandAugment, call the function
            # new_rand_augment instead of rand_augment
            new_images.append(rand_augment(svhn_image[0]))
        augmented_svhn_images[hashed_image] = new_images

In [32]:
svhn_data_test_tensor = ComputerVisionDataset(convert_images_to_tensor(svhn_data_test))

In [33]:
svhn_class_names = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

In [34]:
svhn_experiment_num = "0.001"

Pretrain with original SVHN data

In [None]:
transform_train = transforms.Compose([
    transforms.ToTensor(),
])

original_svhn_data_train_tensor = datasets.SVHN(root=svhn_data_path, split='train', download=True, transform=transform_train)

In [None]:
svhn_hyperparams = {
    'batch_size': 64,
    'test_batch_size': 10,
    'epochs': 50,
    'learning_rate': 0.0001,
    'momentum': 0.9,
    'weight_decay': 0.0,
    'temperature': 0.4,
    'beta': 1
}
run_training(original_svhn_data_train_tensor, svhn_data_test_tensor, svhn_data_path, svhn_hyperparams,
             svhn_experiment_num, svhn_class_names, None, True)

Train with augmented SVHN data using UDA

In [None]:
ssl_svhn_hyperparams = {
    'batch_size': 64,
    'test_batch_size': 10,
    'epochs': 100,
    'learning_rate': 0.0001,
    'momentum': 0.9,
    'weight_decay': 0.0005,
    'temperature': 0.4,
    'beta': 0.8
}
run_training(new_svhn_data_train, svhn_data_test_tensor, svhn_data_path, ssl_svhn_hyperparams,
             svhn_experiment_num, svhn_class_names, augmented_svhn_images, False)