# An Episodic Training Example
## Dataloaders setup

In [1]:
from data.unified_emotion.unified_emotion import unified_emotion
from transformers import AutoTokenizer

In [2]:
unified = unified_emotion("./data/datasets/unified-dataset.jsonl")

unified.prep()

## Model setup

In [3]:
from transformers import BertConfig

from modules.mlp_clf import SF_CLF


# Pre-trained model name (from Huggingface)
bert_name = 'bert-base-uncased'
# Pre-trained configuration from Huggingface
bert_config = BertConfig.from_pretrained(bert_name)
# Max to layer keep frozen. 11 keeps model frozen, -1 makes BERT totally trainable
nu = 11

# MLP hidden layers
hidden_dims = [512, 256, 128]
# Which activation to use. Currently either tanh or ReLU
act_fn = 'ReLU'

config = {'bert_name': bert_name, 
'bert_config': bert_config, 
'nu': nu,
'hidden_dims': hidden_dims,
'act_fn': act_fn
}

tokenizer = AutoTokenizer.from_pretrained(bert_name)


In [4]:
from models.meta_bert import MetaBert

model = MetaBert(config)

print(model)

MetaBert(
  (encoder): BertSequence(
    (model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                

## Training setup

In [5]:
import higher
import torch.nn as nn
import torch.optim as optim

from data.utils.sampling import dataset_sampler

task_optimizer = optim.SGD(model.parameters(), lr=1e-1)
meta_optimizer = optim.AdamW(model.parameters(), lr=1e-3)

lossfn = nn.CrossEntropyLoss()

k = 4 # Number of shots
n_inner = 5 # Number of inner loop updates
n_outer = 10 # Number of outer loop updates before meta update

In [6]:
for i in range(n_outer):
    
    # A single episode
    # Set optimizer outside to 0
    meta_optimizer.zero_grad()

    # Sample a task
    source_name = dataset_sampler(unified, sampling_method='sqrt')

    # Get task dataloaders
    trainloader, testloader = unified.get_dataloader(source_name, k=k, tokenizer=tokenizer, shuffle=True)

    # Re-initialize the softmax layer
    # Can be informed (e.g. ProtoMAML, LEOPARD, etc.)
    n_classes = len(unified.label_map[source_name].keys())
    clf_layer = SF_CLF(n_classes=n_classes, hidden_dims=hidden_dims)

    print(f"Enter the innerloop for {source_name}, N={n_classes}")
    with higher.innerloop_ctx(model, task_optimizer, copy_initial_weights=False) as (fmodel, diffopt):
        
        # MAML support sets
        for ii in range(n_inner):
            # Sample batch
            batch = next(trainloader)
            labels, text, attn_mask = batch
            
            # Use higher to perform inner loop updates
            y = model(text, attn_mask)
            logits = clf_layer(y)
            inner_loss = lossfn(logits, labels)
            diffopt.step(inner_loss)

            print(f"\tInner {ii} | Loss={inner_loss.detach().tolist()}", flush=True)
        
        # MAML query set
        batch = next(trainloader)
        labels, text, attn_mask = batch
        
        y = model(text, attn_mask)
        logits = clf_layer(y)
        outer_loss = lossfn(logits, labels)
        outer_loss.backward()

        print(f"\tOuter   | Loss={outer_loss.detach().tolist()}", flush=True)

    meta_optimizer.step()


Enter the innerloop for dailydialog, N=7
	Inner 0 | Loss=1.949546217918396
	Inner 1 | Loss=1.9495809078216553
	Inner 2 | Loss=1.9483307600021362
	Inner 3 | Loss=1.9500354528427124
	Inner 4 | Loss=1.948388934135437
	Outer   | Loss=1.9478541612625122
Enter the innerloop for tec, N=6
	Inner 0 | Loss=1.795456051826477
	Inner 1 | Loss=1.7922641038894653
	Inner 2 | Loss=1.7948722839355469
	Inner 3 | Loss=1.7933133840560913
	Inner 4 | Loss=1.7916769981384277
	Outer   | Loss=1.7909350395202637
Enter the innerloop for dailydialog, N=7
	Inner 0 | Loss=1.9476946592330933
	Inner 1 | Loss=1.9475170373916626
	Inner 2 | Loss=1.9461147785186768
	Inner 3 | Loss=1.9486711025238037
	Inner 4 | Loss=1.9497442245483398
	Outer   | Loss=1.9504438638687134
Enter the innerloop for dailydialog, N=7
	Inner 0 | Loss=1.959351897239685
	Inner 1 | Loss=1.9613603353500366
	Inner 2 | Loss=1.954295039176941
	Inner 3 | Loss=1.9514131546020508
	Inner 4 | Loss=1.951809287071228
	Outer   | Loss=1.954148292541504
Enter the i