In [1]:
from copy import deepcopy 

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertConfig, AutoTokenizer, get_constant_schedule_with_warmup

from models.meta_bert import MetaBert
from data.unified_emotion.unified_emotion import unified_emotion
from data.utils.sampling import dataset_sampler
from utils.metrics import logging_metrics

###############################################
### HANDLE WITH ARGPARSER #####################
###############################################

###
### Meta-training definition hyper parameters 
###
k = 8 # Number of shots
n_inner = 5 # Number of inner loop updates
n_outer = 2 # Number of outer loop updates before meta update
n_episodes = 20

###
### Model definition hyperparameters 
###
# Pre-trained model name (from Huggingface)
bert_name = 'bert-base-uncased'
# Pre-trained configuration from Huggingface
bert_config = BertConfig.from_pretrained(bert_name)
# Max to layer keep frozen. 11 keeps model frozen, -1 makes BERT totally trainable
nu = 11

# MLP hidden layers
hidden_dims = [512, 256, 128]
# Which activation to use. Currently either tanh or ReLU
act_fn = 'ReLU'

# Emulate config file
config = {'bert_name': bert_name, 
'bert_config': bert_config, 
'nu': nu,
'hidden_dims': hidden_dims,
'act_fn': act_fn
}

tokenizer = AutoTokenizer.from_pretrained(bert_name)

###
### Optimizer hyper parameters 
###
meta_lr = 1e-5
inner_lr = 1e-3
warm_up_steps = 100

In [2]:

unified = unified_emotion("./data/datasets/unified-dataset.jsonl")

unified.prep()

# Initialization and all that jazz

model = MetaBert(config)

shared_optimizer = optim.SGD(model.parameters(), lr=meta_lr)
shared_lr_schedule = get_constant_schedule_with_warmup(shared_optimizer, warm_up_steps)

lossfn = nn.CrossEntropyLoss()


In [4]:
with torch.autograd.set_detect_anomaly(True):
    for i in range(n_episodes):
        
        for ii in range(n_outer):
 
            # Sample a dataset
            source_name = dataset_sampler(unified, sampling_method='sqrt')
            trainloader, testloader = unified.get_dataloader(source_name, k=k, tokenizer=tokenizer, shuffle=True)

            # Copy shared model to task-specific model
            model_task = deepcopy(model)
            model_task.zero_grad()

            # Inner loop optimizer
            task_optimizer = optim.SGD(model_task.parameters(), lr=inner_lr)    

            ########################
            # Proto Initialization #
            ########################
            model.eval()
            
            # Embed with init_model
            batch = next(trainloader)
            labels, text, attn_mask = batch

            y_init = model(text, attn_mask)

            # Generate initial classification weights and biases
            prototypes = torch.stack([torch.mean(y_init[labels == i], dim=0) for \
                i, _ in enumerate(unified.label_map[source_name].keys())])

            W = 2 * prototypes
            b = -torch.norm(prototypes, p=2, dim=1)

            model.train()

            # Set the weights and biases
            # Separate graph
            W_output, b_output = W.detach(),  b.detach()
            W_output.requires_grad, b_output.requires_grad = True, True

            # Classifier optimizer
            output_optimizer = optim.SGD([W_output, b_output], lr=inner_lr)

            ##############
            # Inner loop #
            ############## 
            for iii in range(n_inner):
                # Load data
                batch = next(trainloader)
                labels, text, attn_mask = batch

                # Embed, encode, classify and compute loss
                y = model_task(text, attn_mask)
                logits = F.linear(y, W_output, b_output)
                loss = lossfn(logits, labels)

                # Backprop the output parameters
                # Retrain graph for shared parameters
                W_output.grad, b_output.grad = torch.autograd.grad(loss, [W_output, b_output], retain_graph=True)

                # Calculate the gradients on shared parameters here    
                updateable_task_params = [param for param in model_task.parameters() if param.requires_grad]
                task_grads = torch.autograd.grad(loss, updateable_task_params)

                # Store task-specific gradients
                for param, grad in zip(updateable_task_params, task_grads):
                    param.grad = grad

                # Update the parameters
                output_optimizer.step()
                task_optimizer.step()

                output_optimizer.zero_grad()
                task_optimizer.zero_grad()
            
            # Re-attach the graphs
            W_output = W + (W_output - W).detach()
            b_output = b + (b_output - b).detach()

            ########################
            # Outer loop gradients #
            ######################## 
            # Load Query
            batch = next(testloader)
            labels, text, attn_mask = batch

            # Push data through task-specific model
            y = model_task(text, attn_mask)
            logits = F.linear(y, W_output, b_output)
            loss = lossfn(logits, labels)

            # TODO: add logging for outer loop updates here
            n_classes = len(unified.label_map[source_name].keys())
            mets = logging_metrics(logits.detach().cpu(), labels.detach().cpu())
            print("Episode {} | Task {}/{}: {:<20s}, N={} | Loss {:5.2f}, Acc {:5.2f}, F1 {:5.2f}".format(i, ii+1, n_outer, source_name, n_classes, loss.detach().item(), mets['acc']*100, mets['f1']*100))

            # Calculate gradients for task-specific parameters    
            updateable_task_params = [param for param in model_task.parameters() if param.requires_grad]
            task_grads = torch.autograd.grad(loss, updateable_task_params, retain_graph=True)

            # Calculate gradients for shared model parameters    
            updateable_model_params = [param for param in model.parameters() if param.requires_grad]
            model_grads = torch.autograd.grad(loss, updateable_model_params)

            # Store the gradients
            for param, g_task, g_init in zip(updateable_model_params, task_grads, model_grads):
                if param.grad == None:
                    param.grad = g_task + g_init
                else:
                    param.grad += g_task + g_init

        print("Episode {} finished.\n".format(i))

        for param in updateable_model_params:
            param.grad = param.grad / n_outer

        shared_optimizer.step()
        shared_optimizer.zero_grad()

Episode 0 | Task 1/2: emoint              , N=4 | Loss  1.39, Acc 25.00, F1 24.43
Episode 0 | Task 2/2: dailydialog         , N=7 | Loss  1.95, Acc 17.86, F1 14.00
Episode 0 finished.

Episode 1 | Task 1/2: crowdflower         , N=8 | Loss  2.08, Acc 12.50, F1  2.86
Episode 1 | Task 2/2: dailydialog         , N=7 | Loss  1.95, Acc 10.71, F1  4.63
Episode 1 finished.

Episode 2 | Task 1/2: dailydialog         , N=7 | Loss  2.05, Acc 14.29, F1  3.69
Episode 2 | Task 2/2: emoint              , N=4 | Loss  1.43, Acc 25.00, F1 17.03
Episode 2 finished.

Episode 3 | Task 1/2: dailydialog         , N=7 | Loss  1.94, Acc 23.21, F1 14.27
Episode 3 | Task 2/2: grounded_emotions   , N=2 | Loss  0.69, Acc 50.00, F1 33.33
Episode 3 finished.

Episode 4 | Task 1/2: dailydialog         , N=7 | Loss  1.94, Acc 16.07, F1  7.65
Episode 4 | Task 2/2: tales-emotion       , N=7 | Loss  2.82, Acc 14.29, F1  5.77
Episode 4 finished.

Episode 5 | Task 1/2: crowdflower         , N=8 | Loss  2.09, Acc 10.94, F1

KeyboardInterrupt: 