In [1]:
from copy import deepcopy
import pickle

import psutil
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from transformers import AutoTokenizer, get_constant_schedule_with_warmup
from torch.utils.tensorboard import SummaryWriter

from models.seqtransformer import SeqTransformer
from data.unified_emotion_numpy import unified_emotion
from data.utils.tokenizer import manual_tokenizer, specials
from data.utils.data_loader_numpy import StratifiedLoader
from data.utils.sampling import dataset_sampler
from utils.metrics import logging_metrics

config = {'encoder_name': 'bert-base-cased',
'nu': 12,
'hidden_dims': [512, 256],
'act_fn': 'tanh',
'include': ['grounded_emotions'],
'max_support_size': 32,
'n_inner': 5,
'warmup_steps': 25,
'max_episodes': 250,
'meta_lr': 5e-5,
'inner_lr': 1e-3,
'output_lr': 1e-1,
'checkpoint_path': 'ProtoMAML_Rebuild',
'version': 'overfit_test',
'gpu': False
}

In [2]:
#######################
# Logging Directories #
#######################
log_dir = os.path.join(config['checkpoint_path'], config['version'])

os.makedirs(log_dir, exist_ok=True)
os.makedirs(os.path.join(log_dir, 'tensorboard'), exist_ok=True)
os.makedirs(os.path.join(log_dir, 'checkpoint'), exist_ok=True)
print(f"Saving models and logs to {log_dir}")

with open(os.path.join(log_dir, 'checkpoint', 'hparams.pickle'), 'wb') as file:
    pickle.dump(config, file)

## Initialization
# Device
device = torch.device('cuda' if (torch.cuda.is_available() and config['gpu']) else 'cpu')

# Build the tensorboard writer
writer = SummaryWriter(os.path.join(log_dir, 'tensorboard'))

###################
# Load in dataset #
###################
dataset = unified_emotion("./data/datasets/unified-dataset.jsonl",
                            include=config['include'])
dataset.prep(text_tokenizer=manual_tokenizer)

####################
# Init models etc. #
####################

model_init = SeqTransformer(config)
tokenizer = AutoTokenizer.from_pretrained(config['encoder_name'])

tokenizer.add_special_tokens({'additional_special_tokens': specials()})
model_init.encoder.model.resize_token_embeddings(len(tokenizer.vocab))

meta_optimizer = optim.SGD(model_init.parameters(), lr=config['meta_lr'])
meta_scheduler = get_constant_schedule_with_warmup(meta_optimizer, config['warmup_steps'])

model_init = model_init.to(device)

loss_fn = nn.CrossEntropyLoss()

#######################
# Overfit query batch #
#######################
task = dataset_sampler(dataset, sampling_method='sqrt')

datasubset = dataset.datasets[task]['train']

dataloader = StratifiedLoader(datasubset, k=16, shuffle=True, max_batch_size=config['max_support_size'], tokenizer=tokenizer, device=device, classes_subset=False)

_, _, query_labels, query_input = next(dataloader)
overfit_queries = [(query_labels, query_input)]

for episode in range(config['max_episodes']):
    queries = []
    
    #################
    # Sample a task #
    #################
    task = dataset_sampler(dataset, sampling_method='sqrt')

    datasubset = dataset.datasets[task]['train']

    dataloader = StratifiedLoader(datasubset, k=16, shuffle=True, max_batch_size=config['max_support_size'], tokenizer=tokenizer, device=device, classes_subset=False)


    #####################
    # Create model_task #
    #####################

    model_task = deepcopy(model_init)
    model_task_optimizer = optim.SGD(model_task.parameters(), lr=config['inner_lr'])
    model_task.train()
    model_task.zero_grad()

    #######################
    # Generate prototypes #
    #######################

    support_labels, support_input, query_labels, query_input = next(dataloader)
    queries.append((query_labels, query_input))

    model_init.train()
    
    y = model_init(support_input)

    labs = torch.sort(torch.unique(support_labels))[0]
    prototypes = torch.stack([torch.mean(y[support_labels==c], dim=0) for c in labs])

    W_init = 2 * prototypes
    b_init = -torch.norm(prototypes, p=2, dim=1)

    W_task, b_task = W_init.detach(), b_init.detach()
    W_task.requires_grad, b_task.requires_grad = True, True

    #################
    # Adapt to data #
    #################
    for _ in range(config['n_inner']):
    
        support_labels, support_input, query_labels, query_input = next(dataloader)
        queries.append((query_labels, query_input))

        y = model_task(support_input)
        logits = F.linear(y, W_task, b_task)

        inner_loss = loss_fn(logits, support_labels)

        W_task_grad, b_task_grad = torch.autograd.grad(inner_loss, [W_task, b_task], retain_graph=True)

        inner_loss.backward()

        model_task_optimizer.step()
        W_task = W_task - 1e-1 * W_task_grad
        b_task = b_task - 1e-1 * b_task_grad
    

    #########################
    # Validate on query set #
    #########################
    for module in model_task.modules():
        if isinstance(module, nn.Dropout):
            module.eval()

    W_task = W_init + (W_task - W_init).detach()
    b_task = b_init + (b_task - b_init).detach()
    
    queries = overfit_queries
    outer_loss_agg, acc_agg, f1_agg = 0.0, 0.0, 0.0
    for i, batch in enumerate(queries):

        query_labels, query_input = batch

        y = model_task(query_input)
        logits = F.linear(y, W_task, b_task)

        outer_loss = loss_fn(logits, query_labels)

        model_task_params = [param for param in model_task.parameters() if param.requires_grad]
        model_task_grads = torch.autograd.grad(outer_loss, model_task_params, retain_graph=True)

        model_init_params = [param for param in model_init.parameters() if param.requires_grad]
        model_init_grads = torch.autograd.grad(outer_loss, model_init_params, retain_graph=True if i != len(queries)-1 else False)

        model_init_grads = model_init_grads + model_task_grads

        for param, grad in zip(model_init_params, model_init_grads):
            if param.grad != None:
                param.grad += grad.detach()
            else:
                param.grad = grad.detach()

        with torch.no_grad():
            mets = logging_metrics(logits.detach().cpu(), query_labels.detach().cpu())
            outer_loss_agg += outer_loss.detach().cpu().item()
            acc_agg += mets['acc']
            f1_agg += mets['f1']

            #print(torch.norm(W_task - W_init, p=2), sum((x - y).abs().sum() for x, y in zip(model_init.state_dict().values(), model_task.state_dict().values())))
        
    outer_loss_agg = outer_loss_agg / len(queries)
    acc_agg = acc_agg / len(queries)
    f1_agg = f1_agg / len(queries)

    model_init_params = [param for param in model_init.parameters() if param.requires_grad]
    for param in model_init_params:
        param.grad = param.grad / len(queries)

    meta_optimizer.step()
    meta_scheduler.step()
    meta_optimizer.zero_grad()

    print("Train | Episode {:} | Task {:<20s}, N={:} | Loss {:.4E}, Acc {:5.2f}, F1 {:5.2f} | Mem {:5.2f} GB".format(\
                    episode, task, dataloader.n_classes,
                    outer_loss_agg, acc_agg, f1_agg,
                    psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3))

    writer.add_scalar('Loss/Train', outer_loss_agg, episode)
    writer.add_scalar('Accuracy/Train', acc_agg, episode)
    writer.add_scalar('F1/Train', f1_agg, episode)

    writer.flush()


Saving models and logs to ProtoMAML_Rebuild\overfit_test
Train | Episode 0 | Task grounded_emotions   , N=2 | Loss 6.9240E-01, Acc  0.44, F1  0.42 | Mem  1.74 GB
Train | Episode 1 | Task grounded_emotions   , N=2 | Loss 6.9322E-01, Acc  0.56, F1  0.56 | Mem  1.74 GB


KeyboardInterrupt: 

In [13]:
dataset = unified_emotion("./data/datasets/unified-dataset.jsonl",
                            include=['ssec'])
dataset.prep(text_tokenizer=manual_tokenizer)

In [21]:


task = dataset_sampler(dataset, sampling_method='sqrt')

datasubset = dataset.datasets[task]['train']

dataloader = StratifiedLoader(datasubset, k=16, shuffle=True, max_batch_size=config['max_support_size'], tokenizer=tokenizer, device=device, classes_subset=False)

for _ in range(100):
    support_labels, support_input, query_labels, query_input = next(dataloader)
    print(torch.sort(torch.unique(support_labels))[0])

tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0, 1, 2, 3, 4, 5, 6])
tensor([0,

In [24]:
y = model_init(support_input)

labs = torch.sort(torch.unique(support_labels))[0]

tensor([[-0.1871, -0.1271, -0.2001,  ...,  0.1974, -0.2980,  0.3948],
        [-0.1745, -0.1166, -0.1963,  ...,  0.1862, -0.2828,  0.3881],
        [-0.2274, -0.1126, -0.2577,  ...,  0.0900, -0.3041,  0.3917],
        ...,
        [-0.1750, -0.1205, -0.1878,  ...,  0.2374, -0.2931,  0.3842],
        [-0.1970, -0.1047, -0.2193,  ...,  0.1810, -0.2886,  0.3829],
        [-0.1937, -0.1399, -0.2136,  ...,  0.1904, -0.2758,  0.3870]],
       grad_fn=<StackBackward>)

In [36]:
proto = torch.stack([torch.mean(y[support_labels==c], dim=0) for c in labs])
W = 2 * proto
b = - torch.norm(proto, dim=1, p=2)

F.linear(y, W, b).size()

torch.Size([28, 7])