Reference:

[1] https://github.com/learnables/learn2learn/blob/master/examples/text/news_topic_classification.py

In [1]:
# !pip install fairseq --user
# !pip install learn2learn --user

### Text Classification using MAML implemented through learn2learn library

learn2learn is a software library for meta-learning research. It is built on top of PyTorch to accelerate two aspects of the meta-learning research cycle:

1. fast prototyping, essential in letting researchers quickly try new ideas, and
2. correct reproducibility, ensuring that these ideas are evaluated fairly.

MAML is a Model-Agnostic Meta-Learning Algorithm proposed in https://arxiv.org/pdf/1703.03400.pdf

### Import Libraries

In [2]:
import argparse
import random

import torch
from torch import nn, optim
from torch.nn import functional as F
from tqdm import tqdm
from transformers import RobertaTokenizer, RobertaModel

import learn2learn as l2l

### Define Model

In [3]:
class Net(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, num_classes, input_dim=768, inner_dim=200, pooler_dropout=0.3):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.activation_fn = nn.ReLU()
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, x, **kwargs):
        x = self.dropout(x)
        x = self.dense(x)
        x = self.activation_fn(x)
        x = self.dropout(x)
        x = F.log_softmax(self.out_proj(x), dim=1)
        return x

  and should_run_async(code)


### Define Accuracy Calculation Function

In [4]:
def accuracy(predictions, targets):
    predictions = predictions.argmax(dim=1)
    acc = (predictions == targets).sum().float()
    acc /= len(targets)
    return acc.item()

### Define Utilisty Function

In [5]:
def collate_tokens(values, pad_idx, eos_idx=None, left_pad=False, move_eos_to_beginning=False):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values)
    res = values[0].new(len(values), size).fill_(pad_idx)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if move_eos_to_beginning:
            assert src[-1] == eos_idx
            dst[0] = eos_idx
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
    return res

In [6]:
class _BatchedDataset(torch.utils.data.Dataset):
    def __init__(self, batched):
        self.sents = [s for s in batched[0]]
        self.ys = [y for y in batched[1]]
    
    def __len__(self):
        return len(self.ys)
    
    def __getitem__(self, idx):
        return (self.sents[idx], self.ys[idx])

In [7]:
def compute_loss(task, roberta, device, learner, loss_func, batch=15):
    loss = 0.0
    acc = 0.0
    for i, (x, y) in enumerate(torch.utils.data.DataLoader(
            _BatchedDataset(task), batch_size=batch, shuffle=True, num_workers=0)):
        # RoBERTa ENCODING
        x = collate_tokens([roberta.encode(sent) for sent in x], pad_idx=1)
        with torch.no_grad():
            x = roberta.extract_features(x)
        x = x[:, 0, :]

        # Moving to device
        x, y = x.to(device), y.view(-1).to(device)

        output = learner(x)
        curr_loss = loss_func(output, y)
        acc += accuracy(output, y)
        loss += curr_loss / len(task)
    loss /= len(task)
    return loss, acc

In [8]:
def main(lr=0.005, maml_lr=0.01, iterations=1000, ways=5, shots=1, tps=32, fas=5, device=torch.device("cpu"),
         download_location="/tmp/text"):
    dataset = l2l.text.datasets.NewsClassification(root=download_location, download=True)
    dataset = l2l.data.MetaDataset(dataset)

    classes = list(range(len(dataset.labels))) # 41 classes
    random.shuffle(classes)

    train_dataset, validation_dataset, test_dataset = dataset, dataset, dataset

    train_gen = l2l.data.TaskDataset(
            train_dataset, num_tasks=20000, 
            task_transforms=[
                l2l.data.transforms.FusedNWaysKShots(
                    train_dataset, n=ways, k=shots, filter_labels=classes[:20]),
                l2l.data.transforms.LoadData(train_dataset),
                l2l.data.transforms.RemapLabels(train_dataset)],)

    validation_gen = l2l.data.TaskDataset(
            validation_dataset, num_tasks=20000, 
            task_transforms=[
                l2l.data.transforms.FusedNWaysKShots(
                    validation_dataset, n=ways, k=shots, filter_labels=classes[20:30]),
                l2l.data.transforms.LoadData(validation_dataset),
                l2l.data.transforms.RemapLabels(validation_dataset)],)

    test_gen = l2l.data.TaskDataset(
            test_dataset, num_tasks=20000, 
            task_transforms=[
                l2l.data.transforms.FusedNWaysKShots(
                    test_dataset, n=ways, k=shots, filter_labels=classes[30:]),
                l2l.data.transforms.LoadData(test_dataset),
                l2l.data.transforms.RemapLabels(test_dataset)],)

    torch.hub.set_dir(download_location)
    roberta = torch.hub.load('pytorch/fairseq', 'roberta.base')
    roberta.eval()
    roberta.to(device)
    model = Net(num_classes=ways)
    model.to(device)
    meta_model = l2l.algorithms.MAML(model, lr=maml_lr)
    opt = optim.Adam(meta_model.parameters(), lr=lr)
    loss_func = nn.NLLLoss(reduction="sum")

    tqdm_bar = tqdm(range(iterations))

    accs = []
    for _ in tqdm_bar:
        iteration_error = 0.0
        iteration_acc = 0.0
        for _ in range(tps):
            learner = meta_model.clone()
            train_task, valid_task = train_gen.sample(), validation_gen.sample()

            # Fast Adaptation
            for _ in range(fas):
                train_error, _ = compute_loss(train_task, roberta, device, learner, loss_func, batch=shots * ways)
                learner.adapt(train_error)

            # Compute validation loss
            valid_error, valid_acc = compute_loss(valid_task, roberta, device, learner, loss_func,
                                                  batch=shots * ways)
            iteration_error += valid_error
            iteration_acc += valid_acc

        iteration_error /= tps
        iteration_acc /= tps
        tqdm_bar.set_description("Loss : {:.3f} Acc : {:.3f}".format(iteration_error.item(), iteration_acc))
        accs.append(iteration_acc)
        # Take the meta-learning step
        opt.zero_grad()
        iteration_error.backward()
        opt.step()
    print (f'first and best validation accuracy: {accs[0]:.4f}, {max(accs):.4f}')

In [9]:
class Arguments():
    def __init__(self):
        
        self.ways  = 5 # number of ways (default: 5)
        self.shots = 1 # number of shots (default: 1)
        self.tasks_per_step = 32 # tasks per step (default: 32)
        self.fast_adaption_steps = 5 # steps per fast adaption (default: 5)
        self.iterations = 1000 # number of iterations (default: 1000)
        self.lr = 0.005 # learning rate (default: 0.005)
        self.maml_lr = 0.01 # learning rate for MAML (default: 0.01)
        self.no_cuda = False # disables CUDA training
        self.seed = 1 # random seed (default: 1)
        self.download_location = '/tmp/text' # download location for train data and roberta(default : /tmp/text

args = Arguments()
args.lr

0.005

In [10]:
use_cuda = not args.no_cuda and torch.cuda.is_available()

torch.manual_seed(args.seed)
random.seed(args.seed)

device = torch.device("cuda" if use_cuda else "cpu")

main(lr=args.lr, maml_lr=args.maml_lr, iterations=args.iterations, ways=args.ways, shots=args.shots,
     tps=args.tasks_per_step, fas=args.fast_adaption_steps, device=device,
     download_location=args.download_location)

Using cache found in /tmp/text\pytorch_fairseq_master
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _TENSORSHAPEPROTO_DIM = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.EnumValueDescriptor(
  _DATATYPE = _descriptor.EnumDescriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _RESOURCEHANDLEPROTO_DTYPEANDSHAPE = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.FieldDescriptor(
  _TENSORPROTO = _descriptor.Descriptor(
  DESCRIPTOR = _descriptor.FileDescriptor(
  _descriptor.EnumValueDescriptor(
  _DATACLASS = _descriptor.EnumDescriptor(
  _descriptor.FieldDescriptor(
  _SUMMARYDESCRIPTION = _descriptor.Descriptor(
2022-09-05 19:23:22 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX
2022-09-05 19:23:23 | INFO | fairseq.file_utils | loading archive file http://dl.fbaipublicfiles.com/fairseq/models/roberta.base.tar.gz 

Loss : 2.018 Acc : 0.169:   7%|██▌                                   | 67/1000 [28:29<6:36:48, 25.52s/it]


TypeError: expected string or buffer