<a href="https://colab.research.google.com/github/jt658/CS330-Final-Project/blob/main/SNAIL_Pytorch_Quick_Draw.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import zipfile
with zipfile.ZipFile("/dataset/QuickDrawData.zip", 'r') as zip_ref:
    zip_ref.extractall('QuickDrawData')

In [None]:
import random
import numpy as np
import torch


class BatchSampler(object):
    '''
    BatchSampler: yield a batch of indexes at each iteration.
    __len__ returns the number of episodes per epoch (same as 'self.iterations').
    '''

    def __init__(self, labels, classes_per_it, num_samples, iterations, batch_size):
        '''
        Initialize the BatchSampler object
        Args:
        - labels: an iterable containing all the labels for the current dataset
        samples indexes will be infered from this iterable.
        - classes_per_it: number of random classes for each iteration
        - num_samples: number of samples for each iteration for each class
        - iterations: number of iterations (episodes) per epoch
        '''
        super(BatchSampler, self).__init__()
        self.labels = labels
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples
        self.iterations = iterations
        self.batch_size = batch_size

        self.classes, self.counts = np.unique(self.labels, return_counts=True)

        self.idxs = range(len(self.labels))
        self.label_tens = np.empty((len(self.classes), max(self.counts)), dtype=int) * np.nan
        self.label_lens = np.zeros_like(self.classes)
        for idx, label in enumerate(self.labels):
            label_idx = np.argwhere(self.classes == label)[0, 0]
            self.label_tens[label_idx, np.where(np.isnan(self.label_tens[label_idx]))[0][0]] = idx
            self.label_lens[label_idx] += 1

    def __iter__(self):
        '''
        yield a batch of indexes
        '''
        spc = self.sample_per_class + 1 # To get that extra sample
        cpi = self.classes_per_it
        num_samples = spc * cpi
        true_num_samples = (spc - 1) * cpi + 1

        for it in range(self.iterations):
            total_batch = np.array([])
            for _ in range(self.batch_size):
                batch = np.empty(num_samples)
                c_idxs = np.random.permutation(len(self.classes))[:cpi]
                for i, c in enumerate(self.classes[c_idxs]):
                    s = slice(i, i + num_samples, cpi)
                    label_idx = np.argwhere(self.classes == c)[0, 0]
                    if spc > self.label_lens[label_idx]:
                        raise AssertionError('More samples per class than exist in the dataset')
                    sample_idxs = np.random.permutation(self.label_lens[label_idx])[:spc]
                    batch[s] = self.label_tens[label_idx][sample_idxs]
                offset = random.randint(0, 4)
                batch = batch[offset:offset + true_num_samples]
                batch[:true_num_samples - 1] = batch[:true_num_samples - 1][np.random.permutation(true_num_samples - 1)]
                total_batch = np.append(total_batch, batch)
            yield total_batch.astype(int)

    def __len__(self):
        '''
        returns the number of iterations (episodes) per epoch
        '''
        return self.iterations

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class CasualConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, 
                 stride=1, dilation=1, groups=1, bias=True):
        super(CasualConv1d, self).__init__()
        self.dilation = dilation
        padding = dilation * (kernel_size - 1)
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, stride,
                                padding, dilation, groups, bias)

    def forward(self, input):
        # Takes something of shape (N, in_channels, T),
        # returns (N, out_channels, T)
        out = self.conv1d(input)
        return out[:, :, :-self.dilation] # TODO: make this correct for different strides/padding

class DenseBlock(nn.Module):
    def __init__(self, in_channels, dilation, filters, kernel_size=2):
        super(DenseBlock, self).__init__()
        self.casualconv1 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)
        self.casualconv2 = CasualConv1d(in_channels, filters, kernel_size, dilation=dilation)

    def forward(self, input):
        # input is dimensions (N, in_channels, T)
        xf = self.casualconv1(input)
        xg = self.casualconv2(input)
        activations = F.tanh(xf) * F.sigmoid(xg) # shape: (N, filters, T)
        return torch.cat((input, activations), dim=1)
        
class TCBlock(nn.Module):
    def __init__(self, in_channels, seq_length, filters):
        super(TCBlock, self).__init__()
        self.dense_blocks = nn.ModuleList([DenseBlock(in_channels + i * filters, 2 ** (i+1), filters)
                                           for i in range(int(math.ceil(math.log(seq_length, 2))))])

    def forward(self, input):
        # input is dimensions (N, T, in_channels)
        input = torch.transpose(input, 1, 2)
        for block in self.dense_blocks:
            input = block(input)
        return torch.transpose(input, 1, 2)

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, key_size, value_size):
        super(AttentionBlock, self).__init__()
        self.linear_query = nn.Linear(in_channels, key_size)
        self.linear_keys = nn.Linear(in_channels, key_size)
        self.linear_values = nn.Linear(in_channels, value_size)
        self.sqrt_key_size = math.sqrt(key_size)

    def forward(self, input):
        # input is dim (N, T, in_channels) where N is the batch_size, and T is
        # the sequence length
        mask = np.array([[1 if i>j else 0 for i in range(input.shape[1])] for j in range(input.shape[1])])
        mask = torch.ByteTensor(mask).cuda()

        #import pdb; pdb.set_trace()
        keys = self.linear_keys(input) # shape: (N, T, key_size)
        query = self.linear_query(input) # shape: (N, T, key_size)
        values = self.linear_values(input) # shape: (N, T, value_size)
        temp = torch.bmm(query, torch.transpose(keys, 1, 2)) # shape: (N, T, T)
        temp.data.masked_fill_(mask, -float('inf'))
        temp = F.softmax(temp / self.sqrt_key_size, dim=1) # shape: (N, T, T), broadcasting over any slice [:, x, :], each row of the matrix
        temp = torch.bmm(temp, values) # shape: (N, T, value_size)
        return torch.cat((input, temp), dim=2) # shape: (N, T, in_channels + value_size)

In [None]:
import numpy as np
from scipy.interpolate import RegularGridInterpolator
def transformation(inputImg, globalScale, rotationAngle, translationM, translationN, outputGridSize, method="linear"):
    
    # Dimensions of the fixed image
    mDim, nDim = outputGridSize
    # Dimensions of the moving image
    mDimOriginal, nDimOriginal = np.shape(inputImg)

    # Define ranges and meshgrid with respect to the dimensions of the fixed image
    mDimRange = np.linspace(-mDim//2, mDim//2, mDim)
    nDimRange = np.linspace(-nDim//2, nDim//2, nDim)
    mv, nv = np.array(np.meshgrid(mDimRange, nDimRange, indexing='ij'))

    angleInRads = np.radians(rotationAngle)
    rotationMatrix = np.array([[np.cos(angleInRads), -np.sin(angleInRads)], [np.sin(angleInRads), np.cos(angleInRads)]])
    
    # Scaled, rotated, and translated m and n coordinates 
    mTransformedCoords = (globalScale*rotationMatrix[0,0]*mv + globalScale*rotationMatrix[0,1]*nv) + translationM
    nTransformedCoords = (globalScale*rotationMatrix[1,0]*mv + globalScale*rotationMatrix[1,1]*nv) + translationN
    
    # Define ranges with respect to the dimensions of the fixed image
    mDimRangeOriginal = np.linspace(-mDimOriginal//2, mDimOriginal//2, mDimOriginal)
    nDimRangeOriginal = np.linspace(-nDimOriginal//2, nDimOriginal//2, nDimOriginal)
    
    # Generate interpolator that will determine values of the moved image
    # based on the intensity values of the moving image 
    interpolate = RegularGridInterpolator((mDimRangeOriginal, nDimRangeOriginal), inputImg, bounds_error=False, fill_value= 0, method=method);
    # Stack the meshgrid to create an array of coordinate pairs to feed 
    # into the interpolator. The output is the moved image. 
    movedImage = interpolate(np.stack((mTransformedCoords, nTransformedCoords), axis=2))
    return movedImage

In [None]:
from __future__ import print_function
import torch.utils.data as data
import numpy as np
import errno
import os
from PIL import Image, ImageOps, ImageFilter
import torch
import shutil
import matplotlib.pyplot as plt
import imageio
import cv2

'''
Inspired by https://github.com/pytorch/vision/pull/46
'''

IMG_CACHE = {}


class OmniglotDataset(data.Dataset):
    vinalys_baseurl = 'https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/'
    vinyals_split_sizes = {
        'test': vinalys_baseurl + 'test.txt',
        'train': vinalys_baseurl + 'train.txt',
        'trainval': vinalys_baseurl + 'trainval.txt',
        'val': vinalys_baseurl + 'val.txt',
    }

    urls = [
        'https://github.com/brendenlake/omniglot/raw/master/python/images_background.zip',
        'https://github.com/brendenlake/omniglot/raw/master/python/images_evaluation.zip'
    ]
    splits_folder = os.path.join('splits', 'vinyals')
    raw_folder = 'raw'
    processed_folder = 'data'

    def __init__(self, mode='train', root='../dataset/omniglot', transform=None, target_transform=None, download=True):
        '''
        The items are (filename,category). The index of all the categories can be found in self.idx_classes
        Args:
        - root: the directory where the dataset will be stored
        - transform: how to transform the input
        - target_transform: how to transform the target
        - download: need to download the dataset
        '''
        super(OmniglotDataset, self).__init__()
        self.mode = mode
        self.root = root
        self.transform = transform
        self.target_transform = target_transform

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError(
                'Dataset not found. You can use download=True to download it')
        
        if mode == 'test':
          self.root = '../dataset/QuickDrawData'
          self.classes = os.listdir(self.root)
          print(self.classes)
          self.all_items = find_items(self.root, self.classes,mode)

          self.idx_classes = index_classes(self.all_items)
#####################
        else: 
          self.classes = get_current_classes(os.path.join(
              self.root, self.splits_folder, mode + '.txt'))
          self.all_items = find_items(os.path.join(
              self.root, self.processed_folder), self.classes,mode)

          self.idx_classes = index_classes(self.all_items)

        paths, self.y = zip(*[self.get_path_label(pl)
                              for pl in range(len(self))])
        
        modes = []
        for pl in range(len(self)):
          modes.append(self.mode)

        self.x = map(load_img, paths, range(len(paths)), modes)
        self.x = list(self.x)

    def __getitem__(self, idx):
        x = self.x[idx]
        if self.transform:
            x = self.transform(x)
        return x, self.y[idx]

    def __len__(self):
        return len(self.all_items)

    def get_path_label(self, index):
        if self.mode == 'test':
          filename = self.all_items[index][0]
          img = str.join('/', [self.all_items[index][2], filename])
          target = self.idx_classes[self.all_items[index]
                                    [1] + self.all_items[index][-1]]
          if self.target_transform is not None:
              target = self.target_transform(target)
        else:
          filename = self.all_items[index][0]
          rot = self.all_items[index][-1]
          img = str.join('/', [self.all_items[index][2], filename]) + rot
          target = self.idx_classes[self.all_items[index]
                                    [1] + self.all_items[index][-1]]
          if self.target_transform is not None:
              target = self.target_transform(target)

        return img, target

    def _check_exists(self):
        return os.path.exists(os.path.join(self.root, self.processed_folder))

    def download(self):
        from six.moves import urllib
        import zipfile

        if self._check_exists():
            return

        try:
            os.makedirs(os.path.join(self.root, self.splits_folder))
            os.makedirs(os.path.join(self.root, self.raw_folder))
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        for k, url in self.vinyals_split_sizes.items():
            print('== Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[-1]
            file_path = os.path.join(self.root, self.splits_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())

        for url in self.urls:
            print('== Downloading ' + url)
            data = urllib.request.urlopen(url)
            filename = url.rpartition('/')[2]
            file_path = os.path.join(self.root, self.raw_folder, filename)
            with open(file_path, 'wb') as f:
                f.write(data.read())
            orig_root = os.path.join(self.root, self.raw_folder)
            print("== Unzip from " + file_path + " to " + orig_root)
            zip_ref = zipfile.ZipFile(file_path, 'r')
            zip_ref.extractall(orig_root)
            zip_ref.close()
        file_processed = os.path.join(self.root, self.processed_folder)
        for p in ['images_background', 'images_evaluation']:
            for f in os.listdir(os.path.join(orig_root, p)):
                shutil.move(os.path.join(orig_root, p, f), file_processed)
            os.rmdir(os.path.join(orig_root, p))
        print("Download finished.")


def find_items(root_dir, classes, mode):

    if mode == "test":
      retour = []
      for (root, dirs, files) in os.walk(root_dir):
          for f in files:
              r = root.split('/')
              lr = len(r)
              label = r[lr - 1]
              if label in classes and (f.endswith("png")):
                  retour.extend([(f, label, root)])
      print("== Dataset: Found %d items " % len(retour))      
    else:
      retour = []
      rots = ['/rot000', '/rot090', '/rot180', '/rot270']
      for (root, dirs, files) in os.walk(root_dir):
          for f in files:
              r = root.split('/')
              lr = len(r)
              label = r[lr - 2] + "/" + r[lr - 1]
              for rot in rots:
                  if label + rot in classes and (f.endswith("png")):
                      retour.extend([(f, label, root, rot)])
      print("== Dataset: Found %d items " % len(retour))
    return retour


def index_classes(items):
    idx = {}
    for i in items:
        if (not i[1] + i[-1] in idx):
            idx[i[1] + i[-1]] = len(idx)
    print("== Dataset: Found %d classes" % len(idx))
    return idx


def get_current_classes(fname):
    with open(fname) as f:
        classes = f.read().splitlines()
    return classes

def image_file_to_array(filename, dim_input):
  """
  Takes an image path and returns numpy array
  Args:
    filename: Image filename
    dim_input: Flattened shape of image
  Returns:
    1 channel image
  """
  image = imageio.imread(filename)
  image = image.reshape([dim_input])
  image = image.astype(np.float32) / 255.0
  image = 1.0 - image
  return image

def load_img(path, idx, mode):
    if mode != 'test':
      path, rot = path.split('/rot')
    if path in IMG_CACHE:
        x = IMG_CACHE[path]
    else:
        x = Image.open(path)
        IMG_CACHE[path] = x
    if mode != 'test':
      x = x.rotate(float(rot))
      x = x.resize((28, 28))

    shape = 1, x.size[0], x.size[1]
    x = np.array(x, np.float32, copy=False)
    x = 1.0 - torch.from_numpy(x)
    x = x.transpose(0, 1).contiguous().view(shape)

    return x

In [None]:
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1):
    """convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
                     padding=padding, bias=False)

def conv_block(in_channels, out_channels):
    '''
    returns a block conv-bn-relu-pool
    '''
    return nn.Sequential(OrderedDict([
        ('conv', nn.Conv2d(in_channels, out_channels, 3, padding=1)),
        ('bn', nn.BatchNorm2d(out_channels, momentum=1)),
        #('bn', nn.BatchNorm2d(out_channels)),
        ('relu', nn.ReLU()),
        ('pool', nn.MaxPool2d(2))
    ]))

def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True,eps=1e-5, momentum=0.1):
    # momentum = 1 restricts stats to the current mini-batch
    # This hack only works when momentum is 1 and avoids needing to track
    # running stats by substituting dummy variables
    size = int(np.prod(np.array(input.data.size()[1])))
    running_mean = torch.zeros(size).cuda()
    running_var = torch.ones(size).cuda()
    return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)

class OmniglotNet(nn.Module):
    '''
    Model as described in the reference paper,
    source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
    '''
    def __init__(self, x_dim=1, hid_dim=64, z_dim=64):
        super(OmniglotNet, self).__init__()
        self.encoder = nn.Sequential(OrderedDict([
            ('block1', conv_block(x_dim, hid_dim)),
            ('block2', conv_block(hid_dim, hid_dim)),
            ('block3', conv_block(hid_dim, hid_dim)),
            ('block4', conv_block(hid_dim, z_dim)),
        ]))

    def forward(self, x, weights=None):
        if weights is None:
            x = self.encoder(x)
        else:
            x = F.conv2d(x, weights['encoder.block1.conv.weight'], weights['encoder.block1.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block1.bn.weight'], bias=weights['encoder.block1.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block2.conv.weight'], weights['encoder.block2.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block2.bn.weight'], bias=weights['encoder.block2.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block3.conv.weight'], weights['encoder.block3.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block3.bn.weight'], bias=weights['encoder.block3.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.conv2d(x, weights['encoder.block4.conv.weight'], weights['encoder.block4.conv.bias'])
            x = batchnorm(x, weight=weights['encoder.block4.bn.weight'], bias=weights['encoder.block4.bn.bias'])
            x = F.relu(x)
            x = F.max_pool2d(x, 2, 2)
        return x.view(x.size(0), -1)

class ResBlock(nn.Module):

    def __init__(self, in_channels, filters, pool_padding=0):
        super(ResBlock, self).__init__()
        self.conv1 = conv(in_channels, filters)
        self.bn1 = nn.BatchNorm2d(filters)
        self.relu1 = nn.LeakyReLU()
        self.conv2 = conv(filters, filters)
        self.bn2 = nn.BatchNorm2d(filters)
        self.relu2 = nn.LeakyReLU()
        self.conv3 = conv(filters, filters)
        self.bn3 = nn.BatchNorm2d(filters)
        self.relu3 = nn.LeakyReLU()
        self.conv4 = conv(in_channels, filters, kernel_size=1, padding=0)

        self.maxpool = nn.MaxPool2d(2, padding=pool_padding)
        self.dropout = nn.Dropout(p=0.9)

    def forward(self, x):
        residual = self.conv4(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu2(out)

        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu3(out)

        out += residual
        out = self.maxpool(out)
        out = self.dropout(out)

        return out

class MiniImagenetNet(nn.Module):
    '''
    Model as described in the reference paper,
    source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
    '''
    def __init__(self, in_channels=3):
        super(MiniImagenetNet, self).__init__()
        self.block1 = ResBlock(in_channels, 64)
        self.block2 = ResBlock(64, 96)
        self.block3 = ResBlock(96, 128, pool_padding=1)
        self.block4 = ResBlock(128, 256, pool_padding=1)
        self.conv1 = conv(256, 2048, kernel_size=1, padding=0)
        self.maxpool = nn.MaxPool2d(6)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.9)
        self.conv2 = conv(2048, 384, kernel_size=1, padding=0)
        
    def forward(self, x, weights=None):
        if weights is None:
            x = self.block1(x)
            x = self.block2(x)
            x = self.block3(x)
            x = self.block4(x)
            x = self.conv1(x)
            x = self.maxpool(x)
            x = self.relu(x)
            x = self.dropout(x)
            x = self.conv2(x)
        else:
            raise ValueError('Not implemented yet')
        return x.view(x.size(0), -1)

In [None]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class SnailFewShot(nn.Module):
    def __init__(self, N, K, task, use_cuda=True):
        # N-way, K-shot
        super(SnailFewShot, self).__init__()
        if task == 'omniglot':
            self.encoder = OmniglotNet()
            num_channels = 64 + N
        elif task == 'mini_imagenet':
            self.encoder = MiniImagenetNet()
            num_channels = 384 + N
        else:
            raise ValueError('Not recognized task value')
        num_filters = int(math.ceil(math.log(N * K + 1, 2)))
        self.attention1 = AttentionBlock(num_channels, 64, 32)
        num_channels += 32
        self.tc1 = TCBlock(num_channels, N * K + 1, 128)
        num_channels += num_filters * 128
        self.attention2 = AttentionBlock(num_channels, 256, 128)
        num_channels += 128
        self.tc2 = TCBlock(num_channels, N * K + 1, 128)
        num_channels += num_filters * 128
        self.attention3 = AttentionBlock(num_channels, 512, 256)
        num_channels += 256
        self.fc = nn.Linear(num_channels, N)
        self.N = N
        self.K = K
        self.use_cuda = use_cuda

    def forward(self, input, labels):
        x = self.encoder(input)
        batch_size = int(labels.size()[0] / (self.N * self.K + 1))
        last_idxs = [(i + 1) * (self.N * self.K + 1) - 1 for i in range(batch_size)]
        if self.use_cuda:
            labels[last_idxs] = torch.Tensor(np.zeros((batch_size, labels.size()[1]))).cuda()
        else:
            labels[last_idxs] = torch.Tensor(np.zeros((batch_size, labels.size()[1])))
        x = torch.cat((x, labels), 1)
        x = x.view((batch_size, self.N * self.K + 1, -1))
        x = self.attention1(x)
        x = self.tc1(x)
        x = self.attention2(x)
        x = self.tc2(x)
        x = self.attention3(x)
        x = self.fc(x)
        return x

In [None]:
import torch


def init_dataset(opt):
    '''
    Initialize the datasets, samplers and dataloaders
    '''
    if opt.dataset == 'omniglot':
        train_dataset = OmniglotDataset(mode='train')
        val_dataset = OmniglotDataset(mode='val')
        trainval_dataset = OmniglotDataset(mode='trainval')
        test_dataset = OmniglotDataset(mode='test')
    elif opt.dataset == 'mini_imagenet':
        train_dataset = MiniImagenetDataset(mode='train')
        val_dataset = MiniImagenetDataset(mode='val')
        trainval_dataset = MiniImagenetDataset(mode='val')
        test_dataset = MiniImagenetDataset(mode='test')

    tr_sampler = BatchSampler(labels=train_dataset.y,
                                          classes_per_it=opt.num_cls,
                                          num_samples=opt.num_samples,
                                          iterations=opt.iterations,
                                          batch_size=opt.batch_size)

    val_sampler = BatchSampler(labels=val_dataset.y,
                                           classes_per_it=opt.num_cls,
                                           num_samples=opt.num_samples,
                                           iterations=opt.iterations,
                                           batch_size=opt.batch_size)

    trainval_sampler = BatchSampler(labels=trainval_dataset.y,
                                                classes_per_it=opt.num_cls,
                                                num_samples=opt.num_samples,
                                                iterations=opt.iterations,
                                                batch_size=opt.batch_size)

    test_sampler = BatchSampler(labels=test_dataset.y,
                                            classes_per_it=opt.num_cls,
                                            num_samples=opt.num_samples,
                                            iterations=opt.iterations,
                                            batch_size=opt.batch_size)

    tr_dataloader = torch.utils.data.DataLoader(train_dataset,
                                                batch_sampler=tr_sampler)

    val_dataloader = torch.utils.data.DataLoader(val_dataset,
                                                 batch_sampler=val_sampler)

    trainval_dataloader = torch.utils.data.DataLoader(trainval_dataset,
                                                      batch_sampler=trainval_sampler)

    test_dataloader = torch.utils.data.DataLoader(test_dataset,
                                                  batch_sampler=test_sampler)
     
    return tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader

In [None]:
# coding=utf-8
import argparse
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.optim import Adam
import numpy as np
from tqdm import tqdm
import os
import copy
import easydict


def init_model(opt):
    model = SnailFewShot(opt.num_cls, opt.num_samples, opt.dataset)
    model = model.cuda() if opt.cuda else model
    return model

def save_list_to_file(path, thelist):
    with open(path, 'w') as f:
        for item in thelist:
            f.write("%s\n" % item)

def labels_to_one_hot(opt, labels):
    if opt.cuda:
        labels = labels.cpu()
    labels = labels.numpy()
    unique = np.unique(labels)
    map = {label:idx for idx, label in enumerate(unique)}
    idxs = [map[labels[i]] for i in range(labels.size)]
    one_hot = np.zeros((labels.size, unique.size))
    one_hot[np.arange(labels.size), idxs] = 1
    return one_hot, idxs

def batch_for_few_shot(opt, x, y):
    seq_size = opt.num_cls * opt.num_samples + 1
    one_hots = []
    last_targets = []
    for i in range(opt.batch_size):
        one_hot, idxs = labels_to_one_hot(opt, y[i * seq_size: (i + 1) * seq_size])
        one_hots.append(one_hot)
        last_targets.append(idxs[-1])
    last_targets = Variable(torch.Tensor(last_targets).long())
    one_hots = [torch.Tensor(temp) for temp in one_hots]
    y = torch.cat(one_hots, dim=0)
    x, y = Variable(x), Variable(y)
    if opt.cuda:
        x, y = x.cuda(), y.cuda()
        last_targets = last_targets.cuda()
    return x, y, last_targets

def get_acc(last_model, last_targets):
    _, preds = last_model.max(1)
    acc = torch.eq(preds, last_targets).float().mean()
    return acc.item()

def train(opt, tr_dataloader, model, optim, val_dataloader=None):
    if val_dataloader is None:
        best_state = None
    train_loss = []
    train_acc = []
    val_loss = []
    val_acc = []
    best_acc = 0

    best_model_path = os.path.join(opt.exp, 'best_model.pth')
    last_model_path = os.path.join(opt.exp, 'last_model.pth')

    loss_fn = nn.CrossEntropyLoss()

    for epoch in range(opt.epochs):
        print('=== Epoch: {} ==='.format(epoch))
        tr_iter = iter(tr_dataloader)
        model.train()
        model = model.cuda()
        for batch in tqdm(tr_iter):
            optim.zero_grad()
            x, y = batch
            x, y, last_targets = batch_for_few_shot(opt, x, y)
            model_output = model(x, y)
            last_model = model_output[:, -1, :]
            loss = loss_fn(last_model, last_targets)
            loss.backward()
            optim.step()
            train_loss.append(loss.item())
            train_acc.append(get_acc(last_model, last_targets))
        avg_loss = np.mean(train_loss[-opt.iterations:])
        avg_acc = np.mean(train_acc[-opt.iterations:])
        print('Avg Train Loss: {}, Avg Train Acc: {}'.format(avg_loss, avg_acc))
        if val_dataloader is None:
            continue
        val_iter = iter(val_dataloader)
        model.eval()
        for batch in val_iter:
            x, y = batch
            x, y, last_targets = batch_for_few_shot(opt, x, y)
            model_output = model(x, y)
            last_model = model_output[:, -1, :]
            loss = loss_fn(last_model, last_targets)
            val_loss.append(loss.item())
            val_acc.append(get_acc(last_model, last_targets))
        avg_loss = np.mean(val_loss[-opt.iterations:])
        avg_acc = np.mean(val_acc[-opt.iterations:])
        postfix = ' (Best)' if avg_acc >= best_acc else ' (Best: {})'.format(
            best_acc)
        print('Avg Val Loss: {}, Avg Val Acc: {}{}'.format(
            avg_loss, avg_acc, postfix))
        if avg_acc >= best_acc:
            torch.save(model.state_dict(), best_model_path)
            best_acc = avg_acc
            best_state = model.state_dict()
        for name in ['train_loss', 'train_acc', 'val_loss', 'val_acc']:
            save_list_to_file(os.path.join(opt.exp, name + '.txt'), locals()[name])

    torch.save(model.state_dict(), last_model_path)

    return best_state, best_acc, train_loss, train_acc, val_loss, val_acc


def test(opt, test_dataloader, model):
    avg_acc = list()
    for epoch in range(10):
        test_iter = iter(test_dataloader)
        for batch in test_iter:
            x, y = batch
            x, y, last_targets = batch_for_few_shot(opt, x, y)
            model_output = model(x, y)
            last_model = model_output[:, -1, :]
            avg_acc.append(get_acc(last_model, last_targets))
    avg_acc = np.mean(avg_acc)
    print('Test Acc: {}'.format(avg_acc))

    return avg_acc

def main():
    '''
    Initialize everything and train
    '''

    options = easydict.EasyDict({
    "exp": 'default',
    "epochs": 7,
    "iterations": 10000,
    "dataset": 'omniglot',
    "num_cls": 5,
    "num_samples": 1,
    "lr": 0.0001,
    "batch_size": 32,
    "cuda": True
    })

    if not os.path.exists(options.exp):
        os.makedirs(options.exp)

    if torch.cuda.is_available() and not options.cuda:
        print("WARNING: You have a CUDA device, so you should probably run with --cuda")

    tr_dataloader, val_dataloader, trainval_dataloader, test_dataloader = init_dataset(options)
    model = init_model(options)
    optim = torch.optim.Adam(params=model.parameters(), lr=options.lr)
    res = train(opt=options,
                tr_dataloader=tr_dataloader,
                val_dataloader=val_dataloader,
                model=model,
                optim=optim)
    best_state, best_acc, train_loss, train_acc, val_loss, val_acc = res

    model.load_state_dict(best_state)
    print('Testing with best model..')
    test(opt=options,
         test_dataloader=test_dataloader,
         model=model)

In [None]:
main()