In [63]:
"""
data process
"""

import os
import json
import random
import copy
from collections import Counter
from itertools import chain
from typing import Dict, Tuple, Optional, List, Union

import gensim
import numpy as np

DEBUG_PRINT = False

class PrototypicalData(object):
    def __init__(self, output_path: str, sequence_length: int = 100, num_classes: int = 2, num_support: int = 5,
                 num_queries: int = 50, num_tasks: int = 1000, num_eval_tasks: int = 100,
                 stop_word_path: Optional[str] = None,
                 embedding_size: Optional[int] = None, low_freq: int = 5,
                 word_vector_path: Optional[str] = None, is_training: bool = True):
        """
        init method
        :param output_path: path of train/eval data
        :param num_classes: number of support class
        :param num_support: number of support sample per class
        :param num_queries: number of query sample per class
        :param num_tasks: number of pre-sampling tasks, this will speeding up train
        :param num_eval_tasks: number of pre-sampling tasks in eval stage
        :param stop_word_path: path of stop word file
        :param embedding_size: embedding size
        :param low_freq: frequency of words
        :param word_vector_path: path of word vector file(eg. word2vec, glove)
        :param is_training: bool
        """

        self.__output_path = output_path
        if not os.path.exists(self.__output_path):
            os.makedirs(self.__output_path)

        self.__sequence_length = sequence_length
        self.__num_classes = num_classes
        self.__num_support = num_support
        self.__num_queries = num_queries
        self.__num_tasks = num_tasks
        self.__num_eval_tasks = num_eval_tasks
        self.__stop_word_path = stop_word_path
        self.__embedding_size = embedding_size
        self.__low_freq = low_freq
        self.__word_vector_path = word_vector_path
        self.__is_training = is_training

        self.vocab_size = None
        self.word_vectors = None
        self.current_category_index = 0  # record current sample category

        print("stop word path: ", self.__stop_word_path)
        print("word vector path: ", self.__word_vector_path)

    @staticmethod
    def load_data(data_path: str) -> Dict[str, Dict[str, List[List[str]]]]:
        """
        read train/eval data
        :param data_path:
        :return: dict. {class_name: {sentiment: [[]], }, ...}
        """
        category_files = os.listdir(data_path)
        categories_data = {}
        for category_file in category_files:
            file_path = os.path.join(data_path, category_file)
            sentiment_data = {}
            with open(file_path, "r", encoding="utf8") as fr:
                for line in fr.readlines():
                    content, label = line.strip().split("\t")
                    if sentiment_data.get(label, None):
                        sentiment_data[label].append(content.split(" "))
                    else:
                        sentiment_data[label] = [content.split(" ")]

            # print("task name: ", category_file)
            # print("pos samples length: ", len(sentiment_data["1"]))
            # print("neg samples length: ", len(sentiment_data["-1"]))
            categories_data[category_file] = sentiment_data
        return categories_data

    def remove_stop_word(self, data: Dict[str, Dict[str, List[List[str]]]]) -> List[str]:
        """
        remove low frequency words and stop words, construct vocab
        :param data: {class_name: {sentiment: [[]], }, ...}
        :return:
        """
        all_words = []
        for category, category_data in data.items():
            for sentiment, sentiment_data in category_data.items():
                all_words.extend(list(chain(*sentiment_data)))
        word_count = Counter(all_words)  # statistic the frequency of words
        sort_word_count = sorted(word_count.items(), key=lambda x: x[1], reverse=True)

        # remove low frequency word
        words = [item[0] for item in sort_word_count if item[1] > self.__low_freq]

        # if stop word file exists, then remove stop words
        if self.__stop_word_path:
            with open(self.__stop_word_path, "r", encoding="utf8") as fr:
                stop_words = [line.strip() for line in fr.readlines()]
            words = [word for word in words if word not in stop_words]

        return words

    def get_word_vectors(self, vocab: List[str]) -> np.ndarray:
        """
        load word vector file,
        :param vocab: vocab
        :return:
        """
        pad_vector = np.zeros(self.__embedding_size)  # set the "<pad>" vector to 0
        word_vectors = (1 / np.sqrt(len(vocab) - 1) * (2 * np.random.rand(len(vocab) - 1, self.__embedding_size) - 1))
        word_vectors = np.vstack((pad_vector, word_vectors))
        if DEBUG_PRINT:
            # print(vocab)
            print(f"get_word_vectors word_vectors={word_vectors.shape}")
        
        # load glove vectors
        # glove_vector = {}
        # with open(self.__word_vector_path, "r", encoding="utf8") as fr:
        #     for line in fr.readlines():
        #         line_list = line.strip().split(" ")
        #         glove_vector[line_list[0]] = line_list[1:]

        # for i in range(1, len(vocab)):
        #     if glove_vector.get(vocab[i], None):
        #         word_vectors[i, :] = glove_vector[vocab[i]]
        #     else:
        #         print(vocab[i] + "not exist word vector file")

        # # load gensim word2vec vectors
        # if os.path.splitext(self.__word_vector_path)[-1] == ".bin":
        #     word_vec = gensim.models.KeyedVectors.load_word2vec_format(self.__word_vector_path, binary=True)
        # else:
        #     word_vec = gensim.models.KeyedVectors.load_word2vec_format(self.__word_vector_path)
        #
        # for i in range(1, len(vocab)):
        #     try:
        #         vector = word_vec.wv[vocab[i]]
        #         word_vectors[i, :] = vector
        #     except:
        #         print(vocab[i] + "not exist word vector file")

        return word_vectors

    def gen_vocab(self, words: List[str]) -> Dict[str, int]:
        """
        generate word_to_index mapping table
        :param words:
        :return:
        """
        if self.__is_training:
            vocab = ["<pad>", "<unk>"] + words

            self.vocab_size = len(vocab)

            if self.__word_vector_path:
                word_vectors = self.get_word_vectors(vocab)
                self.word_vectors = word_vectors
                # save word vector to npy file
                np.save(os.path.join(self.__output_path, "word_vectors.npy"), self.word_vectors)

            word_to_index = dict(zip(vocab, list(range(len(vocab)))))

            # save word_to_index to json file
            with open(os.path.join(self.__output_path, "word_to_index.json"), "w") as f:
                json.dump(word_to_index, f)
        else:
            with open(os.path.join(self.__output_path, "word_to_index.json"), "r") as f:
                word_to_index = json.load(f)

        return word_to_index

    @staticmethod
    def trans_to_index(data: Dict[str, Dict[str, List[List[str]]]], word_to_index: Dict[str, int]) -> \
            Dict[str, Dict[str, List[List[int]]]]:
        """
        transformer token to id
        :param data:
        :param word_to_index:
        :return: {class_name: [[], [], ], ..}
        """
        data_ids = {category: {sentiment: [[word_to_index.get(token, word_to_index["<unk>"]) for token in line]
                                           for line in sentiment_data]
                               for sentiment, sentiment_data in category_data.items()}
                    for category, category_data in data.items()}
        return data_ids

    def choice_support_query(self, task_data: Dict[str, List[List[int]]])\
            -> Tuple[List[List[List[int]]], List[List[int]], List[int]]:
        """
        randomly selecting support set, query set form a task.
        :param task_data: all data for a task
        :return:
        """
        label_to_index = {"1": 0, "-1": 1}
        # if self.__is_training:
        #     with open(os.path.join(self.__output_path, "label_to_index.json"), "w") as f:
        #         json.dump(label_to_index, f)

        pos_samples = task_data["1"]
        neg_samples = task_data["-1"]
        pos_support = random.sample(pos_samples, self.__num_support)
        neg_support = random.sample(neg_samples, self.__num_support)

        pos_others = copy.copy(pos_samples)
        [pos_others.remove(data) for data in pos_support]

        neg_others = copy.copy(neg_samples)
        [neg_others.remove(data) for data in neg_support]

        pos_query = random.sample(pos_others, self.__num_queries)
        neg_query = random.sample(neg_others, self.__num_queries)

        # padding
        pos_support = self.padding(pos_support)
        neg_support = self.padding(neg_support)
        pos_query = self.padding(pos_query)
        neg_query = self.padding(neg_query)

        support_set = [pos_support, neg_support]  # [num_classes, num_support, sequence_length]
        query_set = pos_query + neg_query  # [num_classes * num_queries, sequence_length]
        labels = [label_to_index["1"]] * len(pos_query) + [label_to_index["-1"]] * len(neg_query)

        return support_set, query_set, labels

    def samples(self, data_ids: Dict[str, Dict[str, List[List[int]]]]) \
            -> List[Dict[str, Union[List[List[List[int]]], List[List[int]], List[int]]]]:
        """
        positive and negative sample from raw data
        :param data_ids:
        :return:
        """
        # product name list
        category_list = list(data_ids.keys())

        tasks = []
        if self.__is_training:
            num_tasks = self.__num_tasks
        else:
            num_tasks = self.__num_eval_tasks
        for i in range(num_tasks):
            # randomly choice a category to construct train sample
            try:
                support_category = random.choice(category_list)
                support_set, query_set, labels = self.choice_support_query(data_ids[support_category])
                tasks.append(dict(support=support_set, queries=query_set, labels=labels))
            except:
                pass
        return tasks

    def gen_data(self, file_path: str) -> Dict[str, Dict[str, List[List[int]]]]:
        """
        Generate data that is eventually input to the model
        :return:
        """
        # load data
        data = self.load_data(file_path)
        # remove stop word
        words = self.remove_stop_word(data)
        word_to_index = self.gen_vocab(words)
        self.word_to_index = word_to_index
        self.index_to_word = {v:k for k,v in word_to_index.items()}

        data_ids = self.trans_to_index(data, word_to_index)
        return data_ids, data

    def padding(self, sentences: List[List[int]]) -> List[List[int]]:
        """
        padding according to the predefined sequence length
        :param sentences:
        :return:
        """
        sentence_pad = [sentence[:self.__sequence_length] if len(sentence) > self.__sequence_length
                        else sentence + [0] * (self.__sequence_length - len(sentence))
                        for sentence in sentences]
        return sentence_pad

    def next_batch(self, data_ids: Dict[str, Dict[str, List[List[int]]]]) \
            -> Dict[str, Union[List[List[List[int]]], List[List[int]], List[int]]]:
        """
        train a task at every turn
        :param data_ids:
        :return:
        """

        tasks = self.samples(data_ids)

        for task in tasks:
            yield task

config = {
  "model_name": "prototypical",
  "epochs": 30,
  "checkpoint_every": 100,
  "eval_every": 500,
  "learning_rate": 1e-3,
  "optimization": "adam",
  "embedding_size": 300,
  "hidden_sizes": [128],
  "attention_size": 64,
  "num_support": 10,
  "num_queries": 50,
  "num_classes": 2,
  "num_tasks": 200,
  "num_eval_tasks": 100,
  "low_freq": 3,
  "sequence_length": 200,
  "keep_prob": 0.7,
  "l2_reg_lambda": 0.0,
  "max_grad_norm": 5.0,
  "train_data": "./reviews/newtrain",
  "eval_data": "./reviews/neweval",
  "stop_word_path": "./reviews/english",
  "output_path": "./output/prototypical",
  "word_vector_path": "./word_embedded/new_word2vec_model.txt",
  "ckpt_model_path": "./output/prototypical/ckpt_model",
  "pb_model_path": "./output/prototypical/pb_model"
}

In [15]:
"""
performance metrics function
"""
import torch
from torch import nn


def accuracy(pred_y: torch.Tensor, true_y: torch.Tensor):
    """
    Calculate accuracy
    :param pred_y: predict result
    :param true_y: true result
    :return:
    """
    return np.sum(pred_y == true_y) / len(pred_y)


def binary_precision(pred_y: torch.Tensor, true_y: torch.Tensor, positive=1):
    """
    Calculate the precision of binary classification
    :param pred_y: predict result
    :param true_y: true result
    :param positive: index of positive label
    :return:
    """
    tp = np.sum((pred_y == positive) & (true_y == positive))
    fp = np.sum((pred_y == positive) & (true_y != positive))
    # fn = np.sum((pred_y != positive) & (true_y == positive))
    # tn = np.sum((pred_y != positive) & (true_y != positive))
    if (tp + fp) != 0:
        return tp / (tp + fp)
    else:
        return 0.0


def binary_recall(pred_y, true_y, positive=1):
    """
    Calculate the recall of binary classification
    :param pred_y: predict result
    :param true_y: true result
    :param positive: index of positive label
    :return:
    """
    tp = np.sum((pred_y == positive) & (true_y == positive))
    # fp = np.sum((pred_y == positive) & (true_y != positive))
    fn = np.sum((pred_y != positive) & (true_y == positive))
    # tn = np.sum((pred_y != positive) & (true_y != positive))
    if (tp + fn) != 0:
        return tp / (tp + fn)
    else:
        return 0


def binary_f_beta(pred_y, true_y, beta=1.0, positive=1):
    """
    Calculate the f beta of binary classification
    :param pred_y: predict result
    :param beta: beta parameter
    :param true_y: true result
    :param positive: index of positive label
    :return:
    """
    precision = binary_precision(pred_y, true_y, positive)
    recall = binary_recall(pred_y, true_y, positive)
    if (beta * beta * precision + recall) != 0:
        return (1 + beta * beta) * precision * recall / (beta * beta * precision + recall)
    else:
        return 0


def get_binary_metrics(pred_y, true_y, f_beta=1.0):
    """
    Calculate various performance metrics of binary classification
    :param pred_y: predict result
    :param true_y: true result
    :param f_beta: beta parameter
    :return:
    """
    acc = accuracy(pred_y, true_y)
    recall = binary_recall(pred_y, true_y)
    precision = binary_precision(pred_y, true_y)
    f_beta = binary_f_beta(pred_y, true_y, f_beta)
    return acc, recall, precision, f_beta


def multi_precision(pred_y, true_y, labels):
    """
    Calculate the precision of multi classification
    :param pred_y: predict result
    :param true_y: true result
    :param labels: label list
    :return:
    """
    precisions = [binary_precision(pred_y, true_y, label) for label in labels]
    prec = np.mean(precisions)
    return prec


def multi_recall(pred_y, true_y, labels):
    """
    Calculate the recall of multi classification
    :param pred_y: predict result
    :param true_y: true result
    :param labels: label list
    :return:
    """
    recalls = [binary_recall(pred_y, true_y, label) for label in labels]
    rec = np.mean(recalls)
    return rec


def multi_f_beta(pred_y, true_y, labels, beta=1.0):
    """
    Calculate the f value of multi classification
    :param pred_y: predict result
    :param true_y: true result
    :param labels: label list
    :param beta: beta parameter
    :return:
    """
    f_betas = [binary_f_beta(pred_y, true_y, beta, label) for label in labels]
    f_beta = np.mean(f_betas)
    return f_beta


def get_multi_metrics(pred_y, true_y, labels, f_beta=1.0):
    """
    Calculate various performance metrics of multi classification
    :param pred_y: predict result
    :param true_y: true result
    :param labels: label list
    :param beta: beta parameter
    :return:
    """
    acc = accuracy(pred_y, true_y)
    recall = multi_recall(pred_y, true_y, labels)
    precision = multi_precision(pred_y, true_y, labels)
    f_beta = multi_f_beta(pred_y, true_y, labels, f_beta)
    return acc, recall, precision, f_beta

In [21]:
import torch
from torch import nn
import torch.nn.functional as F
import os
from pathlib import Path
import numpy as np

def extract_batch(batch: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    supports = torch.tensor(batch['support'], device=device)
    query = torch.tensor(batch['queries'], device=device)
    labels = torch.tensor(batch['labels'], device=device)
    
    # https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
    # https://github.com/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb

    # supports = [num_classes, batch, length]
    pos_support = supports[0, :, :]
    pos_support_labels = torch.zeros(pos_support.size(0))
    neg_support = supports[1, :, :]
    neg_support_labels = torch.ones(neg_support.size(0))
    supports = torch.cat([pos_support, neg_support], dim=0).to(device=device)
    supports_labels = torch.cat([pos_support_labels, neg_support_labels], dim=0).to(device=device)
    
    return supports, supports_labels, query, labels

class CNN(nn.Module):
    
    def __init__(self, vocab_size: int, embedding_dim: int, fs: List[int], channels: int, output_dim: int) -> None:
        super(CNN, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1, out_channels=channels, kernel_size=(n, embedding_dim)) for n in fs
        ])
        
        self.fc = nn.Linear(len(fs) * channels, output_dim)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if DEBUG_PRINT:
            print(f"CNN x={x.size()}")
        embedded = self.embed(x)
        if DEBUG_PRINT:
            print(f"CNN embeded={embedded.size()}")
        embedded = embedded.unsqueeze(1)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        output = torch.cat(pooled, dim=1)
        return self.fc(output)
    
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        return self(x).argmax(1)
    
    def scores(self, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        supports, supports_labels, query, query_labels = extract_batch(batch)
        supports_out = model(supports)
        query_out = model(query)

        if DEBUG_PRINT:
            print(f"train support_out={supports_out.size()} query_out={query_out.size()}")
        
        if DEBUG_PRINT:
            print(f"train indices={torch.nonzero(supports_labels == 0).size()}")
        proto_pos = supports_out[torch.nonzero(supports_labels == 0)]
        proto_neg = supports_out[torch.nonzero(supports_labels == 1)]
        
        if DEBUG_PRINT:
            print(f"train proto_pos={proto_pos.size()}")
            
        proto_pos = proto_pos.mean(0)
        proto_neg = proto_neg.mean(0)
        
        if DEBUG_PRINT:
            print(f"train proto_pos_meaned={proto_pos.size()}")
            
        proto = torch.cat([proto_pos, proto_neg], dim=0).to(dtype=torch.float)
        if DEBUG_PRINT:
            print(f"train proto={proto.size()}")
            
        dists = torch.cdist(query_out, proto)
        if DEBUG_PRINT:
            print(f"train dists={dists.size()} labels={query_labels.size()}")

        scores = -dists
        return scores, query_out, query_labels
    
data_loader = PrototypicalData(output_path=config["output_path"],
                                    sequence_length=config["sequence_length"],
                                    num_classes=config["num_classes"],
                                    num_support=config["num_support"],
                                    num_queries=config["num_queries"],
                                    num_tasks=config["num_tasks"],
                                    num_eval_tasks=config["num_eval_tasks"],
                                    embedding_size=config["embedding_size"],
                                    stop_word_path=config["stop_word_path"],
                                    word_vector_path=config["word_vector_path"],
                                    is_training=True)
eval_loader = PrototypicalData(output_path=config["output_path"],
                                    sequence_length=config["sequence_length"],
                                    num_classes=config["num_classes"],
                                    num_support=config["num_support"],
                                    num_queries=config["num_queries"],
                                    num_tasks=config["num_tasks"],
                                    num_eval_tasks=config["num_eval_tasks"],
                                    embedding_size=config["embedding_size"],
                                    stop_word_path=config["stop_word_path"],
                                    word_vector_path=config["word_vector_path"],
                                    is_training=False)
train_tasks, _ = data_loader.gen_data(config["train_data"])
eval_tasks, _ = eval_loader.gen_data(config["eval_data"])

epochs = config["epochs"]
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CNN(data_loader.vocab_size, 512, [1,2,4], 128, 2).to(device)
lr = 0.01
opt = torch.optim.SGD(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
total_loss = 0
interval = config["checkpoint_every"]
current_step = 0

ckpt_path = Path(config['ckpt_model_path'])


for ep in range(1, epochs+1):
    for task in data_loader.next_batch(train_tasks):
        model.zero_grad()
        # https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
        # https://github.com/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb
    
        scores, _, labels = model.scores(task)
        
        loss = loss_fn(scores, labels)
        loss.backward()
        
        opt.step()
        
        total_loss += loss.item()
        
        current_step += 1
        
        if current_step % interval == 0:
            with torch.no_grad():
                val_losses = []
                val_accs = []
                val_recalls = []
                val_precs = []
                val_fbeta = []
                for task in eval_loader.next_batch(eval_tasks):
                    scores, preds_out, labels = model.scores(task)
                    preds = preds_out.argmax(1)
                    val_loss = loss_fn(scores, labels)
                    val_losses.append(val_loss.item())
                    
                    # Move to cpu for faster computation
                    acc, recall, prec, f_beta = get_multi_metrics(pred_y=preds.cpu().numpy(),
                                                                 true_y=labels.cpu().numpy(),
                                                                 labels=np.array([0, 1]))
                    val_accs.append(acc)
                    val_recalls.append(recall)
                    val_precs.append(prec)
                    val_fbeta.append(f_beta)
                    
                print(f"VAL: epoch={ep} loss={np.mean(val_losses):.03f} accuracy={np.mean(val_accs):.03f} recalls={np.mean(val_recalls):.03f} precs={np.mean(val_precs):.03f} fbeta={np.mean(val_fbeta):.03f}")
        
        opt.step()
    print(f"TRAIN: epoch={ep} loss={loss} avg_loss={total_loss/current_step}")

stop word path:  ./reviews/english
word vector path:  ./word_embedded/new_word2vec_model.txt
stop word path:  ./reviews/english
word vector path:  ./word_embedded/new_word2vec_model.txt
VAL: epoch=1 loss=0.658 accuracy=0.534 recalls=0.534 precs=0.611 fbeta=0.426
VAL: epoch=1 loss=0.622 accuracy=0.580 recalls=0.580 precs=0.652 fbeta=0.514
TRAIN: epoch=1 loss=0.5843609571456909 avg_loss=0.6603516525030136
VAL: epoch=2 loss=0.570 accuracy=0.580 recalls=0.580 precs=0.720 fbeta=0.492
VAL: epoch=2 loss=0.534 accuracy=0.704 recalls=0.704 precs=0.746 fbeta=0.689
TRAIN: epoch=2 loss=0.6051499247550964 avg_loss=0.6185775139927864
VAL: epoch=3 loss=0.501 accuracy=0.741 recalls=0.741 precs=0.744 fbeta=0.740
VAL: epoch=3 loss=0.480 accuracy=0.749 recalls=0.749 precs=0.776 fbeta=0.741
TRAIN: epoch=3 loss=0.41845524311065674 avg_loss=0.5750580715139707
VAL: epoch=4 loss=0.465 accuracy=0.768 recalls=0.768 precs=0.772 fbeta=0.768
VAL: epoch=4 loss=0.444 accuracy=0.761 recalls=0.761 precs=0.772 fbeta=0.

In [31]:
import torch
from torch import nn
import torch.nn.functional as F
import os
from pathlib import Path
import numpy as np

def extract_batch(batch: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    supports = torch.tensor(batch['support'], device=device)
    query = torch.tensor(batch['queries'], device=device)
    labels = torch.tensor(batch['labels'], device=device)
    
    # https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
    # https://github.com/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb

    # supports = [num_classes, batch, length]
    pos_support = supports[0, :, :]
    pos_support_labels = torch.zeros(pos_support.size(0))
    neg_support = supports[1, :, :]
    neg_support_labels = torch.ones(neg_support.size(0))
    supports = torch.cat([pos_support, neg_support], dim=0).to(device=device)
    supports_labels = torch.cat([pos_support_labels, neg_support_labels], dim=0).to(device=device)
    
    return supports, supports_labels, query, labels

class CNN(nn.Module):
    
    def __init__(self, vocab_size: int, embedding_dim: int, fs: List[int], channels: int, output_dim: int) -> None:
        super(CNN, self).__init__()
        
        self.embed = nn.Embedding(vocab_size, embedding_dim)
        
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels=1, out_channels=channels, kernel_size=(n, embedding_dim)) for n in fs
        ])
        
        self.fc = nn.Linear(len(fs) * channels, output_dim)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if DEBUG_PRINT:
            print(f"CNN x={x.size()}")
        embedded = self.embed(x)
        if DEBUG_PRINT:
            print(f"CNN embeded={embedded.size()}")
        embedded = embedded.unsqueeze(1)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
        pooled = [F.max_pool1d(conv, conv.shape[2]).squeeze(2) for conv in conved]
        output = torch.cat(pooled, dim=1)
        return self.fc(output)
    
    def predict(self, x: torch.Tensor) -> torch.Tensor:
        return self(x).argmax(1)
    
    def scores(self, batch: dict) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        supports, supports_labels, query, query_labels = extract_batch(batch)
        supports_out = model(supports)
        query_out = model(query)

        if DEBUG_PRINT:
            print(f"train support_out={supports_out.size()} query_out={query_out.size()}")
        
        if DEBUG_PRINT:
            print(f"train indices={torch.nonzero(supports_labels == 0).size()}")
        proto_pos = supports_out[torch.nonzero(supports_labels == 0)]
        proto_neg = supports_out[torch.nonzero(supports_labels == 1)]
        
        if DEBUG_PRINT:
            print(f"train proto_pos={proto_pos.size()}")
            
        proto_pos = proto_pos.mean(0)
        proto_neg = proto_neg.mean(0)
        
        if DEBUG_PRINT:
            print(f"train proto_pos_meaned={proto_pos.size()}")
            
        proto = torch.cat([proto_pos, proto_neg], dim=0).to(dtype=torch.float)
        if DEBUG_PRINT:
            print(f"train proto={proto.size()}")
            
        dists = torch.cdist(query_out, proto)
        if DEBUG_PRINT:
            print(f"train dists={dists.size()} labels={query_labels.size()}")

        scores = -dists
        return scores, query_out, query_labels
    
data_loader = PrototypicalData(output_path=config["output_path"],
                                    sequence_length=config["sequence_length"],
                                    num_classes=config["num_classes"],
                                    num_support=config["num_support"],
                                    num_queries=config["num_queries"],
                                    num_tasks=config["num_tasks"],
                                    num_eval_tasks=config["num_eval_tasks"],
                                    embedding_size=config["embedding_size"],
                                    stop_word_path=config["stop_word_path"],
                                    word_vector_path=config["word_vector_path"],
                                    is_training=True)
eval_loader = PrototypicalData(output_path=config["output_path"],
                                    sequence_length=config["sequence_length"],
                                    num_classes=config["num_classes"],
                                    num_support=config["num_support"],
                                    num_queries=config["num_queries"],
                                    num_tasks=config["num_tasks"],
                                    num_eval_tasks=config["num_eval_tasks"],
                                    embedding_size=config["embedding_size"],
                                    stop_word_path=config["stop_word_path"],
                                    word_vector_path=config["word_vector_path"],
                                    is_training=False)
train_tasks, _ = data_loader.gen_data(config["train_data"])
eval_tasks, _ = eval_loader.gen_data(config["eval_data"])

epochs = config["epochs"]
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CNN(data_loader.vocab_size, 512, [1,2,4], 256, 2).to(device)
model.train()
lr = 0.01
opt = torch.optim.SGD(model.parameters(), lr=lr)
loss_fn = nn.CrossEntropyLoss()
total_loss = 0
interval = config["checkpoint_every"]
current_step = 0

ckpt_path = Path(config['ckpt_model_path'])

ticks = []
for ep in range(1, epochs+1):
    for task in data_loader.next_batch(train_tasks):
        model.zero_grad()
        # https://github.com/jakesnell/prototypical-networks/blob/master/protonets/models/few_shot.py
        # https://github.com/sicara/easy-few-shot-learning/blob/master/notebooks/my_first_few_shot_classifier.ipynb
    
        scores, _, labels = model.scores(task)
        
        loss = loss_fn(scores, labels)
        loss.backward()
        
        opt.step()
        
        total_loss += loss.item()
        
        current_step += 1
        
        if current_step % interval == 0:
            with torch.no_grad():
                val_losses = []
                val_accs = []
                val_recalls = []
                val_precs = []
                val_fbeta = []
                for task in eval_loader.next_batch(eval_tasks):
                    scores, preds_out, labels = model.scores(task)
                    preds = preds_out.argmax(1)
                    val_loss = loss_fn(scores, labels)
                    val_losses.append(val_loss.item())
                    
                    # Move to cpu for faster computation
                    acc, recall, prec, f_beta = get_multi_metrics(pred_y=preds.cpu().numpy(),
                                                                 true_y=labels.cpu().numpy(),
                                                                 labels=np.array([0, 1]))
                    val_accs.append(acc)
                    val_recalls.append(recall)
                    val_precs.append(prec)
                    val_fbeta.append(f_beta)
                
                mean_loss = np.mean(val_losses)
                mean_acc = np.mean(val_accs)
                mean_recalss = np.mean(val_recalls)
                mean_precs = np.mean(val_precs)
                mean_f1b = np.mean(val_fbeta)
                print(f"VAL: epoch={ep} loss={mean_loss:.03f} accuracy={mean_acc:.03f} recalls={mean_recalss:.03f} precs={mean_precs:.03f} fbeta={mean_f1b:.03f}")

                ticks.append(
                    ((ep, current_step), mean_loss, mean_acc, mean_recalss, mean_precs, mean_f1b)
                )
        opt.step()
    print(f"TRAIN: epoch={ep} loss={loss} avg_loss={total_loss/current_step}")

stop word path:  ./reviews/english
word vector path:  ./word_embedded/new_word2vec_model.txt
stop word path:  ./reviews/english
word vector path:  ./word_embedded/new_word2vec_model.txt
VAL: epoch=1 loss=0.655 accuracy=0.556 recalls=0.556 precs=0.565 fbeta=0.538
VAL: epoch=1 loss=0.611 accuracy=0.586 recalls=0.586 precs=0.603 fbeta=0.565
TRAIN: epoch=1 loss=0.5804505348205566 avg_loss=0.6555341073870659
VAL: epoch=2 loss=0.552 accuracy=0.661 recalls=0.661 precs=0.689 fbeta=0.647
VAL: epoch=2 loss=0.527 accuracy=0.719 recalls=0.719 precs=0.722 fbeta=0.718
TRAIN: epoch=2 loss=0.417949378490448 avg_loss=0.6089077597856521
VAL: epoch=3 loss=0.489 accuracy=0.735 recalls=0.735 precs=0.759 fbeta=0.726
VAL: epoch=3 loss=0.449 accuracy=0.764 recalls=0.764 precs=0.768 fbeta=0.763
TRAIN: epoch=3 loss=0.30760663747787476 avg_loss=0.5650602225959301
VAL: epoch=4 loss=0.446 accuracy=0.592 recalls=0.592 precs=0.711 fbeta=0.521
VAL: epoch=4 loss=0.432 accuracy=0.746 recalls=0.746 precs=0.766 fbeta=0.7

In [32]:
print(ticks)

[((1, 100), 0.6550026260396485, 0.5560215053763442, 0.5560215053763442, 0.5649300672519706, 0.5383688217515725), ((1, 200), 0.6110553638059266, 0.5858163265306123, 0.5858163265306123, 0.6034250659045903, 0.5651431731659589), ((2, 300), 0.551669873652004, 0.660952380952381, 0.660952380952381, 0.6893738408054743, 0.6469467733601613), ((2, 400), 0.5266320000676548, 0.7190588235294116, 0.7190588235294116, 0.7222923723267565, 0.7178597693977845), ((3, 500), 0.48907450234496985, 0.7346153846153846, 0.7346153846153846, 0.7594859706140806, 0.7257330737148866), ((3, 600), 0.44861228700647965, 0.7636170212765958, 0.7636170212765958, 0.7680353074156651, 0.7625187866481069), ((4, 700), 0.44608796474545503, 0.5918604651162791, 0.5918604651162791, 0.7114090663014866, 0.5205709911446809), ((4, 800), 0.43229351050398324, 0.7464044943820225, 0.7464044943820225, 0.7658888654373842, 0.7401985609777428), ((5, 900), 0.4030474927476657, 0.6920430107526881, 0.6920430107526881, 0.7553949815294411, 0.668370329

In [None]:
data_loader = PrototypicalData(output_path=config["output_path"],
                                    sequence_length=config["sequence_length"],
                                    num_classes=config["num_classes"],
                                    num_support=config["num_support"],
                                    num_queries=config["num_queries"],
                                    num_tasks=config["num_tasks"],
                                    num_eval_tasks=config["num_eval_tasks"],
                                    embedding_size=config["embedding_size"],
                                    stop_word_path=config["stop_word_path"],
                                    word_vector_path=config["word_vector_path"],
                                    is_training=True)
eval_loader = PrototypicalData(output_path=config["output_path"],
                                    sequence_length=config["sequence_length"],
                                    num_classes=config["num_classes"],
                                    num_support=config["num_support"],
                                    num_queries=config["num_queries"],
                                    num_tasks=config["num_tasks"],
                                    num_eval_tasks=config["num_eval_tasks"],
                                    embedding_size=config["embedding_size"],
                                    stop_word_path=config["stop_word_path"],
                                    word_vector_path=config["word_vector_path"],
                                    is_training=False)
train_tasks, train_data_orig = data_loader.gen_data(config["train_data"])
eval_tasks, eval_data_orig = eval_loader.gen_data(config["eval_data"])

In [74]:
import random



model.eval()

rk = random.choice([k for k in eval_tasks.keys()])

print(eval_tasks[rk])
print(eval_tasks[rk].keys())
for i in range(10):
    if i % 2 == 0:
        st = "-1"
        gt = "Negative"
    else:
        st = "1"
        gt = "Positive"
    sentences = [(s, eval_data_orig[rk][st][idx]) for idx, s in enumerate(eval_tasks[rk][st]) if len(s) <= 20 and len(s) >= 4]
    sentence, sentence_original = random.choice(sentences)
    
    preds = model.predict(torch.tensor(sentence, device=device).unsqueeze(0)).item()
    if preds == 1:
        pred = "Negative"
    else:
        pred = "Positive"
    print(f"\\makecell[l]{{ {' '.join(sentence_original)} }} & {pred} & {gt} \\\\")

# print(eval_loader.index_to_word)

{'1': [[199, 1, 1049], [2, 12, 5, 1, 497, 1, 1492, 8, 518, 1, 41, 2284, 1, 1, 1323, 1, 1, 6943, 2, 13, 1, 1, 810, 1, 28, 1, 3715, 1, 1, 3423, 424, 1, 361, 1797, 93, 557, 161, 1, 11056, 1, 497, 77, 29, 1, 1, 1637, 1, 1, 2003, 37, 669, 110, 1, 2633, 1, 2, 56, 16, 34, 1, 2990, 2, 10, 98, 1, 1, 389, 1, 10, 1, 61, 1, 1, 1, 551, 13195], [146, 1384, 1, 1, 257, 1779, 1, 2, 5, 1, 1745, 493, 3480, 1, 2, 1, 3253, 960, 714, 1, 806, 1, 21168, 1, 1118, 1, 3480, 1, 1, 3003, 1, 10, 1, 78, 1, 1840, 1, 17359, 1, 12592, 1, 1, 26, 2194, 1, 18163, 1, 204, 1, 145, 1, 3, 324, 510, 390, 89, 1, 1, 397, 1, 1, 1, 1637, 1, 290], [7, 18163, 130, 1, 1789, 1, 1, 810, 247, 2489, 16, 89, 1, 1384, 1, 29, 1, 1637, 1, 290], [108, 1389], [77, 1, 1, 2752, 1, 80, 2256, 511, 1, 4205], [108, 1, 299, 17, 1, 810, 1, 1, 2256, 1, 636, 47, 3211, 961, 289, 22423, 5, 6547, 458, 1, 810, 592, 365], [1384, 1, 1, 394], [47, 1, 1, 1, 117, 1, 30, 6, 1, 1839, 117, 1, 1, 1, 16193, 1, 23, 136, 80, 1, 1, 1, 93, 1, 1, 1, 1, 93, 16158, 1, 2883,