Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Computational graph not retained for BERT #70

Closed
Nithin-Holla opened this issue Jul 15, 2020 · 8 comments
Closed

Computational graph not retained for BERT #70

Nithin-Holla opened this issue Jul 15, 2020 · 8 comments

Comments

@Nithin-Holla
Copy link

Nithin-Holla commented Jul 15, 2020

I'm trying to implement first-order version of ProtoMAML (https://arxiv.org/pdf/1903.03096.pdf) for a sequence labelling task. If I use BERT as encoder, I run into this error at the line diffopt.step: RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.. Instead, if I use an LSTM as an encoder, then it runs successfully. Perhaps the graph is purged somehow since BERT is a large model?

Here is a self-contained code to replicate the issue. The issue occurs on both CPU and GPU. On line 117, you can specify the encoder as bert or lstm. It requires the transformers library from HuggingFace to run.

import higher
import torch

from torch import nn, optim
from transformers import BertModel, BertTokenizer
from torch.nn import functional as F


class BaseModel(nn.Module):

    def __init__(self, encoder, max_length, device):
        super(BaseModel, self).__init__()
        self.max_length = max_length
        self.device = device
        if encoder == 'bert':
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.encoder = BertModel.from_pretrained('bert-base-uncased')
            self.encoder.pooler.dense.weight.requires_grad = False
            self.encoder.pooler.dense.bias.requires_grad = False
        elif encoder == 'lstm':
            self.encoder = nn.LSTM(batch_first=True, input_size=32, hidden_size=768)
        self.linear = nn.Linear(768, 192)
        self.to(self.device)

    def encode_text(self, text):
        if isinstance(self.encoder, BertModel):
            encode_result = self.tokenizer.batch_encode_plus(text, return_token_type_ids=False, max_length=self.max_length,
                                                             pad_to_max_length=True, return_tensors='pt')
            for key in encode_result:
                encode_result[key] = encode_result[key].to(self.device)
            return encode_result
        elif isinstance(self.encoder, nn.LSTM):
            return torch.randn((len(text), 32, 32), device=self.device)

    def forward(self, inputs):
        if isinstance(self.encoder, BertModel):
            out, _ = self.encoder(inputs['input_ids'], attention_mask=inputs['attention_mask'])
        elif isinstance(self.encoder, nn.LSTM):
            out, _ = self.encoder(inputs)
        out = out[:, 1:-1, :]
        out = self.linear(out)
        return out


class ProtoMAML:

    def __init__(self, device, encoder):
        self.output_layer_weight = None
        self.output_layer_bias = None
        self.learner = BaseModel(encoder=encoder, max_length=32, device=device)
        self.inner_optimizer = optim.SGD([p for p in self.learner.parameters() if p.requires_grad], lr=0.001)
        self.loss_fn = nn.CrossEntropyLoss()
        self.output_lr = 0.001
        self.device = device
        self.updates = 5

    def output_layer(self, input, weight, bias):
        return F.linear(input, self.output_layer_weight + weight, self.output_layer_bias + bias)

    def initialize_with_proto_weights(self, support_repr, support_label, n_classes):
        prototypes = self.build_prototypes(support_repr, support_label, n_classes)
        weight = 2 * prototypes
        bias = -torch.norm(prototypes, dim=1) ** 2
        self.output_layer_weight = torch.zeros_like(weight, requires_grad=True)
        self.output_layer_bias = torch.zeros_like(bias, requires_grad=True)
        return weight, bias

    def build_prototypes(self, data_repr, data_label, num_outputs):
        n_dim = data_repr.shape[2]
        data_repr = data_repr.view(-1, n_dim)
        data_label = data_label.view(-1)

        prototypes = torch.zeros((num_outputs, n_dim), device=self.device)

        for c in range(num_outputs):
            idx = torch.nonzero(data_label == c).view(-1)
            if idx.nelement() != 0:
                prototypes[c] = torch.mean(data_repr[idx], dim=0)

        return prototypes

    def initialize_output_layer(self, n_classes):
        self.output_layer_weight = torch.randn((n_classes, 768), requires_grad=True)
        self.output_layer_bias = torch.randn(n_classes, requires_grad=True)

    def train(self, support_text, labels, n_classes, n_iter):

        for itr in range(n_iter):
            print('Iteration ', itr)

            self.learner.zero_grad()

            self.initialize_output_layer(n_classes)
            x = self.learner.encode_text(support_text)
            y = labels.to(device)
            output_repr = self.learner(x)
            init_weights, init_bias = self.initialize_with_proto_weights(output_repr, y, n_classes)

            with higher.innerloop_ctx(self.learner, self.inner_optimizer,
                                      copy_initial_weights=False,
                                      track_higher_grads=False) as (flearner, diffopt):

                for i in range(self.updates):
                    output = flearner(x)
                    output = self.output_layer(output, init_weights, init_bias)
                    output = output.view(output.size()[0] * output.size()[1], -1)
                    loss = self.loss_fn(output, y)
                    output_weight_grad, output_bias_grad = torch.autograd.grad(loss, [self.output_layer_weight, self.output_layer_bias],
                                                                               retain_graph=True)
                    self.output_layer_weight = self.output_layer_weight - self.output_lr * output_weight_grad
                    self.output_layer_bias = self.output_layer_bias - self.output_lr * output_bias_grad
                    diffopt.step(loss)


if __name__ == '__main__':
    
    encoder = 'bert'  # or 'lstm'

    support_text = [['This is a support text']] * 64
    labels = torch.randint(0, 10, (64 * 30, ))
    n_classes = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ProtoMAML(device=device, encoder=encoder)
    model.train(support_text, labels, n_classes, n_iter=10)
@egrefen
Copy link
Contributor

egrefen commented Jul 15, 2020

Can you summarize the proto-MAML approach briefly? Also why is track_higher_grads=False? This is for "test mode" where you don't want to take higher gradients, just to avoid having to write the test loop from scratch.

@Nithin-Holla
Copy link
Author

Nithin-Holla commented Jul 15, 2020

In ProtoMAML, the network representations are first computed for all the data points in the support set. Then a "prototype" is generated for each class based on the mean representations of data points belonging to each class. This prototype is then used to initialize the final classification (softmax) layer - the weight is 2 * prototype and the bias is negative squared norm of the prototype. The outer-loop backprop requires gradients from two parts - the inner-loop updated model as well as the original classification layer initialization.

I used track_higher_grads=False because I use the first-order version and don't need the second-order gradients.

@egrefen egrefen added the wontfix This will not be worked on label Jul 15, 2020
@egrefen
Copy link
Contributor

egrefen commented Jul 15, 2020

It's not clear to me if you need higher at all, then. The entire purpose of the library is to allow you to do the second order backprop. It sounds to me like you should just be implementing this in vanilla pytorch. You don't actually need to patch modules or use differentiable optimizers for the first order case.

Anyway, the issue you're having comes from track_higher_grads=False, I'll close the issue.

@egrefen egrefen closed this as completed Jul 15, 2020
@Nithin-Holla
Copy link
Author

Nithin-Holla commented Jul 15, 2020

@egrefen The issue also occurs when track_higher_grads=True. I have modified the code to only fine-tune the top layer of BERT so that it doesn't run out of memory:

import higher
import torch

from torch import nn, optim
from transformers import BertModel, BertTokenizer
from torch.nn import functional as F


class BaseModel(nn.Module):

    def __init__(self, encoder, max_length, device):
        super(BaseModel, self).__init__()
        self.max_length = max_length
        self.device = device
        if encoder == 'bert':
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.encoder = BertModel.from_pretrained('bert-base-uncased')
            self.encoder.pooler.dense.weight.requires_grad = False
            self.encoder.pooler.dense.bias.requires_grad = False
            self.n_tunable_layers = 1
            tunable_layers = {str(l) for l in range(12 - self.n_tunable_layers, 12)}
            for name, param in self.encoder.named_parameters():
                if not set.intersection(set(name.split('.')), tunable_layers):
                    param.requires_grad = False
        elif encoder == 'lstm':
            self.encoder = nn.LSTM(batch_first=True, input_size=32, hidden_size=768)
        self.linear = nn.Linear(768, 192)
        self.to(self.device)

    def encode_text(self, text):
        if isinstance(self.encoder, BertModel):
            encode_result = self.tokenizer.batch_encode_plus(text, return_token_type_ids=False, max_length=self.max_length,
                                                             pad_to_max_length=True, return_tensors='pt')
            for key in encode_result:
                encode_result[key] = encode_result[key].to(self.device)
            return encode_result
        elif isinstance(self.encoder, nn.LSTM):
            return torch.randn((len(text), 32, 32), device=self.device)

    def forward(self, inputs):
        if isinstance(self.encoder, BertModel):
            out, _ = self.encoder(inputs['input_ids'], attention_mask=inputs['attention_mask'])
        elif isinstance(self.encoder, nn.LSTM):
            out, _ = self.encoder(inputs)
        out = out[:, 1:-1, :]
        out = self.linear(out)
        return out


class ProtoMAML:

    def __init__(self, device, encoder):
        self.output_layer_weight = None
        self.output_layer_bias = None
        self.learner = BaseModel(encoder=encoder, max_length=32, device=device)
        self.inner_optimizer = optim.SGD([p for p in self.learner.parameters() if p.requires_grad], lr=0.001)
        self.loss_fn = nn.CrossEntropyLoss()
        self.output_lr = 0.001
        self.device = device
        self.updates = 5

    def output_layer(self, input, weight, bias):
        return F.linear(input, self.output_layer_weight + weight, self.output_layer_bias + bias)

    def initialize_with_proto_weights(self, support_repr, support_label, n_classes):
        prototypes = self.build_prototypes(support_repr, support_label, n_classes)
        weight = 2 * prototypes
        bias = -torch.norm(prototypes, dim=1) ** 2
        self.output_layer_weight = torch.zeros_like(weight, requires_grad=True)
        self.output_layer_bias = torch.zeros_like(bias, requires_grad=True)
        return weight, bias

    def build_prototypes(self, data_repr, data_label, num_outputs):
        n_dim = data_repr.shape[2]
        data_repr = data_repr.view(-1, n_dim)
        data_label = data_label.view(-1)

        prototypes = torch.zeros((num_outputs, n_dim), device=self.device)

        for c in range(num_outputs):
            idx = torch.nonzero(data_label == c).view(-1)
            if idx.nelement() != 0:
                prototypes[c] = torch.mean(data_repr[idx], dim=0)

        return prototypes

    def initialize_output_layer(self, n_classes):
        self.output_layer_weight = torch.randn((n_classes, 768), requires_grad=True)
        self.output_layer_bias = torch.randn(n_classes, requires_grad=True)

    def train(self, support_text, labels, n_classes, n_iter):

        for itr in range(n_iter):
            print('Iteration ', itr)

            self.learner.zero_grad()

            self.initialize_output_layer(n_classes)
            x = self.learner.encode_text(support_text)
            y = labels.to(device)
            output_repr = self.learner(x)
            init_weights, init_bias = self.initialize_with_proto_weights(output_repr, y, n_classes)

            with higher.innerloop_ctx(self.learner, self.inner_optimizer,
                                      copy_initial_weights=False,
                                      track_higher_grads=True) as (flearner, diffopt):

                for i in range(self.updates):
                    output = flearner(x)
                    output = self.output_layer(output, init_weights, init_bias)
                    output = output.view(output.size()[0] * output.size()[1], -1)
                    loss = self.loss_fn(output, y)
                    output_weight_grad, output_bias_grad = torch.autograd.grad(loss, [self.output_layer_weight, self.output_layer_bias],
                                                                               retain_graph=True)
                    self.output_layer_weight = self.output_layer_weight - self.output_lr * output_weight_grad
                    self.output_layer_bias = self.output_layer_bias - self.output_lr * output_bias_grad
                    diffopt.step(loss)


if __name__ == '__main__':

    encoder = 'bert'  # or 'lstm'

    support_text = [['This is a support text']] * 64
    labels = torch.randint(0, 10, (64 * 30, ))
    n_classes = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ProtoMAML(device=device, encoder=encoder)
    model.train(support_text, labels, n_classes, n_iter=10)

@egrefen egrefen removed the wontfix This will not be worked on label Jul 16, 2020
@egrefen
Copy link
Contributor

egrefen commented Jul 16, 2020

Okay, I will reopen the issue but this code is quite complex. Can you please try to reimplement, with as few external dependencies as possible, a minimal example which reproduces the error? Does it really only crop up with huggingface's transformers? It's really time consuming to dig through a lot of code, understand it, and figure out where the bug is, and I'm happy to do that eventually if it's the only way, but I don't have a lot of time to do this right now so this would really speed things up.

@egrefen egrefen reopened this Jul 16, 2020
@Nithin-Holla
Copy link
Author

Nithin-Holla commented Jul 16, 2020

I tried to remove the components one by one until I could still replicate the issue. Now it's no longer ProtoMAML-specific but just using higher on top of BERT for sequence labelling. If you set encoder = 'lstm' in main, it runs fine. Here is the updated code:

import higher
import torch

from torch import nn, optim
from transformers import BertModel, BertTokenizer


class BaseModel(nn.Module):

    def __init__(self, encoder, max_length, device):
        super(BaseModel, self).__init__()
        self.max_length = max_length
        self.device = device
        if encoder == 'bert':
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.encoder = BertModel.from_pretrained('bert-base-uncased')
            self.encoder.pooler.dense.weight.requires_grad = False
            self.encoder.pooler.dense.bias.requires_grad = False
            self.n_tunable_layers = 1
            tunable_layers = {str(l) for l in range(12 - self.n_tunable_layers, 12)}
            for name, param in self.encoder.named_parameters():
                if not set.intersection(set(name.split('.')), tunable_layers):
                    param.requires_grad = False
        elif encoder == 'lstm':
            self.encoder = nn.LSTM(batch_first=True, input_size=32, hidden_size=768)
        self.linear = nn.Linear(768, 10)
        self.to(self.device)

    def encode_text(self, text):
        if isinstance(self.encoder, BertModel):
            encode_result = self.tokenizer.batch_encode_plus(text, return_token_type_ids=False, max_length=self.max_length,
                                                             pad_to_max_length=True, return_tensors='pt')
            for key in encode_result:
                encode_result[key] = encode_result[key].to(self.device)
            return encode_result
        elif isinstance(self.encoder, nn.LSTM):
            return torch.randn((len(text), 32, 32), device=self.device)

    def forward(self, inputs):
        if isinstance(self.encoder, BertModel):
            out, _ = self.encoder(inputs['input_ids'], attention_mask=inputs['attention_mask'])
        elif isinstance(self.encoder, nn.LSTM):
            out, _ = self.encoder(inputs)
        out = out[:, 1:-1, :]
        out = self.linear(out)
        return out


class ProtoMAML:

    def __init__(self, device, encoder):
        self.output_layer_weight = None
        self.output_layer_bias = None
        self.learner = BaseModel(encoder=encoder, max_length=32, device=device)
        self.inner_optimizer = optim.SGD([p for p in self.learner.parameters() if p.requires_grad], lr=0.001)
        self.loss_fn = nn.CrossEntropyLoss()
        self.output_lr = 0.001
        self.device = device
        self.updates = 5

    def train(self, support_text, labels, n_iter):

        for itr in range(n_iter):
            print('Iteration ', itr)
            self.learner.zero_grad()
            x = self.learner.encode_text(support_text)
            y = labels.to(device)

            with higher.innerloop_ctx(self.learner, self.inner_optimizer,
                                      copy_initial_weights=False,
                                      track_higher_grads=True) as (flearner, diffopt):
                for i in range(self.updates):
                    output = flearner(x)
                    output = output.view(output.size()[0] * output.size()[1], -1)
                    loss = self.loss_fn(output, y)
                    diffopt.step(loss)


if __name__ == '__main__':

    encoder = 'bert'  # or 'lstm'

    support_text = [['This is a support text']] * 64
    labels = torch.randint(0, 10, (64 * 30, ))
    n_classes = 10
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = ProtoMAML(device=device, encoder=encoder)
    model.train(support_text, labels, n_iter=10)

@egrefen
Copy link
Contributor

egrefen commented Jul 23, 2020

Hello @Nithin-Holla. Thanks for the pared down example, although I wouldn't call this minimal. It would have been helpful to get some tracebacks/outputs when things fail (and what should have happened) and sample output when things don't fair (e.g. when using LSTM). It also would have been useful to get a requirements.txt for repro (e.g. pip freeze > requirements.txt.

I've tried to set up an env to repro some error (since there isn't an actual clear statement of what that is) and get the following failure:

Traceback (most recent call last):
  File "bert.py", line 88, in <module>
    model.train(support_text, labels, n_iter=10)
  File "bert.py", line 66, in train
    x = self.learner.encode_text(support_text)
  File "bert.py", line 32, in encode_text
    pad_to_max_length=True, return_tensors='pt')
  File "/usr/local/anaconda3/lib/python3.7/site-packages/transformers/tokenization_utils_base.py", line 1832, in batch_encode_plus
    **kwargs,
  File "/usr/local/anaconda3/lib/python3.7/site-packages/transformers/tokenization_utils.py", line 534, in _batch_encode_plus
    ids, pair_ids = ids_or_pair_ids
ValueError: not enough values to unpack (expected 2, got 1)

I'm not sure this corresponds to the sort of error you're getting. Can you please help elucidate this matter?

@Nithin-Holla
Copy link
Author

This error is coming from within the new transformers release which I didn't have before. Anyway, I bumped up some of the packages I had and the original issue is now gone. I guess it was caused by one of the low-level libraries and not from higher. You can close this for now. Thanks!

@egrefen egrefen closed this as completed Jul 27, 2020
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants