In [1]:
from sklearn.datasets import fetch_20newsgroups
import spacy
import pandas as pd
import numpy as np
import torch
import seaborn as sns
import argparse
from fastprogress import progress_bar
import pickle
import os
import random

In [2]:
torch.__version__

'1.3.1'

In [3]:
def get_texts(sampled_character_folders, labels, texts, nb_samples=None, shuffle=False):
    if nb_samples is not None:
        sampler = lambda x: random.sample(x, nb_samples)
    else:
        sampler = lambda x: x
    texts_labels = [(i, text_vectors[idx]) for i in range(len(sampled_character_folders)) for idx in sampler(list(np.where(np.array(labels) == i)[0]))]
    if shuffle:
        random.shuffle(texts_labels)
    return texts_labels
    
class TextGenerator(object):
    """
    Data Generator capable of generating batches of text.
    """
    def __init__(self, df, num_classes, num_samples_per_class, num_meta_test_classes, num_meta_test_samples_per_class):
        """
        Args:
            num_classes: Number of classes for classification (K-way)
            num_samples_per_class: num samples to generate per class in one batch
            num_meta_test_classes: Number of classes for classification (K-way) at meta-test time
            num_meta_test_samples_per_class: num samples to generate per class in one batch at meta-test time
            batch_size: size of meta batch size (e.g. number of functions)
        """
        self.num_samples_per_class = num_samples_per_class
        self.num_classes = num_classes
        self.num_meta_test_samples_per_class = num_meta_test_samples_per_class
        self.num_meta_test_classes = num_meta_test_classes
        self.dim_input = 768
        self.dim_output = 32
        self.texts = df['text'].tolist()
        self.labels = df['target'].tolist()
        #self.nlp = spacy.load('en_trf_bertbaseuncased_lg')
        class_list = np.unique(np.array(df['target'].tolist()))
        random.seed(1)
        random.shuffle(class_list)
        num_val = 6
        num_train = 8
        self.metatrain_character_folders = class_list[: num_train]
        self.metaval_character_folders = class_list[num_train:num_train + num_val]
        self.metatest_character_folders = class_list[num_train + num_val:]
    def sample_batch(self, batch_type, batch_size, shuffle=True, swap=False):
        """
        Samples a batch for training, validation, or testing
        Args:
            batch_type: meta_train/meta_val/meta_test
            shuffle: randomly shuffle classes or not
            swap: swap number of classes (N) and number of samples per class (K) or not
        Returns:
            A a tuple of (1) Image batch and (2) Label batch where
            image batch has shape [B, N, K, 768] and label batch has shape [B, N, K, N] if swap
            where B is batch size, K is number of samples per class, N is number of classes
        """
        if batch_type == "meta_train":
            text_classes = self.metatrain_character_folders
            num_classes = self.num_classes
            num_samples_per_class = self.num_samples_per_class
        elif batch_type == "meta_val":
            text_classes = self.metaval_character_folders
            num_classes = self.num_classes
            num_samples_per_class = self.num_samples_per_class
        else:
            text_classes = self.metatest_character_folders
            num_classes = self.num_meta_test_classes
            num_samples_per_class = self.num_meta_test_samples_per_class
        all_text_batches, all_label_batches = [], []
        for i in range(batch_size):
            sampled_character_folders = random.sample(list(text_classes), num_classes)
            labels_and_texts = get_texts(sampled_character_folders, self.labels, self.texts, nb_samples=num_samples_per_class, shuffle=False)
            labels = [li[0] for li in labels_and_texts]
            texts_ = [li[1] for li in labels_and_texts]
            texts = np.stack(texts_)
            labels = np.array(labels)
            labels = np.reshape(labels, (num_classes, num_samples_per_class))
            labels = np.eye(num_classes)[labels]
            texts = np.reshape(texts, (num_classes, num_samples_per_class, -1))
            batch = np.concatenate([labels, texts], 2)
            if shuffle:
                for p in range(num_samples_per_class):
                    np.random.shuffle(batch[:, p])
            labels = batch[:, :, :num_classes]
            texts = batch[:, :, num_classes:]
            if swap:
                labels = np.swapaxes(labels, 0, 1)
                texts = np.swapaxes(texts, 0, 1)
            all_text_batches.append(texts)
            all_label_batches.append(labels)
        all_text_batches = np.stack(all_text_batches)
        all_label_batches = np.stack(all_label_batches)
        return all_text_batches, all_label_batches

In [4]:
text_vectors = pickle.load(open('../data/mini_newsgroup_vectors.pkl','rb'))
mini_df = pickle.load(open('../data/mini_newsgroup_data.pkl','rb'))

In [5]:
class ProtoNetText(torch.nn.Module):
    def __init__(self, embedding_size, hidden_size, proto_dim):
        super(ProtoNetText, self).__init__()
        self.embed_size = embedding_size
        self.proto_dim = proto_dim
        self.hidden_size = hidden_size
        self.l1 = torch.nn.Linear(self.embed_size, self.hidden_size)
        self.rep_block =torch.nn.Sequential(*[torch.nn.BatchNorm1d(hidden_size), torch.nn.Linear(self.hidden_size, self.hidden_size)])
        self.final = torch.nn.Linear(self.hidden_size, self.proto_dim)
    def forward(self, x):
        return self.final(self.rep_block(self.l1(x)))

In [6]:
# x_latent, q_latent, labels_onehot, num_classes, num_support, num_queries
class ProtoLoss(torch.nn.Module):
    def __init__(self, num_classes, num_support, num_queries, ndim):
        super(ProtoLoss,self).__init__()
        self.num_classes = num_classes
        self.num_support = num_support
        self.num_queries = num_queries
        self.ndim = ndim
    
    def euclidean_distance(self, a, b):
        # a.shape = N x D
        # b.shape = M x D
        N, D = a.shape[0], a.shape[1]
        M = b.shape[0]
        a = torch.repeat_interleave(a.unsqueeze(1), repeats = M, dim = 1)
        b = torch.repeat_interleave(b.unsqueeze(0), repeats = N, dim = 0)
        return 1.*torch.sum(torch.pow((a-b), 2),2)
        
    def forward(self, x, q, labels_onehot):
        protox = torch.mean(1.*x.reshape([self.num_classes,self.num_support,self.ndim]),1)
        dists = self.euclidean_distance(protox, q)
        logpy = torch.log_softmax(-1.*dists,0).transpose(1,0).view(self.num_classes,self.num_queries,self.num_classes)
        ce_loss = -1. * torch.mean(torch.mean(logpy * labels_onehot.float(),1))
        accuracy = torch.mean((torch.argmax(labels_onehot.float(),-1).float() == torch.argmax(logpy,-1).float()).float())
        return ce_loss, accuracy

In [7]:
n_way = 5
k_shot = 5
proto_dim = 32
n_query = 2
n_meta_test_way = 5
k_meta_test_shot = 5
n_meta_test_query = 2
num_epochs = 20
num_episodes = 200
hidden_dim = 100

In [8]:
embed_size = 768
model_text = ProtoNetText(embed_size, hidden_dim, proto_dim)
optimizer_text = torch.optim.Adam(model_text.parameters(), lr=1e-4)
criterion = ProtoLoss(n_way, k_shot, n_query, proto_dim)

In [9]:
text_generator_ = TextGenerator(mini_df, n_way, k_shot+n_query, n_meta_test_way, k_meta_test_shot+n_meta_test_query)

In [15]:
x, y = text_generator_.sample_batch('meta_train', 1, shuffle = True)

In [45]:
x.shape

(1, 5, 7, 768)

In [46]:
y.shape

(1, 5, 7, 5)

In [37]:
lookup_dict = {i:np.where(y.reshape(-1,n_way)[:,i] == 1.)[0] for i in range(n_way)}

In [72]:
rearrage_list = []
np.ravel([lookup_dict[i] for i in range(n_way)])

array([10, 18, 19, 20, 21, 29, 30,  0, 11, 22, 23, 27, 31, 33,  7,  8, 12,
       13, 16, 24, 32,  1,  3,  5,  6,  9, 14, 25,  2,  4, 15, 17, 26, 28,
       34])

In [77]:
y.reshape(-1,n_way)[np.ravel([lookup_dict[i] for i in range(n_way)])]

array([[1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0., 0., 0., 1.],
       [0., 0.

In [67]:
map(rearrage_list.append(lookup_list), [lookup_dict[i] for i in range(n_way)])

NameError: name 'lookup_list' is not defined

In [66]:
rearrage_list

[array([10, 18, 19, 20, 21, 29, 30])]

In [39]:
lookup_dict[0][:n_way]

array([10, 18, 19, 20, 21])

In [43]:
def get_latents(x,y, embed_size, n_way, n_query, k_shot):
    lookup_dict = {i:np.where(y.reshape(-1,n_way)[:,i] == 1.)[0] for i in range(n_way)}
    lookup_list = np.ravel([lookup_dict[i] for i in range(n_way)])
    ### 
    x_shuffle = x.reshape(-1, embed_size)[lookup_list].reshape(1, n_way, n_query, embed_size)
    y_shuffle = y.reshape(-1, n_way)[lookup_list].reshape(1, n_way, n_query, n_way)
    ###
    x_support, x_query = x_shuffle[:,:,:k_shot,:], x_shuffle[:,:,k_shot:,:]
    y_support, y_query = y_shuffle[:,:,:k_shot,:], y_shuffle[:,:,k_shot:,:]
    labels_onehot = y_query.reshape(n_way, n_query, n_way)
    support_input_t = torch.Tensor(x_support).view(-1, embed_size)
    query_input_t = torch.Tensor(x_query).view(-1, embed_size)
    return support_input_t, query_input_t, labels_onehot

(5, 768)

In [52]:
x.reshape(-1, embed_size)

(35, 768)

In [63]:
x[:,:,:k_shot,:].shape

(1, 5, 5, 768)

In [64]:
x

(1, 5, 7, 768)