In [1]:
from sklearn.datasets import fetch_20newsgroups
import spacy
import pandas as pd
import numpy as np
import torch
import seaborn as sns
import argparse
from fastprogress import progress_bar
import pickle
import os
import random

In [2]:
torch.__version__

'1.3.1'

In [3]:
#!pip install seaborn

In [4]:
def get_texts(sampled_character_folders, labels, texts, nb_samples=None, shuffle=False):
    if nb_samples is not None:
        sampler = lambda x: random.sample(x, nb_samples)
    else:
        sampler = lambda x: x
    texts_labels = [(i, text_vectors[idx]) for i in range(len(sampled_character_folders)) for idx in sampler(list(np.where(np.array(labels) == i)[0]))]
    if shuffle:
        random.shuffle(texts_labels)
    return texts_labels
    
class TextGenerator(object):
    """
    Data Generator capable of generating batches of text.
    """
    def __init__(self, df, num_classes, num_samples_per_class, num_meta_test_classes, num_meta_test_samples_per_class):
        """
        Args:
            num_classes: Number of classes for classification (K-way)
            num_samples_per_class: num samples to generate per class in one batch
            num_meta_test_classes: Number of classes for classification (K-way) at meta-test time
            num_meta_test_samples_per_class: num samples to generate per class in one batch at meta-test time
            batch_size: size of meta batch size (e.g. number of functions)
        """
        self.num_samples_per_class = num_samples_per_class
        self.num_classes = num_classes
        self.num_meta_test_samples_per_class = num_meta_test_samples_per_class
        self.num_meta_test_classes = num_meta_test_classes
        self.dim_input = 768
        self.dim_output = 32
        self.texts = df['text'].tolist()
        self.labels = df['target'].tolist()
        #self.nlp = spacy.load('en_trf_bertbaseuncased_lg')
        class_list = np.unique(np.array(df['target'].tolist()))
        random.seed(1)
        random.shuffle(class_list)
        num_val = 6
        num_train = 8
        self.metatrain_character_folders = class_list[: num_train]
        self.metaval_character_folders = class_list[num_train:num_train + num_val]
        self.metatest_character_folders = class_list[num_train + num_val:]
    def sample_batch(self, batch_type, batch_size, shuffle=True, swap=False):
        """
        Samples a batch for training, validation, or testing
        Args:
            batch_type: meta_train/meta_val/meta_test
            shuffle: randomly shuffle classes or not
            swap: swap number of classes (N) and number of samples per class (K) or not
        Returns:
            A a tuple of (1) Image batch and (2) Label batch where
            image batch has shape [B, N, K, 768] and label batch has shape [B, N, K, N] if swap
            where B is batch size, K is number of samples per class, N is number of classes
        """
        if batch_type == "meta_train":
            text_classes = self.metatrain_character_folders
            num_classes = self.num_classes
            num_samples_per_class = self.num_samples_per_class
        elif batch_type == "meta_val":
            text_classes = self.metaval_character_folders
            num_classes = self.num_classes
            num_samples_per_class = self.num_samples_per_class
        else:
            text_classes = self.metatest_character_folders
            num_classes = self.num_meta_test_classes
            num_samples_per_class = self.num_meta_test_samples_per_class
        all_text_batches, all_label_batches = [], []
        for i in range(batch_size):
            sampled_character_folders = random.sample(list(text_classes), num_classes)
            labels_and_texts = get_texts(sampled_character_folders, self.labels, self.texts, nb_samples=num_samples_per_class, shuffle=False)
            labels = [li[0] for li in labels_and_texts]
            texts_ = [li[1] for li in labels_and_texts]
            texts = np.stack(texts_)
            labels = np.array(labels)
            labels = np.reshape(labels, (num_classes, num_samples_per_class))
            labels = np.eye(num_classes)[labels]
            texts = np.reshape(texts, (num_classes, num_samples_per_class, -1))
            batch = np.concatenate([labels, texts], 2)
            if shuffle:
                for p in range(num_samples_per_class):
                    np.random.shuffle(batch[:, p])
            labels = batch[:, :, :num_classes]
            texts = batch[:, :, num_classes:]
            if swap:
                labels = np.swapaxes(labels, 0, 1)
                texts = np.swapaxes(texts, 0, 1)
            all_text_batches.append(texts)
            all_label_batches.append(labels)
        all_text_batches = np.stack(all_text_batches)
        all_label_batches = np.stack(all_label_batches)
        return all_text_batches, all_label_batches

### Load mini newsgroup dataset

In [5]:
text_vectors = pickle.load(open('../data/mini_newsgroup_vectors.pkl','rb'))
mini_df = pickle.load(open('../data/mini_newsgroup_data.pkl','rb'))
#mini_df, text_vectors = get_mini_dataset(samples_per_class = 30, embebeddings = 'BERT')

### Complete protonet pipeline

In [6]:
class ProtoNetText(torch.nn.Module):
    def __init__(self, embedding_size, hidden_size, proto_dim):
        super(ProtoNetText, self).__init__()
        self.embed_size = embedding_size
        self.proto_dim = proto_dim
        self.hidden_size = hidden_size
        self.l1 = torch.nn.Linear(self.embed_size, self.hidden_size)
        self.rep_block =torch.nn.Sequential(*[torch.nn.BatchNorm1d(hidden_size), torch.nn.Linear(self.hidden_size, self.hidden_size)])
        self.final = torch.nn.Linear(self.hidden_size, self.proto_dim)
    def forward(self, x):
        return self.final(self.rep_block(self.l1(x)))

In [7]:
# x_latent, q_latent, labels_onehot, num_classes, num_support, num_queries
class ProtoLoss(torch.nn.Module):
    def __init__(self, num_classes, num_support, num_queries, ndim):
        super(ProtoLoss,self).__init__()
        self.num_classes = num_classes
        self.num_support = num_support
        self.num_queries = num_queries
        self.ndim = ndim
    
    def euclidean_distance(self, a, b):
        N, D = a.shape[0], a.shape[1]
        M = b.shape[0]
        a = torch.repeat_interleave(a.unsqueeze(1), repeats = M, dim = 1)
        b = torch.repeat_interleave(b.unsqueeze(0), repeats = N, dim = 0)
        return 1.*torch.sum(torch.pow((a-b), 2),2)
        
    def forward(self, x, q, labels_onehot):
        protox = torch.mean(1.*x.reshape([self.num_classes,self.num_support,self.ndim]),1)
        dists = self.euclidean_distance(protox, q)
        logpy = torch.log_softmax(-1.*dists,0).transpose(1,0).view(self.num_classes,self.num_queries,self.num_classes)
        ce_loss = -1. * torch.mean(torch.mean(logpy * labels_onehot.float(),1))
        accuracy = torch.mean((torch.argmax(labels_onehot.float(),-1).float() == torch.argmax(logpy,-1).float()).float())
        return ce_loss, accuracy

In [8]:
n_way = 5
k_shot = 5
proto_dim = 32
n_query = 2
n_meta_test_way = 5
k_meta_test_shot = 5
n_meta_test_query = 2
num_epochs = 20
num_episodes = 200
hidden_dim = 100

In [9]:
embed_size = 768
model_text = ProtoNetText(embed_size, hidden_dim, proto_dim)
optimizer_text = torch.optim.Adam(model_text.parameters(), lr=1e-4)
criterion = ProtoLoss(n_way, k_shot, n_query, proto_dim)

In [10]:
text_generator_ = TextGenerator(mini_df, n_way, k_shot+n_query, n_meta_test_way, k_meta_test_shot+n_meta_test_query)

In [11]:
def get_latents(x,y, embed_size, n_way, n_query, k_shot):
    x_support, x_query = x[:,:,:k_shot,:], x[:,:,k_shot:,:]
    y_support, y_query = y[:,:,:k_shot,:], y[:,:,k_shot:,:]
    labels_onehot = y_query.reshape(n_way, n_query, n_way)
    support_input_t = torch.Tensor(x_support).view(-1, embed_size)
    query_input_t = torch.Tensor(x_query).view(-1, embed_size)
    return support_input_t, query_input_t, labels_onehot

In [12]:
def get_latents_new(x,y, embed_size, n_way, n_query, k_shot):
    lookup_dict = {i:np.where(y.reshape(-1,n_way)[:,i] == 1.)[0] for i in range(n_way)}
    lookup_list = np.ravel([lookup_dict[i] for i in range(n_way)])
    ### 
    x_shuffle = x.reshape(-1, embed_size)[lookup_list].reshape(1, n_way, n_query+k_shot, embed_size)
    y_shuffle = y.reshape(-1, n_way)[lookup_list].reshape(1, n_way, n_query+k_shot, n_way)
    ###
    x_support, x_query = x_shuffle[:,:,:k_shot,:], x_shuffle[:,:,k_shot:,:]
    y_support, y_query = y_shuffle[:,:,:k_shot,:], y_shuffle[:,:,k_shot:,:]
    labels_onehot = y_query.reshape(n_way, n_query, n_way)
    support_input_t = torch.Tensor(x_support).view(-1, embed_size)
    query_input_t = torch.Tensor(x_query).view(-1, embed_size)
    return support_input_t, query_input_t, labels_onehot

In [13]:
for ep in range(num_epochs):
    print(f'Epoch: {ep}')
    for epi in range(num_episodes):
        x, y = text_generator_.sample_batch('meta_train', 1, shuffle = True)
        support_input_t, query_input_t, labels_onehot = get_latents_new(x,y, embed_size, n_way, n_query, k_shot)
        x_latent = model_text(support_input_t)
        q_latent = model_text(query_input_t)
        # Compute and print loss
        loss, accuracy = criterion(x_latent, q_latent, torch.tensor(labels_onehot))
        optimizer_text.zero_grad()
        loss.backward()
        optimizer_text.step()
        if epi % 50 == 0:
            with torch.no_grad():
                valid_x, valid_y = text_generator_.sample_batch('meta_val', 1, shuffle = True)
                support_input_valid, query_input_valid, labels_onehot_valid = get_latents_new(valid_x,valid_y, embed_size, n_way, n_query, k_shot)
                x_latent_valid = model_text(support_input_valid)
                q_latent_valid = model_text(query_input_valid)
                # Compute and print loss
                valid_loss, valid_acc = criterion(x_latent_valid, q_latent_valid, torch.tensor(labels_onehot_valid))
                print(f'Epoc {ep}/{num_epochs} Episode {epi}/{num_episodes}, Validation Accuracy: {round(valid_acc.item(),3)}, Validation Loss: {round(valid_loss.item(),3)}')
print('Testing ... . . . .. . . . ')
meta_test_accuracies = []
for epi in range(1000):
    test_x, test_y = text_generator_.sample_batch('meta_test', 1, shuffle = True)
    support_input_test, query_input_test, labels_onehot_test = get_latents_new(test_x,test_y, embed_size, n_way, n_query, k_shot)
    with torch.no_grad():
        x_latent_test = model_text(support_input_test)
        q_latent_test = model_text(query_input_test)
        # Compute and print loss
        test_loss, test_acc = criterion(x_latent_test, q_latent_test, torch.tensor(labels_onehot_valid))
        if (epi + 1) % 50 == 0:
            print(f'Meta test Episode {epi}/{1000}, Test Accuracy: {round(test_acc.item(),3)}, Test Loss: {round(test_loss.item(),3)}')
        meta_test_accuracies.append(test_acc)

Epoch: 0
Epoc 0/40 Episode 0/200, Validation Accuracy: 0.4, Validation Loss: 0.478
Epoc 0/40 Episode 50/200, Validation Accuracy: 0.5, Validation Loss: 0.261
Epoc 0/40 Episode 100/200, Validation Accuracy: 0.6, Validation Loss: 0.188
Epoc 0/40 Episode 150/200, Validation Accuracy: 0.5, Validation Loss: 0.183
Epoch: 1
Epoc 1/40 Episode 0/200, Validation Accuracy: 0.7, Validation Loss: 0.155
Epoc 1/40 Episode 50/200, Validation Accuracy: 0.8, Validation Loss: 0.142
Epoc 1/40 Episode 100/200, Validation Accuracy: 0.6, Validation Loss: 0.165
Epoc 1/40 Episode 150/200, Validation Accuracy: 0.3, Validation Loss: 0.289
Epoch: 2
Epoc 2/40 Episode 0/200, Validation Accuracy: 0.9, Validation Loss: 0.203
Epoc 2/40 Episode 50/200, Validation Accuracy: 0.7, Validation Loss: 0.227
Epoc 2/40 Episode 100/200, Validation Accuracy: 0.8, Validation Loss: 0.171
Epoc 2/40 Episode 150/200, Validation Accuracy: 0.7, Validation Loss: 0.132
Epoch: 3
Epoc 3/40 Episode 0/200, Validation Accuracy: 0.6, Validation

In [15]:
avg_acc = np.mean(meta_test_accuracies)
stds = np.std(meta_test_accuracies)
print('Average Meta-Test Accuracy: {:.5f}, Meta-Test Accuracy Std: {:.5f}'.format(avg_acc, stds))

Average Meta-Test Accuracy: 0.83860, Meta-Test Accuracy Std: 0.22740
