In [None]:
%%capture

# make new directory with my implementation based
!cp -r ../baselines/transfer ../baselines/submission

In [None]:
# ###############################################################################################################################################

In [None]:
%%writefile ../baselines/submission/config.gin

DataGenerator.batch_config = [28, 256] # [image_size, batch_size]
DataGenerator.episode_config = [28, 5, 1, 5] # Not used 
DataGenerator.valid_episode_config = [28, 5, 1, 19]
DataGenerator.pool = 'train'
DataGenerator.mode = 'batch'

In [None]:
%%writefile ../baselines/submission/model.py

import os
import logging
import csv 
import datetime

import tensorflow as tf
from tensorflow import keras
import gin

from metadl.api.api import MetaLearner, Learner, Predictor

from tensorflow.keras.models import clone_model, Model
from tensorflow.keras import activations
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Dense, Dropout, Activation, MaxPooling2D
from tensorflow.keras.layers import AvgPool2D, GlobalAveragePooling2D, MaxPool2D, Flatten, ZeroPadding2D, Add, AveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import ReLU, concatenate
import tensorflow.keras.backend as K

import torch
import numpy as np

LEARNING_RATE=0.00025
DATASET_DIM=4096


def basic_block(x, filters, stride=1):
    x_skip = x

    x = Conv2D(filters, kernel_size=(3, 3), strides=stride, padding='same')(x)
    x = BatchNormalization()(x)
    x = Conv2D(filters, kernel_size=(3, 3), strides=1, padding='same')(x)
    x = BatchNormalization()(x)

    if stride != 1:
      x_skip = Conv2D(filters, kernel_size=(1, 1), strides=stride)(x_skip)
      x_skip = BatchNormalization()(x_skip)

    x = Add()([x, x_skip])
    x = ReLU()(x)

    return x

def make_block(x, filters, blocks, stride=1):
    x = basic_block(x, filters, stride)

    for _ in range(1, blocks):
        x = basic_block(x, filters, stride=1)

    return x


def my_net(img_size=28):
    inputs = Input(shape=(img_size, img_size, 3))
    x = ZeroPadding2D(padding=(2, 2))(inputs)

    x = Conv2D(64, kernel_size=(3, 3))(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    x = MaxPooling2D((2, 2))(x)

    x = make_block(x, 64, blocks=2)
    x = make_block(x, 64, blocks=2, stride=1)
    x = make_block(x, 64, blocks=2, stride=1)
    x = make_block(x, 64, blocks=2, stride=1)

    x = AveragePooling2D((2, 2), padding='same')(x)

    outputs = Flatten()(x)

    # define the model 

    return inputs, outputs



@gin.configurable
class MyMetaLearner(MetaLearner):
    """ Loads and fine-tune a model pre-trained on ImageNet. """
    def __init__(self,
                iterations=10,
                freeze_base=True,
                total_meta_train_class=883):
        super().__init__()
        self.iterations = iterations
        # self.iterations = 200
        self.epochs = 40
        self.total_meta_train_class = total_meta_train_class

        # ++++++ nets ++++++
        inputs, my_net_outputs = my_net()
        
        x = BatchNormalization()(my_net_outputs)
        outputs = Dense(self.total_meta_train_class, activation='softmax')(x)
        self.model = Model(inputs, outputs)

        logging.info(self.model.summary())

        # x = BatchNormalization()(my_net_outputs)
        # outputs = Dense(4, activation='softmax')(x)
        # self.rotate_classifier = Model(inputs, outputs)
        # ++++++ nets ++++++

        self.loss = keras.losses.SparseCategoricalCrossentropy()
        self.learning_rate = LEARNING_RATE

        self.optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate)
        self.acc = keras.metrics.SparseCategoricalAccuracy()

        # Summary Writers
        self.current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        self.train_log_dir = ('logs/transfer/gradient_tape/' + self.current_time 
            + '/meta-train')
        self.valid_log_dir = ('logs/transfer/gradient_tape/' + self.current_time 
            + '/meta-valid')
        self.train_summary_writer = tf.summary.create_file_writer(
            self.train_log_dir)
        self.valid_summary_writer = tf.summary.create_file_writer(
            self.valid_log_dir)

        # Statstics tracker
        self.train_loss = tf.keras.metrics.Mean(name = 'train_loss')
        self.train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name = 'train_accuracy')
        self.valid_loss = tf.keras.metrics.Mean(name = 'valid_loss')
        self.valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
            name = 'valid_accuracy')

    def meta_fit(self, meta_dataset_generator) -> Learner:
        """ We train the classfier created on top of the pre-trained embedding
        layers.
        Args:
            meta_dataset_generator : a DataGenerator object. We can access 
                the meta-train and meta-validation episodes via its attributes.
                Refer to the metadl/data/dataset.py for more details.
        
        Returns:
            MyLearner object : a Learner that stores the current embedding 
                function (Neural Network) of this MetaLearner.
        """
        meta_train_dataset = meta_dataset_generator.meta_train_pipeline
        meta_valid_dataset = meta_dataset_generator.meta_valid_pipeline
        meta_valid_dataset = meta_valid_dataset.batch(2)

        logging.info('Starting meta-fit for the transfer baseline ...')
        meta_iterator = meta_train_dataset.__iter__()
        sample_data = next(meta_iterator)
        logging.info('Images shape : {}'.format(sample_data[0][0].shape))
        logging.info('Labels shape : {}'.format(sample_data[0][1].shape))
        
        for epoch in range(self.epochs):
            self.evaluate(MyLearner(self.model), meta_valid_dataset)
            count = 0

            # PT-MAP hyperparameter testing

            # learning rate optimatization
            if (epoch == 20) or (epoch == 30):
                logging.info('New learning rate : {}'.format(epoch))
                self.learning_rate = self.learning_rate/10
                self.optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate)


            for (images, labels), _ in meta_train_dataset:
                with tf.GradientTape() as tape:
                    augmentation_type = np.random.randint(3)
                    if augmentation_type == 0:
                        images = tf.image.adjust_saturation(images, 3)
                    elif augmentation_type == 1:
                        images = tf.image.adjust_brightness(images, 0.4)

                    images = tf.concat([images,
                                        tf.image.rot90(images, k=1),
                                        tf.image.rot90(images, k=2),
                                        tf.image.rot90(images, k=3)], 0)
                    
                    labels = tf.concat([labels, labels, labels, labels], 0)

                    cpreds = self.model(images)
                    closs = self.loss(labels, cpreds)

                grads = tape.gradient(closs, self.model.trainable_weights)
                self.optimizer.apply_gradients(
                    zip(grads, self.model.trainable_weights))

                # grads = tape.gradient(rloss, self.rotate_classifier.trainable_weights)
                # self.optimizer.apply_gradients(
                #     zip(grads, self.rotate_classifier.trainable_weights))

                if ((count + 1) % 10) == 0:
                    logging.info('Epoch {}/{} - Iteration {}/{} - Loss : {}'.format(
                        epoch + 1, self.epochs, count + 1, self.iterations, closs.numpy()))
                # self.train_accuracy.update_state(labels, preds)
                # self.train_loss.update_state(loss)

                count += 1
                if count >= self.iterations:
                    break

        return MyLearner(self.model)

    def evaluate(self, learner, meta_valid_generator):
        """Evaluates the current meta-learner with episodes generated from the
        meta-validation split. The number of episodes used to compute the 
        an average accuracy is set to 20.
        Args:
            learner : MyLearner object. The current state of the meta-learner 
                    is embedded in the object via its neural network.
            meta_valid_generator : a tf.data.Dataset object that generates
                                    episodes from the meta-validation split.
        """
        count_val = 0
        for tasks_batch in meta_valid_generator : 
            sup_set = tf.data.Dataset.from_tensor_slices(
                (tasks_batch[0][1], tasks_batch[0][0]))
            que_set = tf.data.Dataset.from_tensor_slices(
                (tasks_batch[0][4],tasks_batch[0][3]))
            new_ds = tf.data.Dataset.zip((sup_set, que_set))
            for ((supp_labs, supp_img), (que_labs, que_img)) in new_ds:
                # supp_img, que_img = self.aug_rotation(supp_img, que_img)
                support_set = tf.data.Dataset.from_tensor_slices(
                    (supp_img, supp_labs))
                query_set = tf.data.Dataset.from_tensor_slices(que_img)
                support_set = support_set.batch(5)
                query_set = query_set.batch(95)
                predictor = learner.fit(support_set)
                preds = predictor.predict(query_set)
                self.valid_accuracy.update_state(que_labs, preds)
            
            count_val += 1 
            if count_val >= 20: break

        logging.info('Meta-Valid accuracy : {:.3%}'.format(
            self.valid_accuracy.result()))


@gin.configurable
class MyLearner(Learner):
    def __init__(self, 
                model=None,
                N_ways=5):
        """
        Args:
            model : A keras.Model object describing the Meta-Learner's neural
                network.
            N_ways : Integer, the number of classes to consider at meta-test
                time.
        """
        super().__init__()
        self.N_ways = N_ways
        self.K_shots = 1

        if model == None:
            inputs, outputs = my_net()
            # outputs = Dense(self.total_meta_train_class)(outputs)
            self.model = Model(inputs, outputs)

        else : 
            new_model = keras.models.clone_model(model)
            self.model = keras.Model(inputs=model.input, outputs=model.layers[-3].output)
            # self.model = new_model

        # logging.info(self.model.summary())

        # self.optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate)
        # self.loss = keras.losses.SparseCategoricalCrossentropy()

        self.dataset = torch.zeros((1, self.N_ways, self.K_shots, DATASET_DIM))

    def fit(self, dataset_train) -> Predictor:
        """Fine-tunes the current model with the support examples of a new 
        unseen task. 
        Args:
            dataset_train : a tf.data.Dataset object. Iterates over the support
                examples. 
        Returns:
            a Predictor object that is initialized with the fine-tuned 
                Learner's neural network weights.
        """
        for sup_imgs, sup_lbls in dataset_train:
            emb_imgs = self.model(sup_imgs)
            sup_lbls_numpy = sup_lbls.numpy()

            for emb_img, sup_lbl in zip(emb_imgs, sup_lbls_numpy):
                emb_img_torch = torch.from_numpy(emb_img.numpy())
                self.dataset[0, sup_lbl, 0, :] = emb_img_torch

        return MyPredictor(self.model, self.dataset)

    def save(self, model_dir):
        """
        Saves the embedding function, i.e. the prototypical network as a 
        tensorflow checkpoint.
        """
        if(os.path.isdir(model_dir) != True):
            raise ValueError('The model directory provided is invalid. Please\
                 check that its path is valid.')
        
        ckpt_file = os.path.join(model_dir, 'learner.ckpt')
        self.model.save_weights(ckpt_file) 

    def load(self, model_dir):
        """
        Loads the embedding function, i.e. the prototypical network from a 
        tensorflow checkpoint.
        """
        if(os.path.isdir(model_dir) != True):
            raise ValueError('The model directory provided is invalid. Please\
                    check that its path is valid.')

        ckpt_path = os.path.join(model_dir, 'learner.ckpt')
        self.model.load_weights(ckpt_path)

    
class MyPredictor(Predictor):
    def __init__(self,
                model,
                dataset):
        """
        Args:
            embedding_fn : Distance funtion to consider at meta-test time.
            dataset : Prototypes computed using the support set
            distance_fn : Distance function to consider for the proto-networks
        """
        super().__init__()
        self.embedding_fn = model
        self.dataset = dataset

        self.config = {
            'n_shot' : 1,
            'n_ways' : 5,
            'n_queries' : 19,

            'n_runs' : 1, # competition works only with one run for evaluations

            'pt_epsilon' : 1e-6,
            'pt_beta' : 0.5,
            'gm_lambda' : 10,
            'map_alpha' : 0.4,
            'map_epochs': 30,
        }

        self.config['n_lsamples'] = self.config['n_ways'] * self.config['n_shot']
        self.config['n_usamples'] = self.config['n_ways'] * self.config['n_queries']
        self.config['n_samples'] = self.config['n_lsamples'] + self.config['n_usamples']

    # centres support data and query data
    def center_data(self, data, index):
        # centering support data
        data[:, :index] = data[:, :index, :] - data[:, :index].mean(1, keepdim=True)
        data[:, :index] = data[:, :index, :] / torch.norm(data[:, :index, :], 2, 2)[:, :, None]
        # centering query data
        data[:, index:] = data[:, index:, :] - data[:, index:].mean(1, keepdim=True)
        data[:, index:] = data[:, index:, :] / torch.norm(data[:, index:, :], 2, 2)[:, :, None]
        
        return data

    # calculates the mu_j estimation
    def mu_j_estimation(self, M_star, f_data):
        # nominator
        nominator = M_star.permute(0,2,1).matmul(f_data)
        # denominator
        denominator = M_star.sum(dim=1).unsqueeze(2)

        return nominator.div(denominator)

    # creates pytorch tensor with query data required for PT+MAP
    def dataset_test_to_pytorch(self, dataset_test):
        query_data = torch.zeros((
            self.config['n_runs'],
            self.config['n_ways'],
            self.config['n_queries'],
            self.dataset.shape[3]))

        # only one iteration (the for loop is actually useless)
        for que_imgs in dataset_test:
            emb_imgs = self.embedding_fn(que_imgs)

            i = 0
            sup_lbl = 0

            for emb_img in emb_imgs:
                query_data[0, sup_lbl, i, :] = torch.from_numpy(emb_img.numpy())

                if i == self.config['n_queries'] - 1: 
                    sup_lbl += 1
                    i = 0
                else: 
                    i += 1

        return query_data

    # necessary transformations for PT+MAP, returns transformed data and labels
    def pt_map_init(self, ndatas):
        ndatas = ndatas.permute(0,2,1,3)
        ndatas = ndatas.reshape(
            self.config['n_runs'], self.config['n_samples'], -1)
            
        # Power transform
        ndatas[:,] = torch.pow(ndatas[:,] + self.config['pt_epsilon'], self.config['pt_beta'])

        # QR reduction
        ndatas = torch.qr(ndatas.permute(0,2,1)).R.permute(0,2,1)

        self.config['n_nfeat'] = ndatas.size(2)
        
        # unit variance projection
        ndatas = ndatas/ndatas.norm(dim=2, keepdim=True)

        # trans-mean-sub
        ndatas = self.center_data(ndatas, self.config['n_lsamples'])

        # labels transformations
        labels = torch.arange(self.config['n_ways']).view(1, 1, self.config['n_ways'])
        labels = labels.expand(
            self.config['n_runs'], self.config['n_shot'] + self.config['n_queries'], self.config['n_ways'])
        labels = labels.clone().view(self.config['n_runs'], self.config['n_samples'])

        return ndatas, labels

    # converts pytorch probabilites to tensorflow + formatting to match the label order
    def probs_to_tf(self, probs, labels):
        # torch data into numpy arrays
        numpy_probs = probs.cpu().numpy()
        numpy_labels = labels.squeeze().cpu().numpy()

        # concatenating the results (probabilities)
        concat_probs = np.concatenate(
            [numpy_probs[numpy_labels[self.config['n_lsamples']:] == index] for index in range(self.config['n_ways'])])

        # converting the probabilities to tensorflow tensor for consistency
        probs = tf.convert_to_tensor(concat_probs)

        return probs

    # returns first c_j estimations -> support classes means
    def c_j_data_init(self, ndatas):
        reshape_0 = self.config['n_runs']
        reshape_1 = self.config['n_shot'] + self.config['n_queries']
        reshape_2 = self.config['n_ways']
        reshape_3 = self.config['n_nfeat']

        tmp_data = ndatas.reshape(reshape_0, reshape_1, reshape_2, reshape_3)
        tmp_data = tmp_data[:,:self.config['n_shot'],]

        c_j_data = tmp_data.mean(1)

        return c_j_data

    '''
    Estimate the optimal transport from the initial distribution of the feature vectors to one that 
    would correspond to a balanced draw of samples from Gaussian distributions.
    '''
    def compute_optimal_transport(self, M, p, q, epsilon=1e-6):
        n_runs, n, m = M.shape
        P = torch.exp(- self.config['gm_lambda'] * M)
        P /= P.view((n_runs, -1)).sum(1).unsqueeze(1).unsqueeze(1)
                                        
        u = torch.zeros(n_runs, n)#.cuda()
        maxiters = 1000
        iters = 1
        # normalize this matrix
        while torch.max(torch.abs(u - P.sum(2))) > epsilon:
            u = P.sum(2)
            # rows features
            P *= (p / u).view((n_runs, -1, 1))
            # columns features
            P *= (q / P.sum(1)).view((n_runs, 1, -1))
            
            if iters == maxiters: break
            iters += 1
        
        return P

    def compute_probabilities(self, dist, labels):   
        # L = Sinkhorn -> L
        L = dist[:, self.config['n_lsamples']:]
        # p = Sinkhorn -> p
        _1wq = torch.ones(self.config['n_runs'], self.config['n_usamples'])#.cuda()
        # q = Sinkhorn -> q
        q1w = (torch.ones(self.config['n_runs'], self.config['n_ways']) * self.config['n_queries'])#.cuda()
        
        p_xj_test = self.compute_optimal_transport(L, _1wq, q1w)

        p_xj = torch.zeros_like(dist)
        p_xj[:, self.config['n_lsamples']:] = p_xj_test
        
        p_xj[:,:self.config['n_lsamples']].fill_(0)
        p_xj[:,:self.config['n_lsamples']].scatter_(2, labels[:,:self.config['n_lsamples']].unsqueeze(2), 1)
        
        return p_xj


    def predict(self, dataset_test):
        '''
        Leveraging the Feature Distribution in Transfer-based Few-Shot Learning
        Code: https://github.com/yhu01/PT-MAP/tree/b58829b6e24e0441d46b8e1d9ff3149553b4456d
        Paper: https://arxiv.org/pdf/2006.03806.pdf

        Algorithm authors:
            Yuqing Hu, Electronics Dept.,IMT Atlantique, France - Orange Labs, France Cesson-Sévigné; email: yuqing.hu@imt-atlantique.fr
            Vincent Gripon, Electronics Dept.,IMT Atlantique, France Brest; email: vincent.gripon@imt-atlantique.fr
            Stéphane Pateux, Orange Labs, France Cesson-Sévigné; email: stephane.pateux@orange.com

        The following implementation replicates the PT-MAP algorithm proposed by the authors mentioned above.
        It also borrows inspiration from their pytorch implementation from GitHub.
        '''

        query_data = self.dataset_test_to_pytorch(dataset_test)

        # putting the support data and the query data together
        ndatas = torch.cat((self.dataset, query_data), 2)

        ndatas, labels = self.pt_map_init(ndatas)

        del dataset_test
        # switch to cuda
        # ndatas = ndatas.cuda()
        # labels = labels.cuda()


        # PT-MAP Algorithm 1 (page 5)

        # first c_j data set
        c_j_data = self.c_j_data_init(ndatas)

        for epoch in range(0, self.config['map_epochs']):
            # L_ij matrix creation
            dist = (ndatas.unsqueeze(2) - c_j_data.unsqueeze(1)).norm(dim=3).pow(2)

            # probas maker, Sinkhorn mapping
            probas = self.compute_probabilities(dist, labels)

            # mu_j estimation -> mu_j = g(M*,j)
            mu_j_data = self.mu_j_estimation(probas, ndatas)
                        
            # update centroids
            c_j_data = c_j_data + self.config['map_alpha'] * (mu_j_data - c_j_data)
        
        dist = (ndatas.unsqueeze(2) - c_j_data.unsqueeze(1)).norm(dim=3).pow(2)
        torch_probs = self.compute_probabilities(dist, labels).squeeze()[self.config['n_lsamples']:]

        # converting the probabilities to tensorflow tensor for consistency
        probs = self.probs_to_tf(torch_probs, labels)

        return probs

</br>
</br>
</br>

In [None]:
!python -m metadl.core.run --meta_dataset_dir=../../omniglot --code_dir=../baselines/submission